From f10dd48f9c75cb64dfb8ff7812c13ebc1826850d Mon Sep 17 00:00:00 2001 From: Fengzhe Zhou Date: Wed, 15 May 2024 14:10:33 +0800 Subject: [PATCH] [Fix] Update stop_words in huggingface_above_v4_33 (#1160) --- opencompass/models/huggingface_above_v4_33.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index f7ce622b..41356341 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -156,7 +156,7 @@ class HuggingFacewithChatTemplate(BaseModel): self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) self.generation_kwargs = generation_kwargs self.fastchat_template = fastchat_template - self.stop_words = stop_words + self.stop_words = list(set(stop_words + self._get_potential_stop_words(path))) for k, v in other_kwargs.items(): if v is not None: @@ -213,6 +213,19 @@ class HuggingFacewithChatTemplate(BaseModel): self.model.eval() self.model.generation_config.do_sample = False + def _get_potential_stop_words(self, path: Optional[str]): + from transformers import GenerationConfig + potential_stop_words = [] + try: + generation_config = GenerationConfig.from_pretrained(path) + for token_id in generation_config.eos_token_id: + potential_stop_words.append(self.tokenizer.decode(token_id)) + except: + pass + potential_stop_words.append(self.tokenizer.eos_token) + potential_stop_words = list(set(potential_stop_words)) + return potential_stop_words + def generate(self, inputs: List[str], max_out_len: int,