mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support AntFinix LLM
This commit is contained in:
parent
97010dc4ce
commit
50bcffc4f7
12
examples/eval_antfinix_api.py
Normal file
12
examples/eval_antfinix_api.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from opencompass.configs.datasets.demo.demo_gsm8k_chat_gen import \
|
||||||
|
gsm8k_datasets
|
||||||
|
from opencompass.configs.datasets.demo.demo_math_chat_gen import \
|
||||||
|
math_datasets
|
||||||
|
from opencompass.configs.models.antfinix_api.antfinix_20250418 import \
|
||||||
|
models as antfinix
|
||||||
|
|
||||||
|
datasets = gsm8k_datasets + math_datasets
|
||||||
|
models = antfinix
|
19
opencompass/configs/models/antfinix_api/antfinix_20250418.py
Normal file
19
opencompass/configs/models/antfinix_api/antfinix_20250418.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from opencompass.models import AntFinixAPI
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
path='035A54D2-9A48-021A-8ED7-C6758F3344AF',
|
||||||
|
key='', # set your key here or in environment variable ANTFINIX_API_KEY
|
||||||
|
url='https://fin-evaluator-gw.antgroup.com/api/v1/finEvaluator/evaluate',
|
||||||
|
type=AntFinixAPI,
|
||||||
|
max_out_len=32 * 1024,
|
||||||
|
batch_size=1,
|
||||||
|
generation_kwargs={
|
||||||
|
'temperature': 1.0,
|
||||||
|
'logprobs': 0,
|
||||||
|
'top_p': 1.0,
|
||||||
|
'top_k': -1,
|
||||||
|
'n': 1,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
@ -49,3 +49,4 @@ from .yayi_api import Yayi # noqa: F401
|
|||||||
from .yi_api import YiAPI # noqa: F401
|
from .yi_api import YiAPI # noqa: F401
|
||||||
from .zhipuai_api import ZhiPuAI # noqa: F401
|
from .zhipuai_api import ZhiPuAI # noqa: F401
|
||||||
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
|
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
|
||||||
|
from .antfinix_api import AntFinixAPI # noqa: F401
|
240
opencompass/models/antfinix_api.py
Normal file
240
opencompass/models/antfinix_api.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
import base64
|
||||||
|
import concurrent
|
||||||
|
import concurrent.futures
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from requests.exceptions import ConnectionError
|
||||||
|
from urllib3.connection import HTTPConnection
|
||||||
|
|
||||||
|
try:
|
||||||
|
from retrying import retry
|
||||||
|
except ImportError:
|
||||||
|
retry = None
|
||||||
|
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
|
||||||
|
from .base_api import BaseAPIModel
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPAdapterWithSocketOptions(HTTPAdapter):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._socket_options = HTTPConnection.default_socket_options + [
|
||||||
|
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
||||||
|
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75),
|
||||||
|
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30),
|
||||||
|
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120),
|
||||||
|
]
|
||||||
|
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_poolmanager(self, *args, **kwargs):
|
||||||
|
if self._socket_options is not None:
|
||||||
|
kwargs['socket_options'] = self._socket_options
|
||||||
|
super(HTTPAdapterWithSocketOptions,
|
||||||
|
self).init_poolmanager(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AntFinixAPI(BaseAPIModel):
|
||||||
|
"""Model wrapper around AntFinix Service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ouput_key (str): key for prediction
|
||||||
|
query_per_second (int): The maximum queries allowed per second
|
||||||
|
between two consecutive calls of the API. Defaults to 1.
|
||||||
|
generation_kwargs: other params
|
||||||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
key: str,
|
||||||
|
url: str,
|
||||||
|
meta_template: Optional[Dict] = None,
|
||||||
|
query_per_second: int = 1,
|
||||||
|
retry: int = 3,
|
||||||
|
generation_kwargs: Dict = {},
|
||||||
|
max_seq_len=32 * 1024,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
meta_template=meta_template,
|
||||||
|
retry=retry,
|
||||||
|
generation_kwargs=generation_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f'AntFinix API Model Init path: {path} url={url}')
|
||||||
|
|
||||||
|
self._key = key
|
||||||
|
self._url = (url if url else 'https://fin-evaluator-gw.antgroup.com/api/v1/finEvaluator/evaluate')
|
||||||
|
self._model = path
|
||||||
|
self._sessions = []
|
||||||
|
self._num = (int(os.environ.get('ANTFINIX_API_PARALLEL_NUM'))
|
||||||
|
if os.environ.get('ANTFINIX_API_PARALLEL_NUM') else 1)
|
||||||
|
try:
|
||||||
|
for _ in range(self._num):
|
||||||
|
adapter = HTTPAdapterWithSocketOptions()
|
||||||
|
sess = requests.Session()
|
||||||
|
sess.mount('http://', adapter)
|
||||||
|
sess.mount('https://', adapter)
|
||||||
|
self._sessions.append(sess)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f'Fail to setup the session. {e}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: Union[List[str], PromptList],
|
||||||
|
max_out_len: int = 32 * 1024,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Union[List[str], PromptList]):
|
||||||
|
A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenCompass' API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=self._num, ) as executor:
|
||||||
|
future_to_m = {
|
||||||
|
executor.submit(
|
||||||
|
self._generate,
|
||||||
|
self._sessions[i % self._num],
|
||||||
|
input,
|
||||||
|
max_out_len,
|
||||||
|
): i
|
||||||
|
for i, input in enumerate(inputs)
|
||||||
|
}
|
||||||
|
results = [''] * len(inputs)
|
||||||
|
for future in concurrent.futures.as_completed(future_to_m):
|
||||||
|
m = future_to_m[future]
|
||||||
|
resp = future.result()
|
||||||
|
if resp and resp.status_code == 200:
|
||||||
|
try:
|
||||||
|
result = resp.json()
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f'Fail to inference; '
|
||||||
|
f'model_name={self.path}; '
|
||||||
|
f'error={e}, '
|
||||||
|
f'request={inputs[m]}')
|
||||||
|
else:
|
||||||
|
if result.get('resultObj'):
|
||||||
|
results[m] = result.get('resultObj')
|
||||||
|
else:
|
||||||
|
self.logger.error(f'Receive invalid result. '
|
||||||
|
f'result={result}; '
|
||||||
|
f'request={inputs[m]}')
|
||||||
|
else:
|
||||||
|
self.logger.error(f'Receive invalid response. '
|
||||||
|
f'response={resp}; '
|
||||||
|
f'request={inputs[m]}')
|
||||||
|
self.flush()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
sess,
|
||||||
|
input: Union[str, PromptList],
|
||||||
|
max_out_len: int,
|
||||||
|
) -> str:
|
||||||
|
"""Generate results given an input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (str or PromptList): A string or PromptDict.
|
||||||
|
The PromptDict should be organized in OpenCompass' API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated string.
|
||||||
|
"""
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [{'role': 'user', 'content': input}]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
for item in input:
|
||||||
|
content = item['prompt']
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
message = {'content': content}
|
||||||
|
if item['role'] == 'HUMAN':
|
||||||
|
message['role'] = 'user'
|
||||||
|
elif item['role'] == 'BOT':
|
||||||
|
message['role'] = 'assistant'
|
||||||
|
elif item['role'] == 'SYSTEM':
|
||||||
|
message['role'] = 'system'
|
||||||
|
else:
|
||||||
|
message['role'] = item['role']
|
||||||
|
messages.append(message)
|
||||||
|
data = {
|
||||||
|
"__entry_point__": "openai.chat.completion",
|
||||||
|
'model': 'auto',
|
||||||
|
'messages': messages,
|
||||||
|
'max_tokens': max_out_len,
|
||||||
|
}
|
||||||
|
data.update(self.generation_kwargs)
|
||||||
|
current_time = time.time()
|
||||||
|
signature = self._sign(data, current_time)
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'x-fin-e-gw-signature-appid': 'opencompass',
|
||||||
|
'x-fin-e-gw-signature-timestamp': current_time,
|
||||||
|
'x-fin-e-gw-signature': signature
|
||||||
|
}
|
||||||
|
request = {
|
||||||
|
"source": "opencompass",
|
||||||
|
"input": data,
|
||||||
|
"modelCode": self._model
|
||||||
|
}
|
||||||
|
retry_num = 0
|
||||||
|
while retry_num < self.retry:
|
||||||
|
try:
|
||||||
|
response = self._infer_result(request, headers, sess)
|
||||||
|
except ConnectionError:
|
||||||
|
time.sleep(random.randint(10, 30))
|
||||||
|
retry_num += 1 # retry
|
||||||
|
continue
|
||||||
|
if response.status_code == 200:
|
||||||
|
break # success
|
||||||
|
elif response.status_code == 426:
|
||||||
|
retry_num += 1 # retry
|
||||||
|
elif response.status_code in [302, 429, 500, 504]:
|
||||||
|
time.sleep(random.randint(10, 30))
|
||||||
|
retry_num += 1 # retry
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Status code = {response.status_code}')
|
||||||
|
else:
|
||||||
|
# Exceed the maximal retry times.
|
||||||
|
return ''
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _sign(self, data, current_time):
|
||||||
|
data_str = "postBodyForSign=" + data + "^_^" + "opencompass" + "^_^" + current_time
|
||||||
|
data_hmac = hmac.new(self._key.encode('utf-8'), data_str.encode('utf-8'), hashlib.sha256)
|
||||||
|
signature = base64.b64encode(data_hmac.digest())
|
||||||
|
return signature
|
||||||
|
|
||||||
|
def _infer_result(self, request, headers, sess):
|
||||||
|
response = sess.request(
|
||||||
|
'POST',
|
||||||
|
self._url,
|
||||||
|
json=request,
|
||||||
|
headers=headers,
|
||||||
|
timeout=500,
|
||||||
|
)
|
||||||
|
return response
|
Loading…
Reference in New Issue
Block a user