fix LightllmApi workers bug (#1113)

This commit is contained in:
Yang Yong 2024-04-30 22:09:22 +08:00 committed by GitHub
parent baed2ed9b8
commit 53fe390454
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 4 deletions

View File

@ -6,7 +6,7 @@ from opencompass.tasks import OpenICLInferTask
with read_base():
from .summarizers.leaderboard import summarizer
from .datasets.humaneval.humaneval_gen import humaneval_datasets
from .datasets.humaneval.humaneval_gen_a82cae import humaneval_datasets
datasets = [*humaneval_datasets]
@ -32,7 +32,8 @@ models = [
url='http://localhost:1030/generate',
meta_template=_meta_template,
batch_size=32,
rate_per_worker=32,
max_workers_per_task=128,
rate_per_worker=1024,
retry=4,
generation_kwargs=dict(
do_sample=False,

View File

@ -23,6 +23,7 @@ class LightllmAPI(BaseModel):
path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate',
meta_template: Optional[Dict] = None,
max_workers_per_task: int = 2,
rate_per_worker: int = 2,
retry: int = 2,
generation_kwargs: Optional[Dict] = dict(),
@ -37,6 +38,7 @@ class LightllmAPI(BaseModel):
self.generation_kwargs = generation_kwargs
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
self.meta_template = meta_template
self.max_workers_per_task = max_workers_per_task
self.token_bucket = TokenBucket(rate_per_worker, False)
def generate(self, inputs: List[str], max_out_len: int,
@ -53,7 +55,8 @@ class LightllmAPI(BaseModel):
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(
max_workers=self.max_workers_per_task) as executor:
results = list(
executor.map(self._generate, inputs,
[self.max_out_len] * len(inputs)))
@ -103,7 +106,8 @@ class LightllmAPI(BaseModel):
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(
max_workers=self.max_workers_per_task) as executor:
results = list(
executor.map(self._get_ppl, inputs,
[self.max_out_len] * len(inputs)))