[Fix] Fix BailingAPI model (#1707)

* [fix] sequence under the multiple samples

* resolve the lint problems

* change the parameter name

* add another error code for retry

* output the log for invalid response

* format correction

* update

* update

* update

* update

* add two model python files

* update the default parameter

* use random for delay

* update the api example of bailing

* remove the unnecessary parameter
This commit is contained in:
Yi Ding 2024-11-26 19:24:47 +08:00 committed by GitHub
parent ef695e28e5
commit bcb707dbfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 60 additions and 62 deletions

View File

@ -15,13 +15,19 @@ datasets = [
models = [
dict(
path='Bailing-Lite-0830',
path='Bailing-Lite-1116',
token='xxxxxx', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
generation_kwargs={},
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
},
),
]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [
dict(
path='Bailing-Pro-0920',
path='Bailing-Lite-1116',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [
dict(
path='Bailing-Pro-0920',
path='Bailing-Pro-1120',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [
dict(
path='Bailing-Lite-0830',
path='Bailing-Lite-1116',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]

View File

@ -10,21 +10,19 @@ api_meta_template = dict(
models = [
dict(
path='Bailing-Lite-0830',
path='Bailing-Pro-1120',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]

View File

@ -1,13 +1,14 @@
import concurrent
import concurrent.futures
import os
import random
import socket
import time
import traceback
from typing import Dict, List, Optional, Union
import requests
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from urllib3.connection import HTTPConnection
try:
@ -21,8 +22,6 @@ from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
BAILING_RETRY_DELAY: int = 30
class HTTPAdapterWithSocketOptions(HTTPAdapter):
@ -104,7 +103,7 @@ class BailingAPI(BaseAPIModel):
def generate(
self,
inputs: Union[List[str], PromptList],
max_out_len: int = 4096,
max_out_len: int = 11264,
) -> List[str]:
"""Generate results given a list of inputs.
@ -128,7 +127,7 @@ class BailingAPI(BaseAPIModel):
): i
for i, input in enumerate(inputs)
}
results = []
results = [''] * len(inputs)
for future in concurrent.futures.as_completed(future_to_m):
m = future_to_m[future] # noqa F841
resp = future.result()
@ -136,16 +135,25 @@ class BailingAPI(BaseAPIModel):
try:
result = resp.json()
except Exception as e: # noqa F841
results.append('')
self.logger.error(f'Fail to inference; '
f'model_name={self.path}; '
f'error={e}, '
f'request={inputs[m]}')
else:
if (result.get('choices')
and result['choices'][0].get('message') and
result['choices'][0]['message'].get('content')
is not None):
results.append(
result['choices'][0]['message']['content'])
results[m] = \
result['choices'][0]['message']['content']
else:
self.logger.error(f'Receive invalid result. '
f'result={result}; '
f'request={inputs[m]}')
else:
results.append('')
self.logger.error(f'Receive invalid response. '
f'response={resp}; '
f'request={inputs[m]}')
self.flush()
return results
@ -184,39 +192,31 @@ class BailingAPI(BaseAPIModel):
message['role'] = item['role']
messages.append(message)
request = {
'model':
self._model,
'messages':
messages,
'max_seq_len':
max(
max_out_len if max_out_len else 4096,
self.max_seq_len if self.max_seq_len else 4096,
),
'model': self._model,
'messages': messages,
'max_tokens': max_out_len,
}
request.update(self.generation_kwargs)
try:
retry_num = 0
while retry_num < self.retry:
retry_num = 0
while retry_num < self.retry:
try:
response = self._infer_result(request, sess)
if response.status_code == 200:
break # success
elif response.status_code == 426:
retry_num += 1 # retry
elif response.status_code in [429, 500, 504]:
time.sleep(BAILING_RETRY_DELAY)
retry_num += 1 # retry
else:
raise ValueError(f'Status code = {response.status_code}')
except ConnectionError:
time.sleep(random.randint(10, 30))
retry_num += 1 # retry
continue
if response.status_code == 200:
break # success
elif response.status_code == 426:
retry_num += 1 # retry
elif response.status_code in [302, 429, 500, 504]:
time.sleep(random.randint(10, 30))
retry_num += 1 # retry
else:
raise ValueError(
f'Exceed the maximal retry times. Last status code '
f'= {response.status_code}')
except Exception as e:
self.logger.error(f'Fail to inference request={request}; '
f'model_name={self.path}; error={e}, '
f'stack:{traceback.format_exc()}')
raise e
raise ValueError(f'Status code = {response.status_code}')
else:
# Exceed the maximal retry times.
return ''
return response
# @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms