From e3d4901bedfee5b042a287a1c878f9526e52bcf1 Mon Sep 17 00:00:00 2001 From: Fengzhe Zhou Date: Fri, 27 Oct 2023 11:45:41 +0800 Subject: [PATCH] [Feat] Add _set_model_kwargs_torch_dtype for HF model (#507) * add _set_model_kwargs_torch_dtype for hf models * add logger --- opencompass/models/huggingface.py | 19 +++++++++++++++++-- run.py | 10 ++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/opencompass/models/huggingface.py b/opencompass/models/huggingface.py index 5addd39e..50f03349 100644 --- a/opencompass/models/huggingface.py +++ b/opencompass/models/huggingface.py @@ -131,13 +131,28 @@ class HuggingFace(BaseModel): self.tokenizer.eos_token = '' self.tokenizer.pad_token_id = 0 + def _set_model_kwargs_torch_dtype(self, model_kwargs): + if 'torch_dtype' not in model_kwargs: + torch_dtype = torch.float16 + else: + torch_dtype = { + 'torch.float16': torch.float16, + 'torch.bfloat16': torch.bfloat16, + 'torch.float': torch.float, + 'auto': 'auto', + 'None': None + }.get(model_kwargs['torch_dtype']) + self.logger.debug(f'HF using torch_dtype: {torch_dtype}') + if torch_dtype is not None: + model_kwargs['torch_dtype'] = torch_dtype + def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None): from transformers import AutoModel, AutoModelForCausalLM - model_kwargs.setdefault('torch_dtype', torch.float16) + self._set_model_kwargs_torch_dtype(model_kwargs) try: self.model = AutoModelForCausalLM.from_pretrained( path, **model_kwargs) @@ -409,7 +424,7 @@ class HuggingFaceCausalLM(HuggingFace): peft_path: Optional[str] = None): from transformers import AutoModelForCausalLM - model_kwargs.setdefault('torch_dtype', torch.float16) + self._set_model_kwargs_torch_dtype(model_kwargs) self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) if peft_path is not None: from peft import PeftModel diff --git a/run.py b/run.py index 1e3bc7ed..f5512a67 100644 --- a/run.py +++ b/run.py @@ -175,8 +175,14 @@ def parse_hf_args(hf_parser): hf_parser.add_argument('--hf-path', type=str) hf_parser.add_argument('--peft-path', type=str) hf_parser.add_argument('--tokenizer-path', type=str) - hf_parser.add_argument('--model-kwargs', nargs='+', action=DictAction) - hf_parser.add_argument('--tokenizer-kwargs', nargs='+', action=DictAction) + hf_parser.add_argument('--model-kwargs', + nargs='+', + action=DictAction, + default={}) + hf_parser.add_argument('--tokenizer-kwargs', + nargs='+', + action=DictAction, + default={}) hf_parser.add_argument('--max-out-len', type=int) hf_parser.add_argument('--max-seq-len', type=int) hf_parser.add_argument('--no-batch-padding',