mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Fix LightllmApi ppl test (#951)
This commit is contained in:
parent
107e022cf4
commit
3829be87b1
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -35,6 +36,7 @@ class LightllmAPI(BaseModel):
|
||||
self.retry = retry
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
||||
self.meta_template = meta_template
|
||||
self.token_bucket = TokenBucket(rate_per_worker, False)
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int,
|
||||
@ -158,3 +160,26 @@ class LightllmAPI(BaseModel):
|
||||
Applicable in both single-thread and multi-thread environments.
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
|
||||
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
|
||||
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
|
||||
|
||||
# Count English words
|
||||
english_count = sum(len(part.split()) for part in english_parts)
|
||||
|
||||
# Count Chinese words
|
||||
chinese_count = sum(len(part) for part in chinese_parts)
|
||||
|
||||
return english_count + chinese_count
|
||||
|
@ -108,9 +108,9 @@ class PPLInferencer(BaseInferencer):
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template,
|
||||
remain_sep=normalizing_str is not None)
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='ppl')
|
||||
if self.max_seq_len is not None:
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='ppl')
|
||||
while len(ice_idx_list[idx]
|
||||
) > 0 and prompt_token_num > self.max_seq_len:
|
||||
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
||||
|
Loading…
Reference in New Issue
Block a user