From d92574826629c06d90f80f334f3d3068c7a2dbb7 Mon Sep 17 00:00:00 2001 From: Songyang Zhang Date: Tue, 21 Nov 2023 20:25:47 +0800 Subject: [PATCH] [Feature] Support 360API and FixKRetriever for CSQA dataset (#601) * [Feature] Support 360API and FixKRetriever for CSQA dataset * Update API * Update API * [Feature] Support 360API and FixKRetriever for CSQA dataset * Update API * Update API * rm mathbench * fix_lint * Update opencompass/models/bytedance_api.py Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> * update * update * update --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> --- .../commonsenseqa/commonsenseqa_gen_1da2d0.py | 55 +++++ .../commonsenseqa/commonsenseqa_ppl_e51e32.py | 42 ++++ configs/eval_api_360.py | 36 ++++ configs/eval_api_baichuan.py | 39 ++++ configs/eval_api_pangu.py | 43 ++++ configs/eval_api_sensetime.py | 35 ++++ opencompass/models/__init__.py | 6 + opencompass/models/ai360_api.py | 168 +++++++++++++++ opencompass/models/baichuan_api.py | 164 +++++++++++++++ opencompass/models/baidu_api.py | 196 ++++++++++++++++++ opencompass/models/base_api.py | 33 +++ opencompass/models/bytedance_api.py | 172 +++++++++++++++ opencompass/models/minimax_api.py | 32 --- opencompass/models/pangu_api.py | 182 ++++++++++++++++ opencompass/models/sensetime_api.py | 136 ++++++++++++ opencompass/models/xunfei_api.py | 32 --- opencompass/models/zhipuai_api.py | 32 --- requirements/api.txt | 4 +- 18 files changed, 1310 insertions(+), 97 deletions(-) create mode 100644 configs/datasets/commonsenseqa/commonsenseqa_gen_1da2d0.py create mode 100644 configs/datasets/commonsenseqa/commonsenseqa_ppl_e51e32.py create mode 100644 configs/eval_api_360.py create mode 100644 configs/eval_api_baichuan.py create mode 100644 configs/eval_api_pangu.py create mode 100644 configs/eval_api_sensetime.py create mode 100644 opencompass/models/ai360_api.py create mode 100644 opencompass/models/baichuan_api.py create mode 100644 opencompass/models/baidu_api.py create mode 100644 opencompass/models/bytedance_api.py create mode 100644 opencompass/models/pangu_api.py create mode 100644 opencompass/models/sensetime_api.py diff --git a/configs/datasets/commonsenseqa/commonsenseqa_gen_1da2d0.py b/configs/datasets/commonsenseqa/commonsenseqa_gen_1da2d0.py new file mode 100644 index 00000000..d046c40f --- /dev/null +++ b/configs/datasets/commonsenseqa/commonsenseqa_gen_1da2d0.py @@ -0,0 +1,55 @@ +# Use FixKRetriever to avoid hang caused by the Huggingface +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import commonsenseqaDataset +from opencompass.utils.text_postprocessors import first_capital_postprocess + +commonsenseqa_reader_cfg = dict( + input_columns=["question", "A", "B", "C", "D", "E"], + output_column="answerKey", + test_split="validation") + +_ice_template = dict( + type=PromptTemplate, + template=dict( + begin="", + round=[ + dict( + role="HUMAN", + prompt= + "{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nE. {E}\nAnswer:", + ), + dict( + role="BOT", + prompt="{answerKey}", + ), + ], + ), + ice_token="", +) + +commonsenseqa_infer_cfg = dict( + ice_template=_ice_template, + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4, 5, 6, 7]), + inferencer=dict(type=GenInferencer), +) + +commonsenseqa_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=first_capital_postprocess), +) + +commonsenseqa_datasets = [ + dict( + abbr='commonsense_qa', + type=commonsenseqaDataset, + path='./data/commonsenseqa', + reader_cfg=commonsenseqa_reader_cfg, + infer_cfg=commonsenseqa_infer_cfg, + eval_cfg=commonsenseqa_eval_cfg, + ) +] + +del _ice_template diff --git a/configs/datasets/commonsenseqa/commonsenseqa_ppl_e51e32.py b/configs/datasets/commonsenseqa/commonsenseqa_ppl_e51e32.py new file mode 100644 index 00000000..fd1a77e6 --- /dev/null +++ b/configs/datasets/commonsenseqa/commonsenseqa_ppl_e51e32.py @@ -0,0 +1,42 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import commonsenseqaDataset + +commonsenseqa_reader_cfg = dict( + input_columns=['question', 'A', 'B', 'C', 'D', 'E'], + output_column='answerKey', + test_split='validation') + +_ice_template = dict( + type=PromptTemplate, + template={ + ans: dict( + begin='', + round=[ + dict(role="HUMAN", prompt="Question: {question}\nAnswer: "), + dict(role="BOT", prompt=ans_token), + ]) + for ans, ans_token in [["A", "{A}"], ["B", "{B}"], + ["C", "{C}"], ["D", "{D}"], + ["E", "{E}"]] + }, + ice_token='') + +commonsenseqa_infer_cfg = dict( + ice_template=_ice_template, + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4, 5, 6, 7]), + inferencer=dict(type=PPLInferencer)) + +commonsenseqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) + +commonsenseqa_datasets = [ + dict( + abbr='commonsense_qa', + type=commonsenseqaDataset, + path='./data/commonsenseqa', + reader_cfg=commonsenseqa_reader_cfg, + infer_cfg=commonsenseqa_infer_cfg, + eval_cfg=commonsenseqa_eval_cfg) +] diff --git a/configs/eval_api_360.py b/configs/eval_api_360.py new file mode 100644 index 00000000..aef689a9 --- /dev/null +++ b/configs/eval_api_360.py @@ -0,0 +1,36 @@ +from mmengine.config import read_base +from opencompass.models import AI360GPT +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + from .summarizers.medium import summarizer + from .datasets.ceval.ceval_gen import ceval_datasets + +datasets = [ + *ceval_datasets, +] + +models = [ + dict( + abbr='360GPT_S2_V9', + type=AI360GPT, + path='360GPT_S2_V9', + key="xxxxxxxxxxxx", + query_per_second=1, + max_out_len=2048, + max_seq_len=2048, + batch_size=8), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=2, + concurrent_users=2, + task=dict(type=OpenICLInferTask)), +) + +work_dir ="./output/360GPT_S2_V9" \ No newline at end of file diff --git a/configs/eval_api_baichuan.py b/configs/eval_api_baichuan.py new file mode 100644 index 00000000..1c845ea0 --- /dev/null +++ b/configs/eval_api_baichuan.py @@ -0,0 +1,39 @@ +from mmengine.config import read_base +from opencompass.models import BaiChuan + +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + from .summarizers.medium import summarizer + from .datasets.ceval.ceval_gen import ceval_datasets + +datasets = [ + *ceval_datasets, +] + +models = [ + dict( + abbr='Baichuan2-53B', + type=BaiChuan, + path='Baichuan2-53B', + api_key='xxxxxx', + secret_key="xxxxx", + url="xxxxx", + query_per_second=1, + max_out_len=2048, + max_seq_len=2048, + batch_size=8), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=2, + concurrent_users=2, + task=dict(type=OpenICLInferTask)), +) + +work_dir = "outputs/api_baichuan53b/" \ No newline at end of file diff --git a/configs/eval_api_pangu.py b/configs/eval_api_pangu.py new file mode 100644 index 00000000..3bfc15eb --- /dev/null +++ b/configs/eval_api_pangu.py @@ -0,0 +1,43 @@ +from mmengine.config import read_base +from opencompass.models import PanGu + +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + from .summarizers.medium import summarizer + from .datasets.ceval.ceval_gen import ceval_datasets + +datasets = [ + *ceval_datasets, +] + +models = [ +dict( + abbr='pangu', + type=PanGu, + path='pangu', + access_key="xxxxxx", + secret_key="xxxxxx", + url = "xxxxxx", + # url of token sever, used for generate token, like "https://xxxxxx.myhuaweicloud.com/v3/auth/tokens", + token_url = "xxxxxx", + # scope-project-name, used for generate token + project_name = "xxxxxx", + query_per_second=1, + max_out_len=2048, + max_seq_len=2048, + batch_size=8), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=2, + concurrent_users=2, + task=dict(type=OpenICLInferTask)), +) + +work_dir = "outputs/api_pangu/" \ No newline at end of file diff --git a/configs/eval_api_sensetime.py b/configs/eval_api_sensetime.py new file mode 100644 index 00000000..6f243e00 --- /dev/null +++ b/configs/eval_api_sensetime.py @@ -0,0 +1,35 @@ +from mmengine.config import read_base +from opencompass.models import SenseTime +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + from .summarizers.medium import summarizer + from .datasets.ceval.ceval_gen import ceval_datasets + +datasets = [ + *ceval_datasets, +] + +models = [ + dict( + abbr='nova-ptc-xl-v1', + type=SenseTime, + path='nova-ptc-xl-v1', + key='xxxxxxxxxxxxxx', + url='xxxxxxxxxxx', + query_per_second=1, + max_out_len=2048, + max_seq_len=2048, + batch_size=8), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=2, + concurrent_users=2, + task=dict(type=OpenICLInferTask)), +) diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 380927b5..52920dd1 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -1,6 +1,10 @@ +from .ai360_api import AI360GPT # noqa: F401 from .alaya import AlayaLM # noqa: F401 +from .baichuan_api import BaiChuan # noqa: F401 +from .baidu_api import ERNIEBot # noqa: F401 from .base import BaseModel, LMTemplateParser # noqa from .base_api import APITemplateParser, BaseAPIModel # noqa +from .bytedance_api import ByteDance # noqa: F401 from .claude_api import Claude # noqa: F401 from .glm import GLM130B # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403 @@ -11,5 +15,7 @@ from .lightllm_api import LightllmAPI # noqa: F401 from .llama2 import Llama2, Llama2Chat # noqa: F401, F403 from .minimax_api import MiniMax # noqa: F401 from .openai_api import OpenAI # noqa: F401 +from .pangu_api import PanGu # noqa: F401 +from .sensetime_api import SenseTime # noqa: F401 from .xunfei_api import XunFei # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401 diff --git a/opencompass/models/ai360_api.py b/opencompass/models/ai360_api.py new file mode 100644 index 00000000..542d654a --- /dev/null +++ b/opencompass/models/ai360_api.py @@ -0,0 +1,168 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class AI360GPT(BaseAPIModel): + """Model wrapper around 360 GPT. + + Documentations: https://ai.360.com/platform/docs/overview + + Args: + path (str): Model name + key (str): Provide API Key + url (str): Provided URL + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 2. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, # model name, e.g.: 360GPT_S2_V9 + key: str, + url: str = 'https://api.360.cn/v1/chat/completions', + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + self.headers = { + 'Authorization': f'Bearer {key}', + 'Content-Type': 'application/json', + } + self.model = path + self.url = url + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + elif item['role'] == 'SYSTEM': + msg['role'] = 'system' + messages.append(msg) + + data = { + 'model': self.model, + 'messages': messages, + 'stream': False, + 'temperature': 0.9, + 'max_tokens': 2048, + 'top_p': 0.5, + 'tok_k': 0, + 'repetition_penalty': 1.05, + # "num_beams": 1, + # "user": "OpenCompass" + } + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + # payload = json.dumps(data) + raw_response = requests.request('POST', + url=self.url, + headers=self.headers, + json=data) + response = raw_response.json() + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if raw_response.status_code == 200: + try: + msg = response['choices'][0]['message']['content'].strip() + return msg + + except KeyError: + if 'error' in response: + # tpm(token per minitue) limit + if response['erro']['code'] == '1005': + time.sleep(1) + continue + + self.logger.error('Find error message in response: ', + str(response['error'])) + + # sensitive content, prompt overlength, network error + # or illegal prompt + if (raw_response.status_code == 400 + or raw_response.status_code == 401 + or raw_response.status_code == 402 + or raw_response.status_code == 429 + or raw_response.status_code == 500): + print(raw_response.text) + # return '' + continue + print(raw_response) + max_num_retries += 1 + + raise RuntimeError(raw_response.text) diff --git a/opencompass/models/baichuan_api.py b/opencompass/models/baichuan_api.py new file mode 100644 index 00000000..1396dc25 --- /dev/null +++ b/opencompass/models/baichuan_api.py @@ -0,0 +1,164 @@ +import hashlib +import json +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class BaiChuan(BaseAPIModel): + """Model wrapper around Baichuan. + + Documentation: https://platform.baichuan-ai.com/docs/api + + Args: + path (str): The name of Baichuan model. + e.g. `Baichuan2-53B` + api_key (str): Provided api key + secretkey (str): secretkey in order to obtain access_token + url (str): Provide url + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + api_key: str, + secret_key: str, + url: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + + self.api_key = api_key + self.secret_key = secret_key + self.url = url + self.model = path + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + + messages.append(msg) + + data = {'model': self.model, 'messages': messages} + + def calculate_md5(input_string): + md5 = hashlib.md5() + md5.update(input_string.encode('utf-8')) + encrypted = md5.hexdigest() + return encrypted + + json_data = json.dumps(data) + time_stamp = int(time.time()) + signature = calculate_md5(self.secret_key + json_data + + str(time_stamp)) + + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + self.api_key, + 'X-BC-Request-Id': 'your requestId', + 'X-BC-Timestamp': str(time_stamp), + 'X-BC-Signature': signature, + 'X-BC-Sign-Algo': 'MD5', + } + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + raw_response = requests.request('POST', + url=self.url, + headers=headers, + json=data) + response = raw_response.json() + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if raw_response.status_code == 200 and response['code'] == 0: + # msg = json.load(response.text) + # response + msg = response['data']['messages'][0]['content'] + return msg + + if response['code'] != 0: + print(response) + return '' + print(response) + max_num_retries += 1 + + raise RuntimeError(response) diff --git a/opencompass/models/baidu_api.py b/opencompass/models/baidu_api.py new file mode 100644 index 00000000..085cf2b4 --- /dev/null +++ b/opencompass/models/baidu_api.py @@ -0,0 +1,196 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class ERNIEBot(BaseAPIModel): + """Model wrapper around ERNIE-Bot. + + Documentation: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11 + + Args: + path (str): The name of ENRIE-bot model. + e.g. `erniebot` + model_type (str): The type of the model + e.g. `chat` + secretkey (str): secretkey in order to obtain access_token + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + key: str, + secretkey: str, + url: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + self.headers = {'Content_Type': 'application/json'} + self.secretkey = secretkey + self.key = key + self.url = url + self.model = path + + def _generate_access_token(self): + try: + BAIDU_APIKEY = self.key + BAIDU_SECRETKEY = self.secretkey + url = f'https://aip.baidubce.com/oauth/2.0/token?' \ + f'client_id={BAIDU_APIKEY}&client_secret={BAIDU_SECRETKEY}' \ + f'&grant_type=client_credentials' + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json' + } + response = requests.request('POST', url, headers=headers) + resp_dict = response.json() + if response.status_code == 200: + access_token = resp_dict.get('access_token') + refresh_token = resp_dict.get('refresh_token') + if 'error' in resp_dict: + raise ValueError(f'Failed to obtain certificate.' + f'{resp_dict.get("error")}') + else: + return access_token, refresh_token + else: + error = resp_dict.get('error') + raise ValueError( + f'Failed to requests obtain certificate {error}.') + except Exception as ex: + raise ex + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + """ + { + "messages": [ + {"role":"user","content":"请介绍一下你自己"}, + {"role":"assistant","content":"我是百度公司开发的人工智能语言模型"}, + {"role":"user","content": "我在上海,周末可以去哪里玩?"}, + {"role":"assistant","content": "上海是一个充满活力和文化氛围的城市"}, + {"role":"user","content": "周末这里的天气怎么样?"} + ] + } + + """ + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + + messages.append(msg) + data = {'messages': messages} + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + access_token, _ = self._generate_access_token() + raw_response = requests.request('POST', + url=self.url + access_token, + headers=self.headers, + json=data) + response = raw_response.json() + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if raw_response.status_code == 200: + try: + msg = response['result'] + return msg + except KeyError: + print(response) + self.logger.error(str(response['error_code'])) + time.sleep(1) + continue + + if (response['error_code'] == 110 or response['error_code'] == 100 + or response['error_code'] == 111 + or response['error_code'] == 200 + or response['error_code'] == 1000 + or response['error_code'] == 1001 + or response['error_code'] == 1002 + or response['error_code'] == 21002 + or response['error_code'] == 216100 + or response['error_code'] == 336001 + or response['error_code'] == 336003 + or response['error_code'] == 336000): + print(response['error_msg']) + return '' + print(response) + max_num_retries += 1 + + raise RuntimeError(response['error_msg']) diff --git a/opencompass/models/base_api.py b/opencompass/models/base_api.py index 7c8f0b31..04ce9323 100644 --- a/opencompass/models/base_api.py +++ b/opencompass/models/base_api.py @@ -1,4 +1,5 @@ import re +import sys import threading import warnings from abc import abstractmethod @@ -64,6 +65,38 @@ class BaseAPIModel(BaseModel): ' gen-based evaluation yet, try ppl-based ' 'instead.') + def flush(self): + """Ensure simultaneous emptying of stdout and stderr when concurrent + resources are available. + + When employing multiprocessing with standard I/O redirected to files, + it is crucial to clear internal data for examination or prevent log + loss in case of system failures." + """ + if hasattr(self, 'tokens'): + sys.stdout.flush() + sys.stderr.flush() + + def acquire(self): + """Acquire concurrent resources if exists. + + This behavior will fall back to wait with query_per_second if there are + no concurrent resources. + """ + if hasattr(self, 'tokens'): + self.tokens.acquire() + else: + self.wait() + + def release(self): + """Release concurrent resources if acquired. + + This behavior will fall back to do nothing if there are no concurrent + resources. + """ + if hasattr(self, 'tokens'): + self.tokens.release() + @abstractmethod def get_ppl(self, inputs: List[PromptType], diff --git a/opencompass/models/bytedance_api.py b/opencompass/models/bytedance_api.py new file mode 100644 index 00000000..4092de48 --- /dev/null +++ b/opencompass/models/bytedance_api.py @@ -0,0 +1,172 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +try: + from volcengine.maas import ChatRole, MaasException, MaasService +except ImportError: + ChatRole, MaasException, MaasService = None, None, None + +PromptType = Union[PromptList, str] + + +class ByteDance(BaseAPIModel): + """Model wrapper around ByteDance. + + Args: + path (str): The name of ByteDance model. + e.g. `skylark` + model_type (str): The type of the model + e.g. `chat` + secretkey (str): secretkey in order to obtain access_token + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + accesskey: str, + secretkey: str, + url: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + if not ChatRole: + print('Please install related packages via' + ' `pip install volcengine`') + + self.accesskey = accesskey + self.secretkey = secretkey + self.url = url + self.model = path + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + + messages + [ + { + "role": ChatRole.USER, + "content": "天为什么这么蓝?" + }, { + "role": ChatRole.ASSISTANT, + "content": "因为有你" + }, { + "role": ChatRole.USER, + "content": "花儿为什么这么香?" + }, + ] + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': ChatRole.USER, 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = ChatRole.USER + elif item['role'] == 'BOT': + msg['role'] = ChatRole.ASSISTANT + + messages.append(msg) + + maas = MaasService(self.url, 'cn-beijing') + maas.set_ak(self.accesskey) + maas.set_sk(self.secretkey) + + req = { + 'model': { + 'name': 'skylark-pro-public', + }, + 'messages': messages + } + + def _chat(maas, req): + try: + resp = maas.chat(req) + return resp + except MaasException as e: + print(e) + return e + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + response = _chat(maas, req) + + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if not isinstance(response, MaasException): + # response + msg = response.choice.message.content + return msg + + if isinstance(response, MaasException): + print(response) + return '' + print(response) + max_num_retries += 1 + + raise RuntimeError(response) diff --git a/opencompass/models/minimax_api.py b/opencompass/models/minimax_api.py index 813d9162..f42500ac 100644 --- a/opencompass/models/minimax_api.py +++ b/opencompass/models/minimax_api.py @@ -1,4 +1,3 @@ -import sys from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union @@ -81,37 +80,6 @@ class MiniMax(BaseAPIModel): self.flush() return results - def flush(self): - """Flush stdout and stderr when concurrent resources exists. - - When use multiproessing with standard io rediected to files, need to - flush internal information for examination or log loss when system - breaks. - """ - if hasattr(self, 'tokens'): - sys.stdout.flush() - sys.stderr.flush() - - def acquire(self): - """Acquire concurrent resources if exists. - - This behavior will fall back to wait with query_per_second if there are - no concurrent resources. - """ - if hasattr(self, 'tokens'): - self.tokens.acquire() - else: - self.wait() - - def release(self): - """Release concurrent resources if acquired. - - This behavior will fall back to do nothing if there are no concurrent - resources. - """ - if hasattr(self, 'tokens'): - self.tokens.release() - def _generate( self, input: str or PromptList, diff --git a/opencompass/models/pangu_api.py b/opencompass/models/pangu_api.py new file mode 100644 index 00000000..773d6b26 --- /dev/null +++ b/opencompass/models/pangu_api.py @@ -0,0 +1,182 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class PanGu(BaseAPIModel): + """Model wrapper around PanGu. + + Args: + path (str): The name of Pangu model. + e.g. `pangu` + access_key (str): provided access_key + secret_key (str): secretkey in order to obtain access_token + url (str): provide url for requests + token_url (str): url of token server + project_name (str): project name for generate the token + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + access_key: str, + secret_key: str, + url: str, + token_url: str, + project_name: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + + self.access_key = access_key + self.secret_key = secret_key + self.url = url + self.token_url = token_url + self.project_name = project_name + self.model = path + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _get_token(self): + url = self.token_url + payload = { + 'auth': { + 'identity': { + 'methods': ['hw_ak_sk'], + 'hw_ak_sk': { + 'access': { + 'key': self.access_key + }, + 'secret': { + 'key': self.secret_key + } + } + }, + 'scope': { + 'project': { + 'name': self.project_name + } + } + } + } + headers = {'Content-Type': 'application/json'} + + response = requests.request('POST', url, headers=headers, json=payload) + return response + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'system' + + messages.append(msg) + + data = {'messages': messages, 'stream': False} + + token_response = self._get_token() + if token_response.status_code == 201: + token = token_response.headers['X-Subject-Token'] + print('请求成功!') + else: + msg = 'token生成失败' + print(msg) + return '' + + headers = {'Content-Type': 'application/json', 'X-Auth-Token': token} + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + raw_response = requests.request('POST', + url=self.url, + headers=headers, + json=data) + response = raw_response.json() + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if raw_response.status_code == 200: + # msg = json.load(response.text) + # response + msg = response['choices'][0]['message']['content'] + return msg + + if (raw_response.status_code != 200): + print(response['error_msg']) + return '' + print(response) + max_num_retries += 1 + + raise RuntimeError(response['error_msg']) diff --git a/opencompass/models/sensetime_api.py b/opencompass/models/sensetime_api.py new file mode 100644 index 00000000..b7560110 --- /dev/null +++ b/opencompass/models/sensetime_api.py @@ -0,0 +1,136 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class SenseTime(BaseAPIModel): + """Model wrapper around SenseTime. + + Args: + path (str): The name of SenseTime model. + e.g. `nova-ptc-xl-v1` + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + key: str, + url: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {key}' + } + self.url = url + self.model = path + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + + messages.append(msg) + + data = {'messages': messages, 'model': self.model} + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + raw_response = requests.request('POST', + url=self.url, + headers=self.headers, + json=data) + response = raw_response.json() + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if raw_response.status_code == 200: + msg = response['data']['choices'][0]['message'] + return msg + + if (raw_response.status_code != 200): + print(raw_response.text) + time.sleep(1) + continue + print(response) + max_num_retries += 1 + + raise RuntimeError(raw_response.text) diff --git a/opencompass/models/xunfei_api.py b/opencompass/models/xunfei_api.py index 72f3815f..0e1de20e 100644 --- a/opencompass/models/xunfei_api.py +++ b/opencompass/models/xunfei_api.py @@ -1,5 +1,4 @@ import json -import sys from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union @@ -120,37 +119,6 @@ class XunFei(BaseAPIModel): self.flush() return results - def flush(self): - """Flush stdout and stderr when concurrent resources exists. - - When use multiproessing with standard io rediected to files, need to - flush internal information for examination or log loss when system - breaks. - """ - if hasattr(self, 'tokens'): - sys.stdout.flush() - sys.stderr.flush() - - def acquire(self): - """Acquire concurrent resources if exists. - - This behavior will fall back to wait with query_per_second if there are - no concurrent resources. - """ - if hasattr(self, 'tokens'): - self.tokens.acquire() - else: - self.wait() - - def release(self): - """Release concurrent resources if acquired. - - This behavior will fall back to do nothing if there are no concurrent - resources. - """ - if hasattr(self, 'tokens'): - self.tokens.release() - def _generate( self, input: str or PromptList, diff --git a/opencompass/models/zhipuai_api.py b/opencompass/models/zhipuai_api.py index 7360b12f..dd7c4d70 100644 --- a/opencompass/models/zhipuai_api.py +++ b/opencompass/models/zhipuai_api.py @@ -1,4 +1,3 @@ -import sys from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union @@ -66,37 +65,6 @@ class ZhiPuAI(BaseAPIModel): self.flush() return results - def flush(self): - """Flush stdout and stderr when concurrent resources exists. - - When use multiproessing with standard io rediected to files, need to - flush internal information for examination or log loss when system - breaks. - """ - if hasattr(self, 'tokens'): - sys.stdout.flush() - sys.stderr.flush() - - def acquire(self): - """Acquire concurrent resources if exists. - - This behavior will fall back to wait with query_per_second if there are - no concurrent resources. - """ - if hasattr(self, 'tokens'): - self.tokens.acquire() - else: - self.wait() - - def release(self): - """Release concurrent resources if acquired. - - This behavior will fall back to do nothing if there are no concurrent - resources. - """ - if hasattr(self, 'tokens'): - self.tokens.release() - def _generate( self, input: str or PromptList, diff --git a/requirements/api.txt b/requirements/api.txt index a9a20933..b199a92c 100644 --- a/requirements/api.txt +++ b/requirements/api.txt @@ -1,2 +1,4 @@ +sseclient-py==1.7.2 +volcengine # bytedance websocket-client -zhipuai +zhipuai # zhipu