Support LightllmApi input_format (#888)

This commit is contained in:
Yang Yong 2024-02-19 10:02:59 +08:00 committed by GitHub
parent 08133e060a
commit b6e21ece38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 0 deletions

View File

@ -14,6 +14,7 @@ models = [
abbr='LightllmAPI',
type=LightllmAPI,
url='http://localhost:8080/generate',
input_format='<input_text_to_replace>',
max_seq_len=2048,
batch_size=32,
generation_kwargs=dict(

View File

@ -20,6 +20,7 @@ class LightllmAPI(BaseAPIModel):
self,
path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate',
input_format: str = '<input_text_to_replace>',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
@ -33,6 +34,7 @@ class LightllmAPI(BaseAPIModel):
generation_kwargs=generation_kwargs)
self.logger = get_logger()
self.url = url
self.input_format = input_format
self.generation_kwargs = generation_kwargs
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
@ -62,6 +64,8 @@ class LightllmAPI(BaseAPIModel):
self.wait()
header = {'content-type': 'application/json'}
try:
input = self.input_format.replace('<input_text_to_replace>',
input)
data = dict(inputs=input, parameters=self.generation_kwargs)
raw_response = requests.post(self.url,
headers=header,
@ -114,6 +118,8 @@ class LightllmAPI(BaseAPIModel):
self.wait()
header = {'content-type': 'application/json'}
try:
input = self.input_format.replace('<input_text_to_replace>',
input)
data = dict(inputs=input, parameters=self.generation_kwargs)
raw_response = requests.post(self.url,
headers=header,