mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge 5cbd8b7324
into 07930b854a
This commit is contained in:
commit
f7d2f8de1f
@ -124,20 +124,48 @@ def _get_meta_template(meta_template):
|
|||||||
return APITemplateParser(meta_template or default_meta_template)
|
return APITemplateParser(meta_template or default_meta_template)
|
||||||
|
|
||||||
|
|
||||||
def _set_model_kwargs_torch_dtype(model_kwargs):
|
def _set_model_kwargs_torch_dtype(model_kwargs, path=None):
|
||||||
import torch
|
import torch
|
||||||
if 'torch_dtype' not in model_kwargs:
|
from transformers import AutoConfig
|
||||||
torch_dtype = torch.float16
|
|
||||||
|
# If torch_dtype already exists and is not a string, return directly
|
||||||
|
if 'torch_dtype' in model_kwargs and not isinstance(model_kwargs['torch_dtype'], str):
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
|
# Mapping from string to torch data types
|
||||||
|
dtype_map = {
|
||||||
|
'torch.float16': torch.float16, 'float16': torch.float16,
|
||||||
|
'torch.bfloat16': torch.bfloat16, 'bfloat16': torch.bfloat16,
|
||||||
|
'torch.float': torch.float, 'float': torch.float,
|
||||||
|
'torch.float32': torch.float32, 'float32': torch.float32,
|
||||||
|
'auto': 'auto', 'None': None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. Priority: Use torch_dtype from model_kwargs if available
|
||||||
|
if 'torch_dtype' in model_kwargs:
|
||||||
|
torch_dtype = dtype_map.get(model_kwargs['torch_dtype'], torch.float16)
|
||||||
|
|
||||||
|
# 2. Secondary: Try to read from model config
|
||||||
|
elif path is not None:
|
||||||
|
try:
|
||||||
|
config = AutoConfig.from_pretrained(path)
|
||||||
|
if hasattr(config, 'torch_dtype'):
|
||||||
|
config_dtype = config.torch_dtype
|
||||||
|
if isinstance(config_dtype, str):
|
||||||
|
torch_dtype = dtype_map.get(config_dtype, torch.float16)
|
||||||
|
else:
|
||||||
|
torch_dtype = config_dtype
|
||||||
|
else:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
except Exception:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
# 3. Default: Use float16 as fallback
|
||||||
else:
|
else:
|
||||||
torch_dtype = {
|
torch_dtype = torch.float16
|
||||||
'torch.float16': torch.float16,
|
|
||||||
'torch.bfloat16': torch.bfloat16,
|
# Update model_kwargs with the resolved torch_dtype
|
||||||
'torch.float': torch.float,
|
model_kwargs['torch_dtype'] = torch_dtype
|
||||||
'auto': 'auto',
|
|
||||||
'None': None,
|
|
||||||
}.get(model_kwargs['torch_dtype'])
|
|
||||||
if torch_dtype is not None:
|
|
||||||
model_kwargs['torch_dtype'] = torch_dtype
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
@ -218,12 +246,12 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
||||||
|
|
||||||
def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()):
|
def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()):
|
||||||
from transformers import AutoModel, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True)
|
DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True)
|
||||||
model_kwargs = DEFAULT_MODEL_KWARGS
|
model_kwargs = DEFAULT_MODEL_KWARGS
|
||||||
model_kwargs.update(kwargs)
|
model_kwargs.update(kwargs)
|
||||||
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs, path)
|
||||||
self.logger.debug(f'using model_kwargs: {model_kwargs}')
|
self.logger.debug(f'using model_kwargs: {model_kwargs}')
|
||||||
if is_npu_available():
|
if is_npu_available():
|
||||||
model_kwargs['device_map'] = 'npu'
|
model_kwargs['device_map'] = 'npu'
|
||||||
|
Loading…
Reference in New Issue
Block a user