mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* [fix] sequence under the multiple samples * resolve the lint problems * change the parameter name * add another error code for retry * output the log for invalid response * format correction * update * update * update * update * add two model python files * update the default parameter * use random for delay * update the api example of bailing * remove the unnecessary parameter
232 lines
8.1 KiB
Python
232 lines
8.1 KiB
Python
import concurrent
|
|
import concurrent.futures
|
|
import os
|
|
import random
|
|
import socket
|
|
import time
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import requests
|
|
from requests.adapters import HTTPAdapter
|
|
from requests.exceptions import ConnectionError
|
|
from urllib3.connection import HTTPConnection
|
|
|
|
try:
|
|
from retrying import retry
|
|
except ImportError:
|
|
retry = None
|
|
|
|
from opencompass.utils.prompt import PromptList
|
|
|
|
from .base_api import BaseAPIModel
|
|
|
|
PromptType = Union[PromptList, str]
|
|
|
|
|
|
class HTTPAdapterWithSocketOptions(HTTPAdapter):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._socket_options = HTTPConnection.default_socket_options + [
|
|
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
|
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75),
|
|
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30),
|
|
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120),
|
|
]
|
|
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs)
|
|
|
|
def init_poolmanager(self, *args, **kwargs):
|
|
if self._socket_options is not None:
|
|
kwargs['socket_options'] = self._socket_options
|
|
super(HTTPAdapterWithSocketOptions,
|
|
self).init_poolmanager(*args, **kwargs)
|
|
|
|
|
|
class BailingAPI(BaseAPIModel):
|
|
"""Model wrapper around Bailing Service.
|
|
|
|
Args:
|
|
ouput_key (str): key for prediction
|
|
query_per_second (int): The maximum queries allowed per second
|
|
between two consecutive calls of the API. Defaults to 1.
|
|
generation_kwargs: other params
|
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
token: str,
|
|
url: str,
|
|
meta_template: Optional[Dict] = None,
|
|
query_per_second: int = 1,
|
|
retry: int = 3,
|
|
generation_kwargs: Dict = {},
|
|
max_seq_len=4096,
|
|
):
|
|
super().__init__(
|
|
path=path,
|
|
max_seq_len=max_seq_len,
|
|
query_per_second=query_per_second,
|
|
meta_template=meta_template,
|
|
retry=retry,
|
|
generation_kwargs=generation_kwargs,
|
|
)
|
|
|
|
self.logger.info(f'Bailing API Model Init path: {path} url={url}')
|
|
if not token:
|
|
token = os.environ.get('BAILING_API_KEY')
|
|
if token:
|
|
self._headers = {'Authorization': f'Bearer {token}'}
|
|
else:
|
|
raise RuntimeError('There is not valid token.')
|
|
else:
|
|
self._headers = {'Authorization': f'Bearer {token}'}
|
|
|
|
self._headers['Content-Type'] = 'application/json'
|
|
self._url = (url if url else
|
|
'https://bailingchat.alipay.com/chat/completions')
|
|
self._model = path
|
|
self._sessions = []
|
|
self._num = (int(os.environ.get('BAILING_API_PARALLEL_NUM'))
|
|
if os.environ.get('BAILING_API_PARALLEL_NUM') else 1)
|
|
try:
|
|
for _ in range(self._num):
|
|
adapter = HTTPAdapterWithSocketOptions()
|
|
sess = requests.Session()
|
|
sess.mount('http://', adapter)
|
|
sess.mount('https://', adapter)
|
|
self._sessions.append(sess)
|
|
except Exception as e:
|
|
self.logger.error(f'Fail to setup the session. {e}')
|
|
raise e
|
|
|
|
def generate(
|
|
self,
|
|
inputs: Union[List[str], PromptList],
|
|
max_out_len: int = 11264,
|
|
) -> List[str]:
|
|
"""Generate results given a list of inputs.
|
|
|
|
Args:
|
|
inputs (Union[List[str], PromptList]):
|
|
A list of strings or PromptDicts.
|
|
The PromptDict should be organized in OpenCompass' API format.
|
|
max_out_len (int): The maximum length of the output.
|
|
|
|
Returns:
|
|
List[str]: A list of generated strings.
|
|
"""
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=self._num, ) as executor:
|
|
future_to_m = {
|
|
executor.submit(
|
|
self._generate,
|
|
self._sessions[i % self._num],
|
|
input,
|
|
max_out_len,
|
|
): i
|
|
for i, input in enumerate(inputs)
|
|
}
|
|
results = [''] * len(inputs)
|
|
for future in concurrent.futures.as_completed(future_to_m):
|
|
m = future_to_m[future] # noqa F841
|
|
resp = future.result()
|
|
if resp and resp.status_code == 200:
|
|
try:
|
|
result = resp.json()
|
|
except Exception as e: # noqa F841
|
|
self.logger.error(f'Fail to inference; '
|
|
f'model_name={self.path}; '
|
|
f'error={e}, '
|
|
f'request={inputs[m]}')
|
|
else:
|
|
if (result.get('choices')
|
|
and result['choices'][0].get('message') and
|
|
result['choices'][0]['message'].get('content')
|
|
is not None):
|
|
results[m] = \
|
|
result['choices'][0]['message']['content']
|
|
else:
|
|
self.logger.error(f'Receive invalid result. '
|
|
f'result={result}; '
|
|
f'request={inputs[m]}')
|
|
else:
|
|
self.logger.error(f'Receive invalid response. '
|
|
f'response={resp}; '
|
|
f'request={inputs[m]}')
|
|
self.flush()
|
|
return results
|
|
|
|
def _generate(
|
|
self,
|
|
sess,
|
|
input: Union[str, PromptList],
|
|
max_out_len: int,
|
|
) -> str:
|
|
"""Generate results given an input.
|
|
|
|
Args:
|
|
inputs (str or PromptList): A string or PromptDict.
|
|
The PromptDict should be organized in OpenCompass' API format.
|
|
max_out_len (int): The maximum length of the output.
|
|
|
|
Returns:
|
|
str: The generated string.
|
|
"""
|
|
if isinstance(input, str):
|
|
messages = [{'role': 'user', 'content': input}]
|
|
else:
|
|
messages = []
|
|
for item in input:
|
|
content = item['prompt']
|
|
if not content:
|
|
continue
|
|
message = {'content': content}
|
|
if item['role'] == 'HUMAN':
|
|
message['role'] = 'user'
|
|
elif item['role'] == 'BOT':
|
|
message['role'] = 'assistant'
|
|
elif item['role'] == 'SYSTEM':
|
|
message['role'] = 'system'
|
|
else:
|
|
message['role'] = item['role']
|
|
messages.append(message)
|
|
request = {
|
|
'model': self._model,
|
|
'messages': messages,
|
|
'max_tokens': max_out_len,
|
|
}
|
|
request.update(self.generation_kwargs)
|
|
retry_num = 0
|
|
while retry_num < self.retry:
|
|
try:
|
|
response = self._infer_result(request, sess)
|
|
except ConnectionError:
|
|
time.sleep(random.randint(10, 30))
|
|
retry_num += 1 # retry
|
|
continue
|
|
if response.status_code == 200:
|
|
break # success
|
|
elif response.status_code == 426:
|
|
retry_num += 1 # retry
|
|
elif response.status_code in [302, 429, 500, 504]:
|
|
time.sleep(random.randint(10, 30))
|
|
retry_num += 1 # retry
|
|
else:
|
|
raise ValueError(f'Status code = {response.status_code}')
|
|
else:
|
|
# Exceed the maximal retry times.
|
|
return ''
|
|
return response
|
|
|
|
# @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
|
|
def _infer_result(self, request, sess):
|
|
response = sess.request(
|
|
'POST',
|
|
self._url,
|
|
json=request,
|
|
headers=self._headers,
|
|
timeout=500,
|
|
)
|
|
return response
|