[Bug] Update api with generation_kargs (#614)

* update api

* update generation_kwargs impl

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Songyang Zhang 2023-11-22 10:02:57 +08:00 committed by GitHub
parent eb56fd6d16
commit 721a45c68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 17 deletions

View File

@ -19,6 +19,8 @@ class BaseModel:
meta_template (Dict, optional): The model's meta prompt meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or template if needed, in case the requirement of injecting or
wrapping of any meta instructions. wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
""" """
is_api: bool = False is_api: bool = False
@ -27,7 +29,8 @@ class BaseModel:
path: str, path: str,
max_seq_len: int = 2048, max_seq_len: int = 2048,
tokenizer_only: bool = False, tokenizer_only: bool = False,
meta_template: Optional[Dict] = None): meta_template: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = dict()):
self.path = path self.path = path
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only self.tokenizer_only = tokenizer_only
@ -36,6 +39,7 @@ class BaseModel:
self.eos_token_id = None self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template: if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id'] self.eos_token_id = meta_template['eos_token_id']
self.generation_kwargs = generation_kwargs
@abstractmethod @abstractmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]: def generate(self, inputs: List[str], max_out_len: int) -> List[str]:

View File

@ -28,6 +28,8 @@ class BaseAPIModel(BaseModel):
meta_template (Dict, optional): The model's meta prompt meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or template if needed, in case the requirement of injecting or
wrapping of any meta instructions. wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
""" """
is_api: bool = True is_api: bool = True
@ -37,7 +39,8 @@ class BaseAPIModel(BaseModel):
query_per_second: int = 1, query_per_second: int = 1,
retry: int = 2, retry: int = 2,
max_seq_len: int = 2048, max_seq_len: int = 2048,
meta_template: Optional[Dict] = None): meta_template: Optional[Dict] = None,
generation_kwargs: Dict = dict()):
self.path = path self.path = path
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.meta_template = meta_template self.meta_template = meta_template
@ -46,6 +49,7 @@ class BaseAPIModel(BaseModel):
self.token_bucket = TokenBucket(query_per_second) self.token_bucket = TokenBucket(query_per_second)
self.template_parser = APITemplateParser(meta_template) self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger() self.logger = get_logger()
self.generation_kwargs = generation_kwargs
@abstractmethod @abstractmethod
def generate(self, inputs: List[PromptType], def generate(self, inputs: List[PromptType],

View File

@ -16,25 +16,22 @@ class LightllmAPI(BaseAPIModel):
is_api: bool = True is_api: bool = True
def __init__( def __init__(
self, self,
path: str = 'LightllmAPI', path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate', url: str = 'http://localhost:8080/generate',
max_seq_len: int = 2048, max_seq_len: int = 2048,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
retry: int = 2, retry: int = 2,
generation_kwargs: Optional[Dict] = None, generation_kwargs: Optional[Dict] = dict(),
): ):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template, meta_template=meta_template,
retry=retry) retry=retry,
generation_kwargs=generation_kwargs)
self.logger = get_logger() self.logger = get_logger()
self.url = url self.url = url
if generation_kwargs is not None:
self.generation_kwargs = generation_kwargs
else:
self.generation_kwargs = {}
self.do_sample = self.generation_kwargs.get('do_sample', False) self.do_sample = self.generation_kwargs.get('do_sample', False)
self.ignore_eos = self.generation_kwargs.get('ignore_eos', False) self.ignore_eos = self.generation_kwargs.get('ignore_eos', False)

View File

@ -54,7 +54,6 @@ class TurboMindModel(BaseModel):
tm_model.create_instance() for i in range(concurrency) tm_model.create_instance() for i in range(concurrency)
] ]
self.generator_ids = [i + 1 for i in range(concurrency)] self.generator_ids = [i + 1 for i in range(concurrency)]
self.generation_kwargs = dict()
def generate( def generate(
self, self,

View File

@ -53,7 +53,6 @@ class TurboMindTisModel(BaseModel):
if meta_template and 'eos_token_id' in meta_template: if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id'] self.eos_token_id = meta_template['eos_token_id']
self.tis_addr = tis_addr self.tis_addr = tis_addr
self.generation_kwargs = dict()
def generate( def generate(
self, self,

View File

@ -130,7 +130,7 @@ class GenInferencer(BaseInferencer):
entry, max_out_len=self.max_out_len) entry, max_out_len=self.max_out_len)
generated = results generated = results
num_return_sequences = self.model.generation_kwargs.get( num_return_sequences = self.model.get('generation_kwargs', {}).get(
'num_return_sequences', 1) 'num_return_sequences', 1)
# 5-3. Save current output # 5-3. Save current output
for prompt, prediction, gold in zip( for prompt, prediction, gold in zip(