From 50bcffc4f77007e91da43c7b983a9d30fc185299 Mon Sep 17 00:00:00 2001 From: xsq2060 Date: Fri, 25 Apr 2025 00:22:20 +0800 Subject: [PATCH] [Feature] Support AntFinix LLM --- examples/eval_antfinix_api.py | 12 + .../models/antfinix_api/antfinix_20250418.py | 19 ++ opencompass/models/__init__.py | 1 + opencompass/models/antfinix_api.py | 240 ++++++++++++++++++ 4 files changed, 272 insertions(+) create mode 100644 examples/eval_antfinix_api.py create mode 100644 opencompass/configs/models/antfinix_api/antfinix_20250418.py create mode 100644 opencompass/models/antfinix_api.py diff --git a/examples/eval_antfinix_api.py b/examples/eval_antfinix_api.py new file mode 100644 index 00000000..6b1fd643 --- /dev/null +++ b/examples/eval_antfinix_api.py @@ -0,0 +1,12 @@ +from mmengine.config import read_base + +with read_base(): + from opencompass.configs.datasets.demo.demo_gsm8k_chat_gen import \ + gsm8k_datasets + from opencompass.configs.datasets.demo.demo_math_chat_gen import \ + math_datasets + from opencompass.configs.models.antfinix_api.antfinix_20250418 import \ + models as antfinix + +datasets = gsm8k_datasets + math_datasets +models = antfinix diff --git a/opencompass/configs/models/antfinix_api/antfinix_20250418.py b/opencompass/configs/models/antfinix_api/antfinix_20250418.py new file mode 100644 index 00000000..9d047939 --- /dev/null +++ b/opencompass/configs/models/antfinix_api/antfinix_20250418.py @@ -0,0 +1,19 @@ +from opencompass.models import AntFinixAPI + +models = [ + dict( + path='035A54D2-9A48-021A-8ED7-C6758F3344AF', + key='', # set your key here or in environment variable ANTFINIX_API_KEY + url='https://fin-evaluator-gw.antgroup.com/api/v1/finEvaluator/evaluate', + type=AntFinixAPI, + max_out_len=32 * 1024, + batch_size=1, + generation_kwargs={ + 'temperature': 1.0, + 'logprobs': 0, + 'top_p': 1.0, + 'top_k': -1, + 'n': 1, + }, + ), +] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 580402d4..0439d3a7 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -49,3 +49,4 @@ from .yayi_api import Yayi # noqa: F401 from .yi_api import YiAPI # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401 from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401 +from .antfinix_api import AntFinixAPI # noqa: F401 \ No newline at end of file diff --git a/opencompass/models/antfinix_api.py b/opencompass/models/antfinix_api.py new file mode 100644 index 00000000..bd7e8b57 --- /dev/null +++ b/opencompass/models/antfinix_api.py @@ -0,0 +1,240 @@ +import base64 +import concurrent +import concurrent.futures +import hashlib +import hmac +import os +import random +import socket +import time +from typing import Dict, List, Optional, Union + +import requests +from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError +from urllib3.connection import HTTPConnection + +try: + from retrying import retry +except ImportError: + retry = None + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class HTTPAdapterWithSocketOptions(HTTPAdapter): + + def __init__(self, *args, **kwargs): + self._socket_options = HTTPConnection.default_socket_options + [ + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), + (socket.SOL_TCP, socket.TCP_KEEPIDLE, 75), + (socket.SOL_TCP, socket.TCP_KEEPINTVL, 30), + (socket.SOL_TCP, socket.TCP_KEEPCNT, 120), + ] + super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs) + + def init_poolmanager(self, *args, **kwargs): + if self._socket_options is not None: + kwargs['socket_options'] = self._socket_options + super(HTTPAdapterWithSocketOptions, + self).init_poolmanager(*args, **kwargs) + + +class AntFinixAPI(BaseAPIModel): + """Model wrapper around AntFinix Service. + + Args: + ouput_key (str): key for prediction + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + generation_kwargs: other params + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + key: str, + url: str, + meta_template: Optional[Dict] = None, + query_per_second: int = 1, + retry: int = 3, + generation_kwargs: Dict = {}, + max_seq_len=32 * 1024, + ): + super().__init__( + path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry, + generation_kwargs=generation_kwargs, + ) + + self.logger.info(f'AntFinix API Model Init path: {path} url={url}') + + self._key = key + self._url = (url if url else 'https://fin-evaluator-gw.antgroup.com/api/v1/finEvaluator/evaluate') + self._model = path + self._sessions = [] + self._num = (int(os.environ.get('ANTFINIX_API_PARALLEL_NUM')) + if os.environ.get('ANTFINIX_API_PARALLEL_NUM') else 1) + try: + for _ in range(self._num): + adapter = HTTPAdapterWithSocketOptions() + sess = requests.Session() + sess.mount('http://', adapter) + sess.mount('https://', adapter) + self._sessions.append(sess) + except Exception as e: + self.logger.error(f'Fail to setup the session. {e}') + raise e + + def generate( + self, + inputs: Union[List[str], PromptList], + max_out_len: int = 32 * 1024, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (Union[List[str], PromptList]): + A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with concurrent.futures.ThreadPoolExecutor( + max_workers=self._num, ) as executor: + future_to_m = { + executor.submit( + self._generate, + self._sessions[i % self._num], + input, + max_out_len, + ): i + for i, input in enumerate(inputs) + } + results = [''] * len(inputs) + for future in concurrent.futures.as_completed(future_to_m): + m = future_to_m[future] + resp = future.result() + if resp and resp.status_code == 200: + try: + result = resp.json() + except Exception as e: + self.logger.error(f'Fail to inference; ' + f'model_name={self.path}; ' + f'error={e}, ' + f'request={inputs[m]}') + else: + if result.get('resultObj'): + results[m] = result.get('resultObj') + else: + self.logger.error(f'Receive invalid result. ' + f'result={result}; ' + f'request={inputs[m]}') + else: + self.logger.error(f'Receive invalid response. ' + f'response={resp}; ' + f'request={inputs[m]}') + self.flush() + return results + + def _generate( + self, + sess, + input: Union[str, PromptList], + max_out_len: int, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + content = item['prompt'] + if not content: + continue + message = {'content': content} + if item['role'] == 'HUMAN': + message['role'] = 'user' + elif item['role'] == 'BOT': + message['role'] = 'assistant' + elif item['role'] == 'SYSTEM': + message['role'] = 'system' + else: + message['role'] = item['role'] + messages.append(message) + data = { + "__entry_point__": "openai.chat.completion", + 'model': 'auto', + 'messages': messages, + 'max_tokens': max_out_len, + } + data.update(self.generation_kwargs) + current_time = time.time() + signature = self._sign(data, current_time) + headers = { + 'Content-Type': 'application/json', + 'x-fin-e-gw-signature-appid': 'opencompass', + 'x-fin-e-gw-signature-timestamp': current_time, + 'x-fin-e-gw-signature': signature + } + request = { + "source": "opencompass", + "input": data, + "modelCode": self._model + } + retry_num = 0 + while retry_num < self.retry: + try: + response = self._infer_result(request, headers, sess) + except ConnectionError: + time.sleep(random.randint(10, 30)) + retry_num += 1 # retry + continue + if response.status_code == 200: + break # success + elif response.status_code == 426: + retry_num += 1 # retry + elif response.status_code in [302, 429, 500, 504]: + time.sleep(random.randint(10, 30)) + retry_num += 1 # retry + else: + raise ValueError(f'Status code = {response.status_code}') + else: + # Exceed the maximal retry times. + return '' + return response + + def _sign(self, data, current_time): + data_str = "postBodyForSign=" + data + "^_^" + "opencompass" + "^_^" + current_time + data_hmac = hmac.new(self._key.encode('utf-8'), data_str.encode('utf-8'), hashlib.sha256) + signature = base64.b64encode(data_hmac.digest()) + return signature + + def _infer_result(self, request, headers, sess): + response = sess.request( + 'POST', + self._url, + json=request, + headers=headers, + timeout=500, + ) + return response