mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix LightllmApi workers bug (#1113)
This commit is contained in:
parent
baed2ed9b8
commit
53fe390454
@ -6,7 +6,7 @@ from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .summarizers.leaderboard import summarizer
|
||||
from .datasets.humaneval.humaneval_gen import humaneval_datasets
|
||||
from .datasets.humaneval.humaneval_gen_a82cae import humaneval_datasets
|
||||
|
||||
datasets = [*humaneval_datasets]
|
||||
|
||||
@ -32,7 +32,8 @@ models = [
|
||||
url='http://localhost:1030/generate',
|
||||
meta_template=_meta_template,
|
||||
batch_size=32,
|
||||
rate_per_worker=32,
|
||||
max_workers_per_task=128,
|
||||
rate_per_worker=1024,
|
||||
retry=4,
|
||||
generation_kwargs=dict(
|
||||
do_sample=False,
|
||||
|
@ -23,6 +23,7 @@ class LightllmAPI(BaseModel):
|
||||
path: str = 'LightllmAPI',
|
||||
url: str = 'http://localhost:8080/generate',
|
||||
meta_template: Optional[Dict] = None,
|
||||
max_workers_per_task: int = 2,
|
||||
rate_per_worker: int = 2,
|
||||
retry: int = 2,
|
||||
generation_kwargs: Optional[Dict] = dict(),
|
||||
@ -37,6 +38,7 @@ class LightllmAPI(BaseModel):
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
||||
self.meta_template = meta_template
|
||||
self.max_workers_per_task = max_workers_per_task
|
||||
self.token_bucket = TokenBucket(rate_per_worker, False)
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int,
|
||||
@ -53,7 +55,8 @@ class LightllmAPI(BaseModel):
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=self.max_workers_per_task) as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[self.max_out_len] * len(inputs)))
|
||||
@ -103,7 +106,8 @@ class LightllmAPI(BaseModel):
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=self.max_workers_per_task) as executor:
|
||||
results = list(
|
||||
executor.map(self._get_ppl, inputs,
|
||||
[self.max_out_len] * len(inputs)))
|
||||
|
Loading…
Reference in New Issue
Block a user