mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix LightllmApi workers bug (#1113)
This commit is contained in:
parent
baed2ed9b8
commit
53fe390454
@ -6,7 +6,7 @@ from opencompass.tasks import OpenICLInferTask
|
|||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
from .summarizers.leaderboard import summarizer
|
from .summarizers.leaderboard import summarizer
|
||||||
from .datasets.humaneval.humaneval_gen import humaneval_datasets
|
from .datasets.humaneval.humaneval_gen_a82cae import humaneval_datasets
|
||||||
|
|
||||||
datasets = [*humaneval_datasets]
|
datasets = [*humaneval_datasets]
|
||||||
|
|
||||||
@ -32,7 +32,8 @@ models = [
|
|||||||
url='http://localhost:1030/generate',
|
url='http://localhost:1030/generate',
|
||||||
meta_template=_meta_template,
|
meta_template=_meta_template,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
rate_per_worker=32,
|
max_workers_per_task=128,
|
||||||
|
rate_per_worker=1024,
|
||||||
retry=4,
|
retry=4,
|
||||||
generation_kwargs=dict(
|
generation_kwargs=dict(
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
@ -23,6 +23,7 @@ class LightllmAPI(BaseModel):
|
|||||||
path: str = 'LightllmAPI',
|
path: str = 'LightllmAPI',
|
||||||
url: str = 'http://localhost:8080/generate',
|
url: str = 'http://localhost:8080/generate',
|
||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
|
max_workers_per_task: int = 2,
|
||||||
rate_per_worker: int = 2,
|
rate_per_worker: int = 2,
|
||||||
retry: int = 2,
|
retry: int = 2,
|
||||||
generation_kwargs: Optional[Dict] = dict(),
|
generation_kwargs: Optional[Dict] = dict(),
|
||||||
@ -37,6 +38,7 @@ class LightllmAPI(BaseModel):
|
|||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
||||||
self.meta_template = meta_template
|
self.meta_template = meta_template
|
||||||
|
self.max_workers_per_task = max_workers_per_task
|
||||||
self.token_bucket = TokenBucket(rate_per_worker, False)
|
self.token_bucket = TokenBucket(rate_per_worker, False)
|
||||||
|
|
||||||
def generate(self, inputs: List[str], max_out_len: int,
|
def generate(self, inputs: List[str], max_out_len: int,
|
||||||
@ -53,7 +55,8 @@ class LightllmAPI(BaseModel):
|
|||||||
List[str]: A list of generated strings.
|
List[str]: A list of generated strings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor(
|
||||||
|
max_workers=self.max_workers_per_task) as executor:
|
||||||
results = list(
|
results = list(
|
||||||
executor.map(self._generate, inputs,
|
executor.map(self._generate, inputs,
|
||||||
[self.max_out_len] * len(inputs)))
|
[self.max_out_len] * len(inputs)))
|
||||||
@ -103,7 +106,8 @@ class LightllmAPI(BaseModel):
|
|||||||
List[str]: A list of generated strings.
|
List[str]: A list of generated strings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor(
|
||||||
|
max_workers=self.max_workers_per_task) as executor:
|
||||||
results = list(
|
results = list(
|
||||||
executor.map(self._get_ppl, inputs,
|
executor.map(self._get_ppl, inputs,
|
||||||
[self.max_out_len] * len(inputs)))
|
[self.max_out_len] * len(inputs)))
|
||||||
|
Loading…
Reference in New Issue
Block a user