From 3829be87b106cef9550264c9eccb6237215bce5b Mon Sep 17 00:00:00 2001 From: Yang Yong Date: Fri, 8 Mar 2024 12:04:44 +0800 Subject: [PATCH] Fix LightllmApi ppl test (#951) --- opencompass/models/lightllm_api.py | 25 +++++++++++++++++++ .../icl_inferencer/icl_ppl_inferencer.py | 4 +-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/opencompass/models/lightllm_api.py b/opencompass/models/lightllm_api.py index 3bb02229..d06ab93f 100644 --- a/opencompass/models/lightllm_api.py +++ b/opencompass/models/lightllm_api.py @@ -1,4 +1,5 @@ import json +import re from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional @@ -35,6 +36,7 @@ class LightllmAPI(BaseModel): self.retry = retry self.generation_kwargs = generation_kwargs self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024) + self.meta_template = meta_template self.token_bucket = TokenBucket(rate_per_worker, False) def generate(self, inputs: List[str], max_out_len: int, @@ -158,3 +160,26 @@ class LightllmAPI(BaseModel): Applicable in both single-thread and multi-thread environments. """ return self.token_bucket.get_token() + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized string. Only English and Chinese + characters are counted for now. Users are encouraged to override this + method if more accurate length is needed. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + + english_parts = re.findall(r'[A-Za-z0-9]+', prompt) + chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) + + # Count English words + english_count = sum(len(part.split()) for part in english_parts) + + # Count Chinese words + chinese_count = sum(len(part) for part in chinese_parts) + + return english_count + chinese_count diff --git a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py index e82a015d..e48a8a2f 100644 --- a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py @@ -108,9 +108,9 @@ class PPLInferencer(BaseInferencer): ice_template=ice_template, prompt_template=prompt_template, remain_sep=normalizing_str is not None) + prompt_token_num = self.model.get_token_len_from_template( + prompt, mode='ppl') if self.max_seq_len is not None: - prompt_token_num = self.model.get_token_len_from_template( - prompt, mode='ppl') while len(ice_idx_list[idx] ) > 0 and prompt_token_num > self.max_seq_len: ice_idx_list[idx] = ice_idx_list[idx][:-1]