[Bug] Update api with generation_kargs (#614)

* update api

* update generation_kwargs impl

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Songyang Zhang 2023-11-22 10:02:57 +08:00 committed by GitHub
parent eb56fd6d16
commit 721a45c68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 17 deletions

View File

@ -19,6 +19,8 @@ class BaseModel:
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
"""
is_api: bool = False
@ -27,7 +29,8 @@ class BaseModel:
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None):
meta_template: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = dict()):
self.path = path
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
@ -36,6 +39,7 @@ class BaseModel:
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.generation_kwargs = generation_kwargs
@abstractmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:

View File

@ -28,6 +28,8 @@ class BaseAPIModel(BaseModel):
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
"""
is_api: bool = True
@ -37,7 +39,8 @@ class BaseAPIModel(BaseModel):
query_per_second: int = 1,
retry: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None):
meta_template: Optional[Dict] = None,
generation_kwargs: Dict = dict()):
self.path = path
self.max_seq_len = max_seq_len
self.meta_template = meta_template
@ -46,6 +49,7 @@ class BaseAPIModel(BaseModel):
self.token_bucket = TokenBucket(query_per_second)
self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger()
self.generation_kwargs = generation_kwargs
@abstractmethod
def generate(self, inputs: List[PromptType],

View File

@ -22,19 +22,16 @@ class LightllmAPI(BaseAPIModel):
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
generation_kwargs: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = dict(),
):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
retry=retry)
retry=retry,
generation_kwargs=generation_kwargs)
self.logger = get_logger()
self.url = url
if generation_kwargs is not None:
self.generation_kwargs = generation_kwargs
else:
self.generation_kwargs = {}
self.do_sample = self.generation_kwargs.get('do_sample', False)
self.ignore_eos = self.generation_kwargs.get('ignore_eos', False)

View File

@ -54,7 +54,6 @@ class TurboMindModel(BaseModel):
tm_model.create_instance() for i in range(concurrency)
]
self.generator_ids = [i + 1 for i in range(concurrency)]
self.generation_kwargs = dict()
def generate(
self,

View File

@ -53,7 +53,6 @@ class TurboMindTisModel(BaseModel):
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.tis_addr = tis_addr
self.generation_kwargs = dict()
def generate(
self,

View File

@ -130,7 +130,7 @@ class GenInferencer(BaseInferencer):
entry, max_out_len=self.max_out_len)
generated = results
num_return_sequences = self.model.generation_kwargs.get(
num_return_sequences = self.model.get('generation_kwargs', {}).get(
'num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(