mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] Add _set_model_kwargs_torch_dtype for HF model (#507)
* add _set_model_kwargs_torch_dtype for hf models * add logger
This commit is contained in:
parent
6405cd2db5
commit
e3d4901bed
@ -131,13 +131,28 @@ class HuggingFace(BaseModel):
|
||||
self.tokenizer.eos_token = '</s>'
|
||||
self.tokenizer.pad_token_id = 0
|
||||
|
||||
def _set_model_kwargs_torch_dtype(self, model_kwargs):
|
||||
if 'torch_dtype' not in model_kwargs:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = {
|
||||
'torch.float16': torch.float16,
|
||||
'torch.bfloat16': torch.bfloat16,
|
||||
'torch.float': torch.float,
|
||||
'auto': 'auto',
|
||||
'None': None
|
||||
}.get(model_kwargs['torch_dtype'])
|
||||
self.logger.debug(f'HF using torch_dtype: {torch_dtype}')
|
||||
if torch_dtype is not None:
|
||||
model_kwargs['torch_dtype'] = torch_dtype
|
||||
|
||||
def _load_model(self,
|
||||
path: str,
|
||||
model_kwargs: dict,
|
||||
peft_path: Optional[str] = None):
|
||||
from transformers import AutoModel, AutoModelForCausalLM
|
||||
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self._set_model_kwargs_torch_dtype(model_kwargs)
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path, **model_kwargs)
|
||||
@ -409,7 +424,7 @@ class HuggingFaceCausalLM(HuggingFace):
|
||||
peft_path: Optional[str] = None):
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self._set_model_kwargs_torch_dtype(model_kwargs)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
if peft_path is not None:
|
||||
from peft import PeftModel
|
||||
|
10
run.py
10
run.py
@ -175,8 +175,14 @@ def parse_hf_args(hf_parser):
|
||||
hf_parser.add_argument('--hf-path', type=str)
|
||||
hf_parser.add_argument('--peft-path', type=str)
|
||||
hf_parser.add_argument('--tokenizer-path', type=str)
|
||||
hf_parser.add_argument('--model-kwargs', nargs='+', action=DictAction)
|
||||
hf_parser.add_argument('--tokenizer-kwargs', nargs='+', action=DictAction)
|
||||
hf_parser.add_argument('--model-kwargs',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
default={})
|
||||
hf_parser.add_argument('--tokenizer-kwargs',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
default={})
|
||||
hf_parser.add_argument('--max-out-len', type=int)
|
||||
hf_parser.add_argument('--max-seq-len', type=int)
|
||||
hf_parser.add_argument('--no-batch-padding',
|
||||
|
Loading…
Reference in New Issue
Block a user