mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Fix LightllmApi ppl test (#951)
This commit is contained in:
parent
107e022cf4
commit
3829be87b1
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ class LightllmAPI(BaseModel):
|
|||||||
self.retry = retry
|
self.retry = retry
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
||||||
|
self.meta_template = meta_template
|
||||||
self.token_bucket = TokenBucket(rate_per_worker, False)
|
self.token_bucket = TokenBucket(rate_per_worker, False)
|
||||||
|
|
||||||
def generate(self, inputs: List[str], max_out_len: int,
|
def generate(self, inputs: List[str], max_out_len: int,
|
||||||
@ -158,3 +160,26 @@ class LightllmAPI(BaseModel):
|
|||||||
Applicable in both single-thread and multi-thread environments.
|
Applicable in both single-thread and multi-thread environments.
|
||||||
"""
|
"""
|
||||||
return self.token_bucket.get_token()
|
return self.token_bucket.get_token()
|
||||||
|
|
||||||
|
def get_token_len(self, prompt: str) -> int:
|
||||||
|
"""Get lengths of the tokenized string. Only English and Chinese
|
||||||
|
characters are counted for now. Users are encouraged to override this
|
||||||
|
method if more accurate length is needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Input string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Length of the input tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
|
||||||
|
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
|
||||||
|
|
||||||
|
# Count English words
|
||||||
|
english_count = sum(len(part.split()) for part in english_parts)
|
||||||
|
|
||||||
|
# Count Chinese words
|
||||||
|
chinese_count = sum(len(part) for part in chinese_parts)
|
||||||
|
|
||||||
|
return english_count + chinese_count
|
||||||
|
@ -108,9 +108,9 @@ class PPLInferencer(BaseInferencer):
|
|||||||
ice_template=ice_template,
|
ice_template=ice_template,
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
remain_sep=normalizing_str is not None)
|
remain_sep=normalizing_str is not None)
|
||||||
|
prompt_token_num = self.model.get_token_len_from_template(
|
||||||
|
prompt, mode='ppl')
|
||||||
if self.max_seq_len is not None:
|
if self.max_seq_len is not None:
|
||||||
prompt_token_num = self.model.get_token_len_from_template(
|
|
||||||
prompt, mode='ppl')
|
|
||||||
while len(ice_idx_list[idx]
|
while len(ice_idx_list[idx]
|
||||||
) > 0 and prompt_token_num > self.max_seq_len:
|
) > 0 and prompt_token_num > self.max_seq_len:
|
||||||
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
||||||
|
Loading…
Reference in New Issue
Block a user