mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Bug] Update api with generation_kargs (#614)
* update api * update generation_kwargs impl --------- Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
eb56fd6d16
commit
721a45c68f
@ -19,6 +19,8 @@ class BaseModel:
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
generation_kwargs (Dict, optional): The generation kwargs for the
|
||||
model. Defaults to dict().
|
||||
"""
|
||||
|
||||
is_api: bool = False
|
||||
@ -27,7 +29,8 @@ class BaseModel:
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_only: bool = False,
|
||||
meta_template: Optional[Dict] = None):
|
||||
meta_template: Optional[Dict] = None,
|
||||
generation_kwargs: Optional[Dict] = dict()):
|
||||
self.path = path
|
||||
self.max_seq_len = max_seq_len
|
||||
self.tokenizer_only = tokenizer_only
|
||||
@ -36,6 +39,7 @@ class BaseModel:
|
||||
self.eos_token_id = None
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
self.generation_kwargs = generation_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
|
@ -28,6 +28,8 @@ class BaseAPIModel(BaseModel):
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
generation_kwargs (Dict, optional): The generation kwargs for the
|
||||
model. Defaults to dict().
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
@ -37,7 +39,8 @@ class BaseAPIModel(BaseModel):
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None):
|
||||
meta_template: Optional[Dict] = None,
|
||||
generation_kwargs: Dict = dict()):
|
||||
self.path = path
|
||||
self.max_seq_len = max_seq_len
|
||||
self.meta_template = meta_template
|
||||
@ -46,6 +49,7 @@ class BaseAPIModel(BaseModel):
|
||||
self.token_bucket = TokenBucket(query_per_second)
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
self.logger = get_logger()
|
||||
self.generation_kwargs = generation_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, inputs: List[PromptType],
|
||||
|
@ -16,25 +16,22 @@ class LightllmAPI(BaseAPIModel):
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str = 'LightllmAPI',
|
||||
url: str = 'http://localhost:8080/generate',
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
generation_kwargs: Optional[Dict] = None,
|
||||
self,
|
||||
path: str = 'LightllmAPI',
|
||||
url: str = 'http://localhost:8080/generate',
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
retry: int = 2,
|
||||
generation_kwargs: Optional[Dict] = dict(),
|
||||
):
|
||||
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
generation_kwargs=generation_kwargs)
|
||||
self.logger = get_logger()
|
||||
self.url = url
|
||||
if generation_kwargs is not None:
|
||||
self.generation_kwargs = generation_kwargs
|
||||
else:
|
||||
self.generation_kwargs = {}
|
||||
self.do_sample = self.generation_kwargs.get('do_sample', False)
|
||||
self.ignore_eos = self.generation_kwargs.get('ignore_eos', False)
|
||||
|
||||
|
@ -54,7 +54,6 @@ class TurboMindModel(BaseModel):
|
||||
tm_model.create_instance() for i in range(concurrency)
|
||||
]
|
||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||
self.generation_kwargs = dict()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
@ -53,7 +53,6 @@ class TurboMindTisModel(BaseModel):
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
self.tis_addr = tis_addr
|
||||
self.generation_kwargs = dict()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
@ -130,7 +130,7 @@ class GenInferencer(BaseInferencer):
|
||||
entry, max_out_len=self.max_out_len)
|
||||
generated = results
|
||||
|
||||
num_return_sequences = self.model.generation_kwargs.get(
|
||||
num_return_sequences = self.model.get('generation_kwargs', {}).get(
|
||||
'num_return_sequences', 1)
|
||||
# 5-3. Save current output
|
||||
for prompt, prediction, gold in zip(
|
||||
|
Loading…
Reference in New Issue
Block a user