From 6e8adf5221cd012f650d7a54380890ad009935ee Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Sat, 19 Oct 2024 20:03:47 +0800 Subject: [PATCH] [Bug] Remove prefix bos_token from messages when using lmdeploy as the accelerator (#1623) * remove prefix bos_token from messages when using lmdeploy as the accelerator * update --- opencompass/models/turbomind_with_tf_above_v4_33.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/opencompass/models/turbomind_with_tf_above_v4_33.py b/opencompass/models/turbomind_with_tf_above_v4_33.py index ab6801c9..7d4e3891 100644 --- a/opencompass/models/turbomind_with_tf_above_v4_33.py +++ b/opencompass/models/turbomind_with_tf_above_v4_33.py @@ -102,7 +102,13 @@ class TurboMindModelwithChatTemplate(BaseModel): messages = _format_with_fast_chat_template(messages, self.fastchat_template) else: messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages] - + # LMDeploy tokenize prompts by AutoTokenizer with its default parameter "add_special_token=True" + # OC add bos_token in the prompt, which requires tokenizing prompts using "add_speicial_token=False" + # But LMDeploy doesn't have "add_speicial_token" in the pipeline API. So, we remove bos_token + # from messages as a workaround + if self.tokenizer.bos_token: + bos_token = self.tokenizer.bos_token + messages = [message.removeprefix(bos_token) if message.startswith(bos_token) else message for message in messages] stop_words = list(set(self.stop_words + stopping_criteria)) DEFAULT_GEN_CONFIG = { @@ -129,8 +135,7 @@ class TurboMindModelwithChatTemplate(BaseModel): results = [] outputs = self.pipe(messages, gen_config=gen_config, do_preprocess=False) for output in outputs: - text = self.tokenizer.decode(output.token_ids) - results.append(text) + results.append(output.text) for s in stop_words: results = [r.split(s)[0] for r in results] @@ -162,4 +167,4 @@ class TurboMindModelwithChatTemplate(BaseModel): else: filtered = {k: v for k, v in engine_config.items() if hasattr(PytorchEngineConfig, k)} backend_config = PytorchEngineConfig(**filtered) - return pipeline(model_path, backend_config=backend_config, log_level='INFO', max_log_len=10) + return pipeline(model_path, backend_config=backend_config, log_level='WARNING')