[Fix] Fix BailingAPI model (#1707)

* [fix] sequence under the multiple samples

* resolve the lint problems

* change the parameter name

* add another error code for retry

* output the log for invalid response

* format correction

* update

* update

* update

* update

* add two model python files

* update the default parameter

* use random for delay

* update the api example of bailing

* remove the unnecessary parameter
This commit is contained in:
Yi Ding 2024-11-26 19:24:47 +08:00 committed by GitHub
parent ef695e28e5
commit bcb707dbfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 60 additions and 62 deletions

View File

@ -15,13 +15,19 @@ datasets = [
models = [ models = [
dict( dict(
path='Bailing-Lite-0830', path='Bailing-Lite-1116',
token='xxxxxx', # set your key here or in environment variable BAILING_API_KEY token='xxxxxx', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions', url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
generation_kwargs={}, max_out_len=11264,
query_per_second=1, batch_size=1,
max_seq_len=4096, generation_kwargs={
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
},
), ),
] ]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [ models = [
dict( dict(
path='Bailing-Pro-0920', path='Bailing-Lite-1116',
token='', # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions', url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, max_out_len=11264,
max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
'temperature': 0.4, 'temperature': 0.01,
'top_p': 1.0, 'top_p': 1.0,
'top_k': -1, 'top_k': -1,
'n': 1, 'n': 1,
'logprobs': 1, 'logprobs': 1,
'use_beam_search': False,
}, },
), ),
] ]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [ models = [
dict( dict(
path='Bailing-Pro-0920', path='Bailing-Pro-1120',
token='', # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions', url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, max_out_len=11264,
max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
'temperature': 0.4, 'temperature': 0.01,
'top_p': 1.0, 'top_p': 1.0,
'top_k': -1, 'top_k': -1,
'n': 1, 'n': 1,
'logprobs': 1, 'logprobs': 1,
'use_beam_search': False,
}, },
), ),
] ]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [ models = [
dict( dict(
path='Bailing-Lite-0830', path='Bailing-Lite-1116',
token='', # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions', url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, max_out_len=11264,
max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
'temperature': 0.4, 'temperature': 0.01,
'top_p': 1.0, 'top_p': 1.0,
'top_k': -1, 'top_k': -1,
'n': 1, 'n': 1,
'logprobs': 1, 'logprobs': 1,
'use_beam_search': False,
}, },
), ),
] ]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [ models = [
dict( dict(
path='Bailing-Lite-0830', path='Bailing-Pro-1120',
token='', # set your key here or in environment variable BAILING_API_KEY token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions', url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI, type=BailingAPI,
meta_template=api_meta_template, meta_template=api_meta_template,
query_per_second=1, max_out_len=11264,
max_seq_len=4096,
batch_size=1, batch_size=1,
generation_kwargs={ generation_kwargs={
'temperature': 0.4, 'temperature': 0.01,
'top_p': 1.0, 'top_p': 1.0,
'top_k': -1, 'top_k': -1,
'n': 1, 'n': 1,
'logprobs': 1, 'logprobs': 1,
'use_beam_search': False,
}, },
), ),
] ]

View File

@ -1,13 +1,14 @@
import concurrent import concurrent
import concurrent.futures import concurrent.futures
import os import os
import random
import socket import socket
import time import time
import traceback
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import requests import requests
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from urllib3.connection import HTTPConnection from urllib3.connection import HTTPConnection
try: try:
@ -21,8 +22,6 @@ from .base_api import BaseAPIModel
PromptType = Union[PromptList, str] PromptType = Union[PromptList, str]
BAILING_RETRY_DELAY: int = 30
class HTTPAdapterWithSocketOptions(HTTPAdapter): class HTTPAdapterWithSocketOptions(HTTPAdapter):
@ -104,7 +103,7 @@ class BailingAPI(BaseAPIModel):
def generate( def generate(
self, self,
inputs: Union[List[str], PromptList], inputs: Union[List[str], PromptList],
max_out_len: int = 4096, max_out_len: int = 11264,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
@ -128,7 +127,7 @@ class BailingAPI(BaseAPIModel):
): i ): i
for i, input in enumerate(inputs) for i, input in enumerate(inputs)
} }
results = [] results = [''] * len(inputs)
for future in concurrent.futures.as_completed(future_to_m): for future in concurrent.futures.as_completed(future_to_m):
m = future_to_m[future] # noqa F841 m = future_to_m[future] # noqa F841
resp = future.result() resp = future.result()
@ -136,16 +135,25 @@ class BailingAPI(BaseAPIModel):
try: try:
result = resp.json() result = resp.json()
except Exception as e: # noqa F841 except Exception as e: # noqa F841
results.append('') self.logger.error(f'Fail to inference; '
f'model_name={self.path}; '
f'error={e}, '
f'request={inputs[m]}')
else: else:
if (result.get('choices') if (result.get('choices')
and result['choices'][0].get('message') and and result['choices'][0].get('message') and
result['choices'][0]['message'].get('content') result['choices'][0]['message'].get('content')
is not None): is not None):
results.append( results[m] = \
result['choices'][0]['message']['content']) result['choices'][0]['message']['content']
else: else:
results.append('') self.logger.error(f'Receive invalid result. '
f'result={result}; '
f'request={inputs[m]}')
else:
self.logger.error(f'Receive invalid response. '
f'response={resp}; '
f'request={inputs[m]}')
self.flush() self.flush()
return results return results
@ -184,39 +192,31 @@ class BailingAPI(BaseAPIModel):
message['role'] = item['role'] message['role'] = item['role']
messages.append(message) messages.append(message)
request = { request = {
'model': 'model': self._model,
self._model, 'messages': messages,
'messages': 'max_tokens': max_out_len,
messages,
'max_seq_len':
max(
max_out_len if max_out_len else 4096,
self.max_seq_len if self.max_seq_len else 4096,
),
} }
request.update(self.generation_kwargs) request.update(self.generation_kwargs)
try:
retry_num = 0 retry_num = 0
while retry_num < self.retry: while retry_num < self.retry:
try:
response = self._infer_result(request, sess) response = self._infer_result(request, sess)
except ConnectionError:
time.sleep(random.randint(10, 30))
retry_num += 1 # retry
continue
if response.status_code == 200: if response.status_code == 200:
break # success break # success
elif response.status_code == 426: elif response.status_code == 426:
retry_num += 1 # retry retry_num += 1 # retry
elif response.status_code in [429, 500, 504]: elif response.status_code in [302, 429, 500, 504]:
time.sleep(BAILING_RETRY_DELAY) time.sleep(random.randint(10, 30))
retry_num += 1 # retry retry_num += 1 # retry
else: else:
raise ValueError(f'Status code = {response.status_code}') raise ValueError(f'Status code = {response.status_code}')
else: else:
raise ValueError( # Exceed the maximal retry times.
f'Exceed the maximal retry times. Last status code ' return ''
f'= {response.status_code}')
except Exception as e:
self.logger.error(f'Fail to inference request={request}; '
f'model_name={self.path}; error={e}, '
f'stack:{traceback.format_exc()}')
raise e
return response return response
# @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms # @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms