diff --git a/configs/api_examples/eval_api_nanbeige.py b/configs/api_examples/eval_api_nanbeige.py new file mode 100644 index 00000000..1d941067 --- /dev/null +++ b/configs/api_examples/eval_api_nanbeige.py @@ -0,0 +1,36 @@ +from mmengine.config import read_base +from opencompass.models import Nanbeige +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + + +with read_base(): + from ..summarizers.medium import summarizer + from ..datasets.ceval.ceval_gen import ceval_datasets + +datasets = [ + *ceval_datasets, +] + +models = [ + dict( + abbr='nanbeige-plus', + type=Nanbeige, + path='nanbeige-plus', + key="xxxxxx", + query_per_second=1, + max_out_len=2048, + batch_size=8), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=2, + concurrent_users=2, + task=dict(type=OpenICLInferTask)), +) + +work_dir ="./output/nanbeige-plus" \ No newline at end of file diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 331c66eb..41c3e127 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -18,6 +18,7 @@ from .minimax_api import MiniMax # noqa: F401 from .mixtral import Mixtral # noqa: F401 from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403 from .moonshot_api import MoonShot # noqa: F401 +from .nanbeige_api import Nanbeige # noqa: F401 from .openai_api import OpenAI # noqa: F401 from .pangu_api import PanGu # noqa: F401 from .qwen_api import Qwen # noqa: F401 diff --git a/opencompass/models/nanbeige_api.py b/opencompass/models/nanbeige_api.py new file mode 100644 index 00000000..d39c34bf --- /dev/null +++ b/opencompass/models/nanbeige_api.py @@ -0,0 +1,148 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class Nanbeige(BaseAPIModel): + """Model wrapper around Nanbeige. + + Documentations: + + Args: + path (str): Model name, e.g. `nanbeige-plus` + key (str): Provide API Key + url (str): Provided URL + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 2. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__(self, + path: str, + key: str, + url: str = None, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 3): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + self.headers = { + 'Authorization': 'Bearer ' + key, + 'Content-Type': 'application/json', + } + self.model = path + self.url = url if url is not None \ + else 'http://stardustlm.zhipin.com/api/gpt/open/chat/send/sync' + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'sender_type': 'USER', 'text': input}] + else: + messages = [] + for item in input: + msg = {'text': item['prompt']} + if item['role'] == 'HUMAN': + msg['sender_type'] = 'USER' + elif item['role'] == 'BOT': + msg['sender_type'] = 'BOT' + + messages.append(msg) + + data = { + 'model': self.model, + 'messages': messages, + } + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + raw_response = requests.request('POST', + url=self.url, + headers=self.headers, + json=data) + self.release() + + if raw_response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + + if raw_response.status_code != 200: + print('请求失败:', raw_response) + print('失败信息:', raw_response.text) + max_num_retries += 1 + continue + + response = raw_response.json() + if response['stardustCode'] == 0: + return response['reply'] + + # exceed concurrency limit + if response['stardustCode'] == 20035: + print(response) + time.sleep(2) + continue + + print(response) + max_num_retries += 1 + + raise RuntimeError(raw_response.text)