[Feat] Add _set_model_kwargs_torch_dtype for HF model (#507)

* add _set_model_kwargs_torch_dtype for hf models

* add logger
This commit is contained in:
Fengzhe Zhou 2023-10-27 11:45:41 +08:00 committed by GitHub
parent 6405cd2db5
commit e3d4901bed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 4 deletions

View File

@ -131,13 +131,28 @@ class HuggingFace(BaseModel):
self.tokenizer.eos_token = '</s>'
self.tokenizer.pad_token_id = 0
def _set_model_kwargs_torch_dtype(self, model_kwargs):
if 'torch_dtype' not in model_kwargs:
torch_dtype = torch.float16
else:
torch_dtype = {
'torch.float16': torch.float16,
'torch.bfloat16': torch.bfloat16,
'torch.float': torch.float,
'auto': 'auto',
'None': None
}.get(model_kwargs['torch_dtype'])
self.logger.debug(f'HF using torch_dtype: {torch_dtype}')
if torch_dtype is not None:
model_kwargs['torch_dtype'] = torch_dtype
def _load_model(self,
path: str,
model_kwargs: dict,
peft_path: Optional[str] = None):
from transformers import AutoModel, AutoModelForCausalLM
model_kwargs.setdefault('torch_dtype', torch.float16)
self._set_model_kwargs_torch_dtype(model_kwargs)
try:
self.model = AutoModelForCausalLM.from_pretrained(
path, **model_kwargs)
@ -409,7 +424,7 @@ class HuggingFaceCausalLM(HuggingFace):
peft_path: Optional[str] = None):
from transformers import AutoModelForCausalLM
model_kwargs.setdefault('torch_dtype', torch.float16)
self._set_model_kwargs_torch_dtype(model_kwargs)
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
if peft_path is not None:
from peft import PeftModel

10
run.py
View File

@ -175,8 +175,14 @@ def parse_hf_args(hf_parser):
hf_parser.add_argument('--hf-path', type=str)
hf_parser.add_argument('--peft-path', type=str)
hf_parser.add_argument('--tokenizer-path', type=str)
hf_parser.add_argument('--model-kwargs', nargs='+', action=DictAction)
hf_parser.add_argument('--tokenizer-kwargs', nargs='+', action=DictAction)
hf_parser.add_argument('--model-kwargs',
nargs='+',
action=DictAction,
default={})
hf_parser.add_argument('--tokenizer-kwargs',
nargs='+',
action=DictAction,
default={})
hf_parser.add_argument('--max-out-len', type=int)
hf_parser.add_argument('--max-seq-len', type=int)
hf_parser.add_argument('--no-batch-padding',