mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Update stop_words in huggingface_above_v4_33 (#1160)
This commit is contained in:
parent
80f831b425
commit
f10dd48f9c
@ -156,7 +156,7 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
|
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.fastchat_template = fastchat_template
|
self.fastchat_template = fastchat_template
|
||||||
self.stop_words = stop_words
|
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
|
||||||
|
|
||||||
for k, v in other_kwargs.items():
|
for k, v in other_kwargs.items():
|
||||||
if v is not None:
|
if v is not None:
|
||||||
@ -213,6 +213,19 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model.generation_config.do_sample = False
|
self.model.generation_config.do_sample = False
|
||||||
|
|
||||||
|
def _get_potential_stop_words(self, path: Optional[str]):
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
potential_stop_words = []
|
||||||
|
try:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(path)
|
||||||
|
for token_id in generation_config.eos_token_id:
|
||||||
|
potential_stop_words.append(self.tokenizer.decode(token_id))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
potential_stop_words.append(self.tokenizer.eos_token)
|
||||||
|
potential_stop_words = list(set(potential_stop_words))
|
||||||
|
return potential_stop_words
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
inputs: List[str],
|
inputs: List[str],
|
||||||
max_out_len: int,
|
max_out_len: int,
|
||||||
|
Loading…
Reference in New Issue
Block a user