mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge 5cbd8b7324
into 07930b854a
This commit is contained in:
commit
f7d2f8de1f
@ -124,20 +124,48 @@ def _get_meta_template(meta_template):
|
||||
return APITemplateParser(meta_template or default_meta_template)
|
||||
|
||||
|
||||
def _set_model_kwargs_torch_dtype(model_kwargs):
|
||||
def _set_model_kwargs_torch_dtype(model_kwargs, path=None):
|
||||
import torch
|
||||
if 'torch_dtype' not in model_kwargs:
|
||||
torch_dtype = torch.float16
|
||||
from transformers import AutoConfig
|
||||
|
||||
# If torch_dtype already exists and is not a string, return directly
|
||||
if 'torch_dtype' in model_kwargs and not isinstance(model_kwargs['torch_dtype'], str):
|
||||
return model_kwargs
|
||||
|
||||
# Mapping from string to torch data types
|
||||
dtype_map = {
|
||||
'torch.float16': torch.float16, 'float16': torch.float16,
|
||||
'torch.bfloat16': torch.bfloat16, 'bfloat16': torch.bfloat16,
|
||||
'torch.float': torch.float, 'float': torch.float,
|
||||
'torch.float32': torch.float32, 'float32': torch.float32,
|
||||
'auto': 'auto', 'None': None
|
||||
}
|
||||
|
||||
# 1. Priority: Use torch_dtype from model_kwargs if available
|
||||
if 'torch_dtype' in model_kwargs:
|
||||
torch_dtype = dtype_map.get(model_kwargs['torch_dtype'], torch.float16)
|
||||
|
||||
# 2. Secondary: Try to read from model config
|
||||
elif path is not None:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(path)
|
||||
if hasattr(config, 'torch_dtype'):
|
||||
config_dtype = config.torch_dtype
|
||||
if isinstance(config_dtype, str):
|
||||
torch_dtype = dtype_map.get(config_dtype, torch.float16)
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
torch_dtype = torch.float16
|
||||
except Exception:
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# 3. Default: Use float16 as fallback
|
||||
else:
|
||||
torch_dtype = {
|
||||
'torch.float16': torch.float16,
|
||||
'torch.bfloat16': torch.bfloat16,
|
||||
'torch.float': torch.float,
|
||||
'auto': 'auto',
|
||||
'None': None,
|
||||
}.get(model_kwargs['torch_dtype'])
|
||||
if torch_dtype is not None:
|
||||
model_kwargs['torch_dtype'] = torch_dtype
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# Update model_kwargs with the resolved torch_dtype
|
||||
model_kwargs['torch_dtype'] = torch_dtype
|
||||
return model_kwargs
|
||||
|
||||
|
||||
@ -218,12 +246,12 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
||||
|
||||
def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()):
|
||||
from transformers import AutoModel, AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
|
||||
DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True)
|
||||
model_kwargs = DEFAULT_MODEL_KWARGS
|
||||
model_kwargs.update(kwargs)
|
||||
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
||||
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs, path)
|
||||
self.logger.debug(f'using model_kwargs: {model_kwargs}')
|
||||
if is_npu_available():
|
||||
model_kwargs['device_map'] = 'npu'
|
||||
|
Loading…
Reference in New Issue
Block a user