OpenCompass/opencompass/models/bailing_api_oc.py
Yi Ding bcb707dbfc
[Fix] Fix BailingAPI model (#1707)
* [fix] sequence under the multiple samples

* resolve the lint problems

* change the parameter name

* add another error code for retry

* output the log for invalid response

* format correction

* update

* update

* update

* update

* add two model python files

* update the default parameter

* use random for delay

* update the api example of bailing

* remove the unnecessary parameter
2024-11-26 19:24:47 +08:00

232 lines
8.1 KiB
Python

import concurrent
import concurrent.futures
import os
import random
import socket
import time
from typing import Dict, List, Optional, Union
import requests
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from urllib3.connection import HTTPConnection
try:
from retrying import retry
except ImportError:
retry = None
from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
class HTTPAdapterWithSocketOptions(HTTPAdapter):
def __init__(self, *args, **kwargs):
self._socket_options = HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75),
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30),
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120),
]
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs):
if self._socket_options is not None:
kwargs['socket_options'] = self._socket_options
super(HTTPAdapterWithSocketOptions,
self).init_poolmanager(*args, **kwargs)
class BailingAPI(BaseAPIModel):
"""Model wrapper around Bailing Service.
Args:
ouput_key (str): key for prediction
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
generation_kwargs: other params
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def __init__(
self,
path: str,
token: str,
url: str,
meta_template: Optional[Dict] = None,
query_per_second: int = 1,
retry: int = 3,
generation_kwargs: Dict = {},
max_seq_len=4096,
):
super().__init__(
path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry,
generation_kwargs=generation_kwargs,
)
self.logger.info(f'Bailing API Model Init path: {path} url={url}')
if not token:
token = os.environ.get('BAILING_API_KEY')
if token:
self._headers = {'Authorization': f'Bearer {token}'}
else:
raise RuntimeError('There is not valid token.')
else:
self._headers = {'Authorization': f'Bearer {token}'}
self._headers['Content-Type'] = 'application/json'
self._url = (url if url else
'https://bailingchat.alipay.com/chat/completions')
self._model = path
self._sessions = []
self._num = (int(os.environ.get('BAILING_API_PARALLEL_NUM'))
if os.environ.get('BAILING_API_PARALLEL_NUM') else 1)
try:
for _ in range(self._num):
adapter = HTTPAdapterWithSocketOptions()
sess = requests.Session()
sess.mount('http://', adapter)
sess.mount('https://', adapter)
self._sessions.append(sess)
except Exception as e:
self.logger.error(f'Fail to setup the session. {e}')
raise e
def generate(
self,
inputs: Union[List[str], PromptList],
max_out_len: int = 11264,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (Union[List[str], PromptList]):
A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._num, ) as executor:
future_to_m = {
executor.submit(
self._generate,
self._sessions[i % self._num],
input,
max_out_len,
): i
for i, input in enumerate(inputs)
}
results = [''] * len(inputs)
for future in concurrent.futures.as_completed(future_to_m):
m = future_to_m[future] # noqa F841
resp = future.result()
if resp and resp.status_code == 200:
try:
result = resp.json()
except Exception as e: # noqa F841
self.logger.error(f'Fail to inference; '
f'model_name={self.path}; '
f'error={e}, '
f'request={inputs[m]}')
else:
if (result.get('choices')
and result['choices'][0].get('message') and
result['choices'][0]['message'].get('content')
is not None):
results[m] = \
result['choices'][0]['message']['content']
else:
self.logger.error(f'Receive invalid result. '
f'result={result}; '
f'request={inputs[m]}')
else:
self.logger.error(f'Receive invalid response. '
f'response={resp}; '
f'request={inputs[m]}')
self.flush()
return results
def _generate(
self,
sess,
input: Union[str, PromptList],
max_out_len: int,
) -> str:
"""Generate results given an input.
Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass' API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
for item in input:
content = item['prompt']
if not content:
continue
message = {'content': content}
if item['role'] == 'HUMAN':
message['role'] = 'user'
elif item['role'] == 'BOT':
message['role'] = 'assistant'
elif item['role'] == 'SYSTEM':
message['role'] = 'system'
else:
message['role'] = item['role']
messages.append(message)
request = {
'model': self._model,
'messages': messages,
'max_tokens': max_out_len,
}
request.update(self.generation_kwargs)
retry_num = 0
while retry_num < self.retry:
try:
response = self._infer_result(request, sess)
except ConnectionError:
time.sleep(random.randint(10, 30))
retry_num += 1 # retry
continue
if response.status_code == 200:
break # success
elif response.status_code == 426:
retry_num += 1 # retry
elif response.status_code in [302, 429, 500, 504]:
time.sleep(random.randint(10, 30))
retry_num += 1 # retry
else:
raise ValueError(f'Status code = {response.status_code}')
else:
# Exceed the maximal retry times.
return ''
return response
# @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
def _infer_result(self, request, sess):
response = sess.request(
'POST',
self._url,
json=request,
headers=self._headers,
timeout=500,
)
return response