mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
216 lines
7.3 KiB
Python
216 lines
7.3 KiB
Python
![]() |
import concurrent
|
||
|
import concurrent.futures
|
||
|
import os
|
||
|
import socket
|
||
|
import traceback
|
||
|
from typing import Dict, List, Optional, Union
|
||
|
|
||
|
import requests
|
||
|
from requests.adapters import HTTPAdapter
|
||
|
from retrying import retry
|
||
|
from urllib3.connection import HTTPConnection
|
||
|
|
||
|
from opencompass.utils.prompt import PromptList
|
||
|
|
||
|
from .base_api import BaseAPIModel
|
||
|
|
||
|
PromptType = Union[PromptList, str]
|
||
|
|
||
|
|
||
|
class HTTPAdapterWithSocketOptions(HTTPAdapter):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
self._socket_options = HTTPConnection.default_socket_options + [
|
||
|
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
||
|
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75),
|
||
|
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30),
|
||
|
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120),
|
||
|
]
|
||
|
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs)
|
||
|
|
||
|
def init_poolmanager(self, *args, **kwargs):
|
||
|
if self._socket_options is not None:
|
||
|
kwargs["socket_options"] = self._socket_options
|
||
|
super(HTTPAdapterWithSocketOptions, self).init_poolmanager(*args, **kwargs)
|
||
|
|
||
|
|
||
|
class BailingAPI(BaseAPIModel):
|
||
|
"""Model wrapper around Bailing Service.
|
||
|
|
||
|
Args:
|
||
|
ouput_key (str): key for prediction
|
||
|
query_per_second (int): The maximum queries allowed per second
|
||
|
between two consecutive calls of the API. Defaults to 1.
|
||
|
generation_kwargs: other params
|
||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
path: str,
|
||
|
token: str,
|
||
|
url: str,
|
||
|
meta_template: Optional[Dict] = None,
|
||
|
query_per_second: int = 1,
|
||
|
retry: int = 3,
|
||
|
generation_kwargs: Dict = {},
|
||
|
max_seq_len=4096,
|
||
|
):
|
||
|
super().__init__(
|
||
|
path=path,
|
||
|
max_seq_len=max_seq_len,
|
||
|
query_per_second=query_per_second,
|
||
|
meta_template=meta_template,
|
||
|
retry=retry,
|
||
|
generation_kwargs=generation_kwargs,
|
||
|
)
|
||
|
|
||
|
self.logger.info(f"Bailing API Model Init path: {path} url={url}")
|
||
|
if not token:
|
||
|
token = os.environ.get("BAILING_API_KEY")
|
||
|
if token:
|
||
|
self._headers = {"Authorization": f"Bearer {token}"}
|
||
|
else:
|
||
|
raise RuntimeError(f"There is not valid token.")
|
||
|
self._headers["Content-Type"] = "application/json"
|
||
|
self._url = url if url else "https://bailingchat.alipay.com/chat/completions"
|
||
|
self._model = path
|
||
|
self._sessions = []
|
||
|
self._num = (
|
||
|
int(os.environ.get("BAILING_API_PARALLEL_NUM"))
|
||
|
if os.environ.get("BAILING_API_PARALLEL_NUM")
|
||
|
else 1
|
||
|
)
|
||
|
try:
|
||
|
for _ in range(self._num):
|
||
|
adapter = HTTPAdapterWithSocketOptions()
|
||
|
sess = requests.Session()
|
||
|
sess.mount("http://", adapter)
|
||
|
sess.mount("https://", adapter)
|
||
|
self._sessions.append(sess)
|
||
|
except Exception as e:
|
||
|
self.logger.error(f"Fail to setup the session. {e}")
|
||
|
raise e
|
||
|
|
||
|
def generate(
|
||
|
self,
|
||
|
inputs: Union[List[str], PromptList],
|
||
|
max_out_len: int = 4096,
|
||
|
) -> List[str]:
|
||
|
"""Generate results given a list of inputs.
|
||
|
|
||
|
Args:
|
||
|
inputs (Union[List[str], PromptList]): A list of strings or PromptDicts.
|
||
|
The PromptDict should be organized in OpenCompass' API format.
|
||
|
max_out_len (int): The maximum length of the output.
|
||
|
|
||
|
Returns:
|
||
|
List[str]: A list of generated strings.
|
||
|
"""
|
||
|
with concurrent.futures.ThreadPoolExecutor(
|
||
|
max_workers=self._num,
|
||
|
) as executor:
|
||
|
future_to_m = {
|
||
|
executor.submit(
|
||
|
self._generate,
|
||
|
self._sessions[i % self._num],
|
||
|
input,
|
||
|
max_out_len,
|
||
|
): i
|
||
|
for i, input in enumerate(inputs)
|
||
|
}
|
||
|
results = []
|
||
|
for future in concurrent.futures.as_completed(future_to_m):
|
||
|
m = future_to_m[future]
|
||
|
resp = future.result()
|
||
|
if resp and resp.status_code == 200:
|
||
|
try:
|
||
|
result = resp.json()
|
||
|
except:
|
||
|
results.append("")
|
||
|
else:
|
||
|
if (
|
||
|
result.get("choices")
|
||
|
and result["choices"][0].get("message")
|
||
|
and result["choices"][0]["message"].get("content")
|
||
|
):
|
||
|
results.append(result["choices"][0]["message"]["content"])
|
||
|
else:
|
||
|
results.append("")
|
||
|
self.flush()
|
||
|
return results
|
||
|
|
||
|
def _generate(
|
||
|
self,
|
||
|
sess,
|
||
|
input: Union[str, PromptList],
|
||
|
max_out_len: int,
|
||
|
) -> str:
|
||
|
"""Generate results given an input.
|
||
|
|
||
|
Args:
|
||
|
inputs (str or PromptList): A string or PromptDict.
|
||
|
The PromptDict should be organized in OpenCompass' API format.
|
||
|
max_out_len (int): The maximum length of the output.
|
||
|
|
||
|
Returns:
|
||
|
str: The generated string.
|
||
|
"""
|
||
|
if isinstance(input, str):
|
||
|
messages = [{"role": "user", "content": input}]
|
||
|
else:
|
||
|
messages = []
|
||
|
for item in input:
|
||
|
content = item["prompt"]
|
||
|
if not content:
|
||
|
continue
|
||
|
message = {"content": content}
|
||
|
if item["role"] == "HUMAN":
|
||
|
message["role"] = "user"
|
||
|
elif item["role"] == "BOT":
|
||
|
message["role"] = "assistant"
|
||
|
elif item["role"] == "SYSTEM":
|
||
|
message["role"] = "system"
|
||
|
else:
|
||
|
message["role"] = item["role"]
|
||
|
messages.append(message)
|
||
|
request = {
|
||
|
"model": self._model,
|
||
|
"messages": messages,
|
||
|
"max_seq_len": max(
|
||
|
max_out_len if max_out_len else 4096,
|
||
|
self.max_seq_len if self.max_seq_len else 4096,
|
||
|
),
|
||
|
}
|
||
|
request.update(self.generation_kwargs)
|
||
|
try:
|
||
|
retry_num = 0
|
||
|
while retry_num < self.retry:
|
||
|
response = self._infer_result(request, sess)
|
||
|
if response.status_code == 200:
|
||
|
break # success
|
||
|
elif response.status_code == 426:
|
||
|
retry_num += 1 # retry
|
||
|
else:
|
||
|
raise ValueError(f"Status code = {response.status_code}")
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Exceed the maximal retry times. Last status code = {response.status_code}"
|
||
|
)
|
||
|
except Exception as e:
|
||
|
self.logger.error(
|
||
|
f"Fail to inference request={request}; model_name={self.path}; error={e}, stack:{traceback.format_exc()}"
|
||
|
)
|
||
|
raise e
|
||
|
return response
|
||
|
|
||
|
@retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
|
||
|
def _infer_result(self, request, sess):
|
||
|
response = sess.request(
|
||
|
"POST",
|
||
|
self._url,
|
||
|
json=request,
|
||
|
headers=self._headers,
|
||
|
timeout=500,
|
||
|
)
|
||
|
return response
|