OpenCompass/opencompass/models/bailing_api_oc.py

226 lines
7.6 KiB
Python
Raw Normal View History

import concurrent
import concurrent.futures
import os
import socket
import traceback
from typing import Dict, List, Optional, Union
import requests
from requests.adapters import HTTPAdapter
from urllib3.connection import HTTPConnection
try:
from retrying import retry
except ImportError:
retry = None
from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
class HTTPAdapterWithSocketOptions(HTTPAdapter):
def __init__(self, *args, **kwargs):
self._socket_options = HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75),
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30),
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120),
]
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs):
if self._socket_options is not None:
kwargs['socket_options'] = self._socket_options
super(HTTPAdapterWithSocketOptions,
self).init_poolmanager(*args, **kwargs)
class BailingAPI(BaseAPIModel):
"""Model wrapper around Bailing Service.
Args:
ouput_key (str): key for prediction
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
generation_kwargs: other params
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def __init__(
self,
path: str,
token: str,
url: str,
meta_template: Optional[Dict] = None,
query_per_second: int = 1,
retry: int = 3,
generation_kwargs: Dict = {},
max_seq_len=4096,
):
super().__init__(
path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry,
generation_kwargs=generation_kwargs,
)
self.logger.info(f'Bailing API Model Init path: {path} url={url}')
if not token:
token = os.environ.get('BAILING_API_KEY')
if token:
self._headers = {'Authorization': f'Bearer {token}'}
else:
raise RuntimeError('There is not valid token.')
2024-09-27 11:56:57 +08:00
else:
self._headers = {'Authorization': f'Bearer {token}'}
self._headers['Content-Type'] = 'application/json'
self._url = (url if url else
'https://bailingchat.alipay.com/chat/completions')
self._model = path
self._sessions = []
self._num = (int(os.environ.get('BAILING_API_PARALLEL_NUM'))
if os.environ.get('BAILING_API_PARALLEL_NUM') else 1)
try:
for _ in range(self._num):
adapter = HTTPAdapterWithSocketOptions()
sess = requests.Session()
sess.mount('http://', adapter)
sess.mount('https://', adapter)
self._sessions.append(sess)
except Exception as e:
self.logger.error(f'Fail to setup the session. {e}')
raise e
def generate(
self,
inputs: Union[List[str], PromptList],
max_out_len: int = 4096,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (Union[List[str], PromptList]):
A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._num, ) as executor:
future_to_m = {
executor.submit(
self._generate,
self._sessions[i % self._num],
input,
max_out_len,
): i
for i, input in enumerate(inputs)
}
results = []
for future in concurrent.futures.as_completed(future_to_m):
m = future_to_m[future] # noqa F841
resp = future.result()
if resp and resp.status_code == 200:
try:
result = resp.json()
except Exception as e: # noqa F841
results.append('')
else:
if (result.get('choices')
and result['choices'][0].get('message') and
result['choices'][0]['message'].get('content')
is not None):
results.append(
result['choices'][0]['message']['content'])
else:
results.append('')
self.flush()
return results
def _generate(
self,
sess,
input: Union[str, PromptList],
max_out_len: int,
) -> str:
"""Generate results given an input.
Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass' API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
for item in input:
content = item['prompt']
if not content:
continue
message = {'content': content}
if item['role'] == 'HUMAN':
message['role'] = 'user'
elif item['role'] == 'BOT':
message['role'] = 'assistant'
elif item['role'] == 'SYSTEM':
message['role'] = 'system'
else:
message['role'] = item['role']
messages.append(message)
request = {
'model':
self._model,
'messages':
messages,
'max_seq_len':
max(
max_out_len if max_out_len else 4096,
self.max_seq_len if self.max_seq_len else 4096,
),
}
request.update(self.generation_kwargs)
try:
retry_num = 0
while retry_num < self.retry:
response = self._infer_result(request, sess)
if response.status_code == 200:
break # success
elif response.status_code == 426:
retry_num += 1 # retry
else:
raise ValueError(f'Status code = {response.status_code}')
else:
raise ValueError(
f'Exceed the maximal retry times. Last status code '
f'= {response.status_code}')
except Exception as e:
self.logger.error(f'Fail to inference request={request}; '
f'model_name={self.path}; error={e}, '
f'stack:{traceback.format_exc()}')
raise e
return response
# @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
def _infer_result(self, request, sess):
response = sess.request(
'POST',
self._url,
json=request,
headers=self._headers,
timeout=500,
)
return response