diff --git a/configs/api_examples/eval_api_bailing.py b/configs/api_examples/eval_api_bailing.py new file mode 100644 index 00000000..15101b09 --- /dev/null +++ b/configs/api_examples/eval_api_bailing.py @@ -0,0 +1,38 @@ +from mmengine.config import read_base + +from opencompass.models import BailingAPI +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + from opencompass.configs.datasets.ceval.ceval_gen import ceval_datasets + from opencompass.configs.summarizers.medium import summarizer + +datasets = [ + *ceval_datasets, +] + +models = [ + dict( + path="Bailing-Lite-0830", + token="xxxxxx", # set your key here or in environment variable BAILING_API_KEY + url="https://bailingchat.alipay.com/chat/completions", + type=BailingAPI, + generation_kwargs={}, + query_per_second=1, + max_seq_len=4096, + ), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=2, + concurrent_users=2, + task=dict(type=OpenICLInferTask), + ), +) + +work_dir = "outputs/api_bailing/" diff --git a/configs/models/bailing_api/bailing-lite-0830.py b/configs/models/bailing_api/bailing-lite-0830.py new file mode 100644 index 00000000..1a43b4be --- /dev/null +++ b/configs/models/bailing_api/bailing-lite-0830.py @@ -0,0 +1,31 @@ +from opencompass.models import BailingAPI + +api_meta_template = dict( + round=[ + dict(role="HUMAN", api_role="HUMAN"), + dict(role="BOT", api_role="BOT", generate=False), + ], + reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], +) + +models = [ + dict( + path="Bailing-Lite-0830", + token="", # set your key here or in environment variable BAILING_API_KEY + url="https://bailingchat.alipay.com/chat/completions", + type=BailingAPI, + meta_template=api_meta_template, + query_per_second=1, + max_seq_len=4096, + batch_size=1, + generation_kwargs={ + "temperature": 0.4, + "top_p": 1.0, + "top_k": -1, + "n": 1, + "logprobs": 1, + "use_beam_search": False, + }, + ), +] + diff --git a/configs/models/bailing_api/bailing-pro-0920.py b/configs/models/bailing_api/bailing-pro-0920.py new file mode 100644 index 00000000..35814bf7 --- /dev/null +++ b/configs/models/bailing_api/bailing-pro-0920.py @@ -0,0 +1,31 @@ +from opencompass.models import BailingAPI + +api_meta_template = dict( + round=[ + dict(role="HUMAN", api_role="HUMAN"), + dict(role="BOT", api_role="BOT", generate=False), + ], + reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], +) + +models = [ + dict( + path="Bailing-Pro-0920", + token="", # set your key here or in environment variable BAILING_API_KEY + url="https://bailingchat.alipay.com/chat/completions", + type=BailingAPI, + meta_template=api_meta_template, + query_per_second=1, + max_seq_len=4096, + batch_size=1, + generation_kwargs={ + "temperature": 0.4, + "top_p": 1.0, + "top_k": -1, + "n": 1, + "logprobs": 1, + "use_beam_search": False, + }, + ), +] + diff --git a/opencompass/configs/models/bailing_api/bailing-lite-0830.py b/opencompass/configs/models/bailing_api/bailing-lite-0830.py new file mode 100644 index 00000000..1a43b4be --- /dev/null +++ b/opencompass/configs/models/bailing_api/bailing-lite-0830.py @@ -0,0 +1,31 @@ +from opencompass.models import BailingAPI + +api_meta_template = dict( + round=[ + dict(role="HUMAN", api_role="HUMAN"), + dict(role="BOT", api_role="BOT", generate=False), + ], + reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], +) + +models = [ + dict( + path="Bailing-Lite-0830", + token="", # set your key here or in environment variable BAILING_API_KEY + url="https://bailingchat.alipay.com/chat/completions", + type=BailingAPI, + meta_template=api_meta_template, + query_per_second=1, + max_seq_len=4096, + batch_size=1, + generation_kwargs={ + "temperature": 0.4, + "top_p": 1.0, + "top_k": -1, + "n": 1, + "logprobs": 1, + "use_beam_search": False, + }, + ), +] + diff --git a/opencompass/configs/models/bailing_api/bailing-pro-0920.py b/opencompass/configs/models/bailing_api/bailing-pro-0920.py new file mode 100644 index 00000000..35814bf7 --- /dev/null +++ b/opencompass/configs/models/bailing_api/bailing-pro-0920.py @@ -0,0 +1,31 @@ +from opencompass.models import BailingAPI + +api_meta_template = dict( + round=[ + dict(role="HUMAN", api_role="HUMAN"), + dict(role="BOT", api_role="BOT", generate=False), + ], + reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], +) + +models = [ + dict( + path="Bailing-Pro-0920", + token="", # set your key here or in environment variable BAILING_API_KEY + url="https://bailingchat.alipay.com/chat/completions", + type=BailingAPI, + meta_template=api_meta_template, + query_per_second=1, + max_seq_len=4096, + batch_size=1, + generation_kwargs={ + "temperature": 0.4, + "top_p": 1.0, + "top_k": -1, + "n": 1, + "logprobs": 1, + "use_beam_search": False, + }, + ), +] + diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 403eb5d6..0beb963a 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -3,6 +3,7 @@ from .ai360_api import AI360GPT # noqa: F401 from .alaya import AlayaLM # noqa: F401 from .baichuan_api import BaiChuan # noqa: F401 from .baidu_api import ERNIEBot # noqa: F401 +from .bailing_api_oc import BailingAPI # noqa: F401 from .base import BaseModel, LMTemplateParser # noqa: F401 from .base_api import APITemplateParser, BaseAPIModel # noqa: F401 from .bytedance_api import ByteDance # noqa: F401 @@ -41,8 +42,7 @@ from .sensetime_api import SenseTime # noqa: F401 from .stepfun_api import StepFun # noqa: F401 from .turbomind import TurboMindModel # noqa: F401 from .turbomind_tis import TurboMindTisModel # noqa: F401 -from .turbomind_with_tf_above_v4_33 import \ - TurboMindModelwithChatTemplate # noqa: F401 +from .turbomind_with_tf_above_v4_33 import TurboMindModelwithChatTemplate # noqa: F401 from .unigpt_api import UniGPT # noqa: F401 from .vllm import VLLM # noqa: F401 from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401 diff --git a/opencompass/models/bailing_api_oc.py b/opencompass/models/bailing_api_oc.py new file mode 100644 index 00000000..6ff75e0d --- /dev/null +++ b/opencompass/models/bailing_api_oc.py @@ -0,0 +1,215 @@ +import concurrent +import concurrent.futures +import os +import socket +import traceback +from typing import Dict, List, Optional, Union + +import requests +from requests.adapters import HTTPAdapter +from retrying import retry +from urllib3.connection import HTTPConnection + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class HTTPAdapterWithSocketOptions(HTTPAdapter): + def __init__(self, *args, **kwargs): + self._socket_options = HTTPConnection.default_socket_options + [ + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), + (socket.SOL_TCP, socket.TCP_KEEPIDLE, 75), + (socket.SOL_TCP, socket.TCP_KEEPINTVL, 30), + (socket.SOL_TCP, socket.TCP_KEEPCNT, 120), + ] + super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs) + + def init_poolmanager(self, *args, **kwargs): + if self._socket_options is not None: + kwargs["socket_options"] = self._socket_options + super(HTTPAdapterWithSocketOptions, self).init_poolmanager(*args, **kwargs) + + +class BailingAPI(BaseAPIModel): + """Model wrapper around Bailing Service. + + Args: + ouput_key (str): key for prediction + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + generation_kwargs: other params + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + token: str, + url: str, + meta_template: Optional[Dict] = None, + query_per_second: int = 1, + retry: int = 3, + generation_kwargs: Dict = {}, + max_seq_len=4096, + ): + super().__init__( + path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry, + generation_kwargs=generation_kwargs, + ) + + self.logger.info(f"Bailing API Model Init path: {path} url={url}") + if not token: + token = os.environ.get("BAILING_API_KEY") + if token: + self._headers = {"Authorization": f"Bearer {token}"} + else: + raise RuntimeError(f"There is not valid token.") + self._headers["Content-Type"] = "application/json" + self._url = url if url else "https://bailingchat.alipay.com/chat/completions" + self._model = path + self._sessions = [] + self._num = ( + int(os.environ.get("BAILING_API_PARALLEL_NUM")) + if os.environ.get("BAILING_API_PARALLEL_NUM") + else 1 + ) + try: + for _ in range(self._num): + adapter = HTTPAdapterWithSocketOptions() + sess = requests.Session() + sess.mount("http://", adapter) + sess.mount("https://", adapter) + self._sessions.append(sess) + except Exception as e: + self.logger.error(f"Fail to setup the session. {e}") + raise e + + def generate( + self, + inputs: Union[List[str], PromptList], + max_out_len: int = 4096, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (Union[List[str], PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with concurrent.futures.ThreadPoolExecutor( + max_workers=self._num, + ) as executor: + future_to_m = { + executor.submit( + self._generate, + self._sessions[i % self._num], + input, + max_out_len, + ): i + for i, input in enumerate(inputs) + } + results = [] + for future in concurrent.futures.as_completed(future_to_m): + m = future_to_m[future] + resp = future.result() + if resp and resp.status_code == 200: + try: + result = resp.json() + except: + results.append("") + else: + if ( + result.get("choices") + and result["choices"][0].get("message") + and result["choices"][0]["message"].get("content") + ): + results.append(result["choices"][0]["message"]["content"]) + else: + results.append("") + self.flush() + return results + + def _generate( + self, + sess, + input: Union[str, PromptList], + max_out_len: int, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + if isinstance(input, str): + messages = [{"role": "user", "content": input}] + else: + messages = [] + for item in input: + content = item["prompt"] + if not content: + continue + message = {"content": content} + if item["role"] == "HUMAN": + message["role"] = "user" + elif item["role"] == "BOT": + message["role"] = "assistant" + elif item["role"] == "SYSTEM": + message["role"] = "system" + else: + message["role"] = item["role"] + messages.append(message) + request = { + "model": self._model, + "messages": messages, + "max_seq_len": max( + max_out_len if max_out_len else 4096, + self.max_seq_len if self.max_seq_len else 4096, + ), + } + request.update(self.generation_kwargs) + try: + retry_num = 0 + while retry_num < self.retry: + response = self._infer_result(request, sess) + if response.status_code == 200: + break # success + elif response.status_code == 426: + retry_num += 1 # retry + else: + raise ValueError(f"Status code = {response.status_code}") + else: + raise ValueError( + f"Exceed the maximal retry times. Last status code = {response.status_code}" + ) + except Exception as e: + self.logger.error( + f"Fail to inference request={request}; model_name={self.path}; error={e}, stack:{traceback.format_exc()}" + ) + raise e + return response + + @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms + def _infer_result(self, request, sess): + response = sess.request( + "POST", + self._url, + json=request, + headers=self._headers, + timeout=500, + ) + return response diff --git a/requirements/runtime.txt b/requirements/runtime.txt index dc638911..e7229e88 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -23,6 +23,7 @@ python-Levenshtein rank_bm25==0.2.2 rapidfuzz requests>=2.31.0 +retrying rich rouge -e git+https://github.com/Isaac-JL-Chen/rouge_chinese.git@master#egg=rouge_chinese