This commit is contained in:
liushz 2025-03-25 02:57:46 +00:00 committed by GitHub
commit f7d2f8de1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -124,20 +124,48 @@ def _get_meta_template(meta_template):
return APITemplateParser(meta_template or default_meta_template)
def _set_model_kwargs_torch_dtype(model_kwargs):
def _set_model_kwargs_torch_dtype(model_kwargs, path=None):
import torch
if 'torch_dtype' not in model_kwargs:
torch_dtype = torch.float16
from transformers import AutoConfig
# If torch_dtype already exists and is not a string, return directly
if 'torch_dtype' in model_kwargs and not isinstance(model_kwargs['torch_dtype'], str):
return model_kwargs
# Mapping from string to torch data types
dtype_map = {
'torch.float16': torch.float16, 'float16': torch.float16,
'torch.bfloat16': torch.bfloat16, 'bfloat16': torch.bfloat16,
'torch.float': torch.float, 'float': torch.float,
'torch.float32': torch.float32, 'float32': torch.float32,
'auto': 'auto', 'None': None
}
# 1. Priority: Use torch_dtype from model_kwargs if available
if 'torch_dtype' in model_kwargs:
torch_dtype = dtype_map.get(model_kwargs['torch_dtype'], torch.float16)
# 2. Secondary: Try to read from model config
elif path is not None:
try:
config = AutoConfig.from_pretrained(path)
if hasattr(config, 'torch_dtype'):
config_dtype = config.torch_dtype
if isinstance(config_dtype, str):
torch_dtype = dtype_map.get(config_dtype, torch.float16)
else:
torch_dtype = config_dtype
else:
torch_dtype = torch.float16
except Exception:
torch_dtype = torch.float16
# 3. Default: Use float16 as fallback
else:
torch_dtype = {
'torch.float16': torch.float16,
'torch.bfloat16': torch.bfloat16,
'torch.float': torch.float,
'auto': 'auto',
'None': None,
}.get(model_kwargs['torch_dtype'])
if torch_dtype is not None:
model_kwargs['torch_dtype'] = torch_dtype
torch_dtype = torch.float16
# Update model_kwargs with the resolved torch_dtype
model_kwargs['torch_dtype'] = torch_dtype
return model_kwargs
@ -218,12 +246,12 @@ class HuggingFacewithChatTemplate(BaseModel):
raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()):
from transformers import AutoModel, AutoModelForCausalLM
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True)
model_kwargs = DEFAULT_MODEL_KWARGS
model_kwargs.update(kwargs)
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs, path)
self.logger.debug(f'using model_kwargs: {model_kwargs}')
if is_npu_available():
model_kwargs['device_map'] = 'npu'