Fix LightllmApi ppl test (#951)

This commit is contained in:
Yang Yong 2024-03-08 12:04:44 +08:00 committed by GitHub
parent 107e022cf4
commit 3829be87b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import json
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional
@ -35,6 +36,7 @@ class LightllmAPI(BaseModel):
self.retry = retry
self.generation_kwargs = generation_kwargs
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
self.meta_template = meta_template
self.token_bucket = TokenBucket(rate_per_worker, False)
def generate(self, inputs: List[str], max_out_len: int,
@ -158,3 +160,26 @@ class LightllmAPI(BaseModel):
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
# Count English words
english_count = sum(len(part.split()) for part in english_parts)
# Count Chinese words
chinese_count = sum(len(part) for part in chinese_parts)
return english_count + chinese_count

View File

@ -108,9 +108,9 @@ class PPLInferencer(BaseInferencer):
ice_template=ice_template,
prompt_template=prompt_template,
remain_sep=normalizing_str is not None)
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='ppl')
if self.max_seq_len is not None:
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='ppl')
while len(ice_idx_list[idx]
) > 0 and prompt_token_num > self.max_seq_len:
ice_idx_list[idx] = ice_idx_list[idx][:-1]