From 721a45c68f197f1425b4463e74c2a55ea371a842 Mon Sep 17 00:00:00 2001 From: Songyang Zhang Date: Wed, 22 Nov 2023 10:02:57 +0800 Subject: [PATCH] [Bug] Update api with generation_kargs (#614) * update api * update generation_kwargs impl --------- Co-authored-by: Leymore --- opencompass/models/base.py | 6 +++++- opencompass/models/base_api.py | 6 +++++- opencompass/models/lightllm_api.py | 21 ++++++++----------- opencompass/models/turbomind.py | 1 - opencompass/models/turbomind_tis.py | 1 - .../icl_inferencer/icl_gen_inferencer.py | 2 +- 6 files changed, 20 insertions(+), 17 deletions(-) diff --git a/opencompass/models/base.py b/opencompass/models/base.py index cc53e11e..0115cb0a 100644 --- a/opencompass/models/base.py +++ b/opencompass/models/base.py @@ -19,6 +19,8 @@ class BaseModel: meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. + generation_kwargs (Dict, optional): The generation kwargs for the + model. Defaults to dict(). """ is_api: bool = False @@ -27,7 +29,8 @@ class BaseModel: path: str, max_seq_len: int = 2048, tokenizer_only: bool = False, - meta_template: Optional[Dict] = None): + meta_template: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = dict()): self.path = path self.max_seq_len = max_seq_len self.tokenizer_only = tokenizer_only @@ -36,6 +39,7 @@ class BaseModel: self.eos_token_id = None if meta_template and 'eos_token_id' in meta_template: self.eos_token_id = meta_template['eos_token_id'] + self.generation_kwargs = generation_kwargs @abstractmethod def generate(self, inputs: List[str], max_out_len: int) -> List[str]: diff --git a/opencompass/models/base_api.py b/opencompass/models/base_api.py index 04ce9323..9c7f2b95 100644 --- a/opencompass/models/base_api.py +++ b/opencompass/models/base_api.py @@ -28,6 +28,8 @@ class BaseAPIModel(BaseModel): meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. + generation_kwargs (Dict, optional): The generation kwargs for the + model. Defaults to dict(). """ is_api: bool = True @@ -37,7 +39,8 @@ class BaseAPIModel(BaseModel): query_per_second: int = 1, retry: int = 2, max_seq_len: int = 2048, - meta_template: Optional[Dict] = None): + meta_template: Optional[Dict] = None, + generation_kwargs: Dict = dict()): self.path = path self.max_seq_len = max_seq_len self.meta_template = meta_template @@ -46,6 +49,7 @@ class BaseAPIModel(BaseModel): self.token_bucket = TokenBucket(query_per_second) self.template_parser = APITemplateParser(meta_template) self.logger = get_logger() + self.generation_kwargs = generation_kwargs @abstractmethod def generate(self, inputs: List[PromptType], diff --git a/opencompass/models/lightllm_api.py b/opencompass/models/lightllm_api.py index f4e05711..73f677c3 100644 --- a/opencompass/models/lightllm_api.py +++ b/opencompass/models/lightllm_api.py @@ -16,25 +16,22 @@ class LightllmAPI(BaseAPIModel): is_api: bool = True def __init__( - self, - path: str = 'LightllmAPI', - url: str = 'http://localhost:8080/generate', - max_seq_len: int = 2048, - meta_template: Optional[Dict] = None, - retry: int = 2, - generation_kwargs: Optional[Dict] = None, + self, + path: str = 'LightllmAPI', + url: str = 'http://localhost:8080/generate', + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + generation_kwargs: Optional[Dict] = dict(), ): super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, - retry=retry) + retry=retry, + generation_kwargs=generation_kwargs) self.logger = get_logger() self.url = url - if generation_kwargs is not None: - self.generation_kwargs = generation_kwargs - else: - self.generation_kwargs = {} self.do_sample = self.generation_kwargs.get('do_sample', False) self.ignore_eos = self.generation_kwargs.get('ignore_eos', False) diff --git a/opencompass/models/turbomind.py b/opencompass/models/turbomind.py index 99f3d30c..3dd9db0d 100644 --- a/opencompass/models/turbomind.py +++ b/opencompass/models/turbomind.py @@ -54,7 +54,6 @@ class TurboMindModel(BaseModel): tm_model.create_instance() for i in range(concurrency) ] self.generator_ids = [i + 1 for i in range(concurrency)] - self.generation_kwargs = dict() def generate( self, diff --git a/opencompass/models/turbomind_tis.py b/opencompass/models/turbomind_tis.py index d1a41fbc..53de25eb 100644 --- a/opencompass/models/turbomind_tis.py +++ b/opencompass/models/turbomind_tis.py @@ -53,7 +53,6 @@ class TurboMindTisModel(BaseModel): if meta_template and 'eos_token_id' in meta_template: self.eos_token_id = meta_template['eos_token_id'] self.tis_addr = tis_addr - self.generation_kwargs = dict() def generate( self, diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index 72d8fff9..342f9e06 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -130,7 +130,7 @@ class GenInferencer(BaseInferencer): entry, max_out_len=self.max_out_len) generated = results - num_return_sequences = self.model.generation_kwargs.get( + num_return_sequences = self.model.get('generation_kwargs', {}).get( 'num_return_sequences', 1) # 5-3. Save current output for prompt, prediction, gold in zip(