From 53fe3904540c049e259492016942cbd39f13a7a2 Mon Sep 17 00:00:00 2001 From: Yang Yong Date: Tue, 30 Apr 2024 22:09:22 +0800 Subject: [PATCH] fix LightllmApi workers bug (#1113) --- configs/eval_lightllm.py | 5 +++-- opencompass/models/lightllm_api.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/configs/eval_lightllm.py b/configs/eval_lightllm.py index d84cd2c2..fe45c14c 100644 --- a/configs/eval_lightllm.py +++ b/configs/eval_lightllm.py @@ -6,7 +6,7 @@ from opencompass.tasks import OpenICLInferTask with read_base(): from .summarizers.leaderboard import summarizer - from .datasets.humaneval.humaneval_gen import humaneval_datasets + from .datasets.humaneval.humaneval_gen_a82cae import humaneval_datasets datasets = [*humaneval_datasets] @@ -32,7 +32,8 @@ models = [ url='http://localhost:1030/generate', meta_template=_meta_template, batch_size=32, - rate_per_worker=32, + max_workers_per_task=128, + rate_per_worker=1024, retry=4, generation_kwargs=dict( do_sample=False, diff --git a/opencompass/models/lightllm_api.py b/opencompass/models/lightllm_api.py index d06ab93f..58fbae72 100644 --- a/opencompass/models/lightllm_api.py +++ b/opencompass/models/lightllm_api.py @@ -23,6 +23,7 @@ class LightllmAPI(BaseModel): path: str = 'LightllmAPI', url: str = 'http://localhost:8080/generate', meta_template: Optional[Dict] = None, + max_workers_per_task: int = 2, rate_per_worker: int = 2, retry: int = 2, generation_kwargs: Optional[Dict] = dict(), @@ -37,6 +38,7 @@ class LightllmAPI(BaseModel): self.generation_kwargs = generation_kwargs self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024) self.meta_template = meta_template + self.max_workers_per_task = max_workers_per_task self.token_bucket = TokenBucket(rate_per_worker, False) def generate(self, inputs: List[str], max_out_len: int, @@ -53,7 +55,8 @@ class LightllmAPI(BaseModel): List[str]: A list of generated strings. """ - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor( + max_workers=self.max_workers_per_task) as executor: results = list( executor.map(self._generate, inputs, [self.max_out_len] * len(inputs))) @@ -103,7 +106,8 @@ class LightllmAPI(BaseModel): List[str]: A list of generated strings. """ - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor( + max_workers=self.max_workers_per_task) as executor: results = list( executor.map(self._get_ppl, inputs, [self.max_out_len] * len(inputs)))