diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 5cd38b4a..8e9bfe94 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -124,20 +124,48 @@ def _get_meta_template(meta_template): return APITemplateParser(meta_template or default_meta_template) -def _set_model_kwargs_torch_dtype(model_kwargs): +def _set_model_kwargs_torch_dtype(model_kwargs, path=None): import torch - if 'torch_dtype' not in model_kwargs: - torch_dtype = torch.float16 + from transformers import AutoConfig + + # If torch_dtype already exists and is not a string, return directly + if 'torch_dtype' in model_kwargs and not isinstance(model_kwargs['torch_dtype'], str): + return model_kwargs + + # Mapping from string to torch data types + dtype_map = { + 'torch.float16': torch.float16, 'float16': torch.float16, + 'torch.bfloat16': torch.bfloat16, 'bfloat16': torch.bfloat16, + 'torch.float': torch.float, 'float': torch.float, + 'torch.float32': torch.float32, 'float32': torch.float32, + 'auto': 'auto', 'None': None + } + + # 1. Priority: Use torch_dtype from model_kwargs if available + if 'torch_dtype' in model_kwargs: + torch_dtype = dtype_map.get(model_kwargs['torch_dtype'], torch.float16) + + # 2. Secondary: Try to read from model config + elif path is not None: + try: + config = AutoConfig.from_pretrained(path) + if hasattr(config, 'torch_dtype'): + config_dtype = config.torch_dtype + if isinstance(config_dtype, str): + torch_dtype = dtype_map.get(config_dtype, torch.float16) + else: + torch_dtype = config_dtype + else: + torch_dtype = torch.float16 + except Exception: + torch_dtype = torch.float16 + + # 3. Default: Use float16 as fallback else: - torch_dtype = { - 'torch.float16': torch.float16, - 'torch.bfloat16': torch.bfloat16, - 'torch.float': torch.float, - 'auto': 'auto', - 'None': None, - }.get(model_kwargs['torch_dtype']) - if torch_dtype is not None: - model_kwargs['torch_dtype'] = torch_dtype + torch_dtype = torch.float16 + + # Update model_kwargs with the resolved torch_dtype + model_kwargs['torch_dtype'] = torch_dtype return model_kwargs @@ -218,12 +246,12 @@ class HuggingFacewithChatTemplate(BaseModel): raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.') def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()): - from transformers import AutoModel, AutoModelForCausalLM + from transformers import AutoConfig, AutoModel, AutoModelForCausalLM DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True) model_kwargs = DEFAULT_MODEL_KWARGS model_kwargs.update(kwargs) - model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) + model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs, path) self.logger.debug(f'using model_kwargs: {model_kwargs}') if is_npu_available(): model_kwargs['device_map'] = 'npu'