mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Enforce do_sample=False
in HF model (#506)
* update hf model wrapper * patch llama --------- Co-authored-by: bot <bot@bot.com>
This commit is contained in:
parent
b62842335d
commit
df07391ed8
@ -100,25 +100,33 @@ class HuggingFace(BaseModel):
|
||||
if self.pad_token_id < 0:
|
||||
self.pad_token_id += self.tokenizer.vocab_size
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.logger.warning(
|
||||
f'Using {self.pad_token_id} as pad_token_id')
|
||||
self.logger.debug(f'Using {self.pad_token_id} as pad_token_id')
|
||||
elif self.tokenizer.pad_token_id != self.pad_token_id:
|
||||
self.logger.warning(
|
||||
f'pad_token_id is not consistent with the tokenizer. Using {self.pad_token_id} as pad_token_id' # noqa
|
||||
)
|
||||
'pad_token_id is not consistent with the tokenizer. Using '
|
||||
f'{self.pad_token_id} as pad_token_id')
|
||||
self.tokenizer.pad_token_id = self.pad_token_id
|
||||
elif self.tokenizer.pad_token_id is None:
|
||||
self.logger.warning('pad_token_id is not set for the tokenizer.')
|
||||
if self.tokenizer.eos_token is not None:
|
||||
self.logger.warning('Using eos_token_id as pad_token_id.')
|
||||
self.logger.warning(
|
||||
f'{self.tokenizer.eos_token} la {self.tokenizer.eos_token is None}' # noqa
|
||||
)
|
||||
f'Using eos_token_id {self.tokenizer.eos_token} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
else:
|
||||
from transformers.generation import GenerationConfig
|
||||
gcfg = GenerationConfig.from_pretrained(path)
|
||||
|
||||
if gcfg.pad_token_id is not None:
|
||||
self.logger.warning(
|
||||
f'Using pad_token_id {gcfg.pad_token_id} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token_id = gcfg.pad_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
'pad_token_id is not set for this tokenizer. Try to set pad_token_id via passing `pad_token_id={PAD_TOKEN_ID}` in model_cfg. You may find pad_token_id in `generation.json`' # noqa
|
||||
)
|
||||
'pad_token_id is not set for this tokenizer. Try to '
|
||||
'set pad_token_id via passing '
|
||||
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
||||
|
||||
# A patch for llama when batch_padding = True
|
||||
if 'decapoda-research/llama' in path or \
|
||||
@ -165,6 +173,7 @@ class HuggingFace(BaseModel):
|
||||
peft_path,
|
||||
is_trainable=False)
|
||||
self.model.eval()
|
||||
self.model.generation_config.do_sample = False
|
||||
|
||||
# A patch for llama when batch_padding = True
|
||||
if 'decapoda-research/llama' in path:
|
||||
@ -432,3 +441,4 @@ class HuggingFaceCausalLM(HuggingFace):
|
||||
peft_path,
|
||||
is_trainable=False)
|
||||
self.model.eval()
|
||||
self.model.generation_config.do_sample = False
|
||||
|
Loading…
Reference in New Issue
Block a user