[Fix] Update stop_words in huggingface_above_v4_33 (#1160)

This commit is contained in:
Fengzhe Zhou 2024-05-15 14:10:33 +08:00 committed by GitHub
parent 80f831b425
commit f10dd48f9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -156,7 +156,7 @@ class HuggingFacewithChatTemplate(BaseModel):
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.fastchat_template = fastchat_template self.fastchat_template = fastchat_template
self.stop_words = stop_words self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
for k, v in other_kwargs.items(): for k, v in other_kwargs.items():
if v is not None: if v is not None:
@ -213,6 +213,19 @@ class HuggingFacewithChatTemplate(BaseModel):
self.model.eval() self.model.eval()
self.model.generation_config.do_sample = False self.model.generation_config.do_sample = False
def _get_potential_stop_words(self, path: Optional[str]):
from transformers import GenerationConfig
potential_stop_words = []
try:
generation_config = GenerationConfig.from_pretrained(path)
for token_id in generation_config.eos_token_id:
potential_stop_words.append(self.tokenizer.decode(token_id))
except:
pass
potential_stop_words.append(self.tokenizer.eos_token)
potential_stop_words = list(set(potential_stop_words))
return potential_stop_words
def generate(self, def generate(self,
inputs: List[str], inputs: List[str],
max_out_len: int, max_out_len: int,