[Fix] Enforce do_sample=False in HF model (#506)

* update hf model wrapper

* patch llama

---------

Co-authored-by: bot <bot@bot.com>
This commit is contained in:
Fengzhe Zhou 2023-10-27 16:54:19 +08:00 committed by GitHub
parent b62842335d
commit df07391ed8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -100,25 +100,33 @@ class HuggingFace(BaseModel):
if self.pad_token_id < 0: if self.pad_token_id < 0:
self.pad_token_id += self.tokenizer.vocab_size self.pad_token_id += self.tokenizer.vocab_size
if self.tokenizer.pad_token_id is None: if self.tokenizer.pad_token_id is None:
self.logger.warning( self.logger.debug(f'Using {self.pad_token_id} as pad_token_id')
f'Using {self.pad_token_id} as pad_token_id')
elif self.tokenizer.pad_token_id != self.pad_token_id: elif self.tokenizer.pad_token_id != self.pad_token_id:
self.logger.warning( self.logger.warning(
f'pad_token_id is not consistent with the tokenizer. Using {self.pad_token_id} as pad_token_id' # noqa 'pad_token_id is not consistent with the tokenizer. Using '
) f'{self.pad_token_id} as pad_token_id')
self.tokenizer.pad_token_id = self.pad_token_id self.tokenizer.pad_token_id = self.pad_token_id
elif self.tokenizer.pad_token_id is None: elif self.tokenizer.pad_token_id is None:
self.logger.warning('pad_token_id is not set for the tokenizer.') self.logger.warning('pad_token_id is not set for the tokenizer.')
if self.tokenizer.eos_token is not None: if self.tokenizer.eos_token is not None:
self.logger.warning('Using eos_token_id as pad_token_id.')
self.logger.warning( self.logger.warning(
f'{self.tokenizer.eos_token} la {self.tokenizer.eos_token is None}' # noqa f'Using eos_token_id {self.tokenizer.eos_token} '
) 'as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
else: else:
raise ValueError( from transformers.generation import GenerationConfig
'pad_token_id is not set for this tokenizer. Try to set pad_token_id via passing `pad_token_id={PAD_TOKEN_ID}` in model_cfg. You may find pad_token_id in `generation.json`' # noqa gcfg = GenerationConfig.from_pretrained(path)
)
if gcfg.pad_token_id is not None:
self.logger.warning(
f'Using pad_token_id {gcfg.pad_token_id} '
'as pad_token_id.')
self.tokenizer.pad_token_id = gcfg.pad_token_id
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to '
'set pad_token_id via passing '
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
# A patch for llama when batch_padding = True # A patch for llama when batch_padding = True
if 'decapoda-research/llama' in path or \ if 'decapoda-research/llama' in path or \
@ -165,6 +173,7 @@ class HuggingFace(BaseModel):
peft_path, peft_path,
is_trainable=False) is_trainable=False)
self.model.eval() self.model.eval()
self.model.generation_config.do_sample = False
# A patch for llama when batch_padding = True # A patch for llama when batch_padding = True
if 'decapoda-research/llama' in path: if 'decapoda-research/llama' in path:
@ -432,3 +441,4 @@ class HuggingFaceCausalLM(HuggingFace):
peft_path, peft_path,
is_trainable=False) is_trainable=False)
self.model.eval() self.model.eval()
self.model.generation_config.do_sample = False