diff --git a/configs/api_examples/eval_api_qwen.py b/configs/api_examples/eval_api_qwen.py new file mode 100644 index 00000000..7df987e0 --- /dev/null +++ b/configs/api_examples/eval_api_qwen.py @@ -0,0 +1,40 @@ +from mmengine.config import read_base +from opencompass.models import Qwen +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + from ..summarizers.medium import summarizer + from ..datasets.ceval.ceval_gen import ceval_datasets + +datasets = [ + *ceval_datasets, +] + +models = [ + dict( + abbr='qwen-max', + type=Qwen, + path='qwen-max', + key='xxxxxxxxxxxxxxxx', # please give you key + generation_kwargs={ + 'enable_search': False, + }, + query_per_second=1, + max_out_len=2048, + max_seq_len=2048, + batch_size=8 + ), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=1, + concurrent_users=1, + task=dict(type=OpenICLInferTask)), +) + +work_dir = "outputs/api_qwen/" diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 92f0ce1a..b742e125 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -19,6 +19,7 @@ from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403 from .moonshot_api import MoonShot # noqa: F401 from .openai_api import OpenAI # noqa: F401 from .pangu_api import PanGu # noqa: F401 +from .qwen_api import Qwen # noqa: F401 from .sensetime_api import SenseTime # noqa: F401 from .turbomind import TurboMindModel # noqa: F401 from .turbomind_tis import TurboMindTisModel # noqa: F401 diff --git a/opencompass/models/qwen_api.py b/opencompass/models/qwen_api.py new file mode 100644 index 00000000..89c4aa9d --- /dev/null +++ b/opencompass/models/qwen_api.py @@ -0,0 +1,153 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +class Qwen(BaseAPIModel): + """Model wrapper around Qwen. + + Documentation: + https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions/ + + Args: + path (str): The name of qwen model. + e.g. `qwen-max` + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__(self, + path: str, + key: str, + query_per_second: int = 1, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 5, + generation_kwargs: Dict = {}): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry, + generation_kwargs=generation_kwargs) + import dashscope + dashscope.api_key = key + self.dashscope = dashscope + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + """ + { + "messages": [ + {"role":"user","content":"请介绍一下你自己"}, + {"role":"assistant","content":"我是通义千问"}, + {"role":"user","content": "我在上海,周末可以去哪里玩?"}, + {"role":"assistant","content": "上海是一个充满活力和文化氛围的城市"}, + {"role":"user","content": "周末这里的天气怎么样?"} + ] + } + + """ + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + + messages.append(msg) + data = {'messages': messages} + data.update(self.generation_kwargs) + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + response = self.dashscope.Generation.call( + model=self.path, + **data, + ) + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + + if response.status_code == 200: + try: + msg = response.output.text + return msg + except KeyError: + print(response) + self.logger.error(str(response.status_code)) + time.sleep(1) + continue + + if ('Range of input length should be ' in response.message + or # input too long + 'Input data may contain inappropriate content.' + in response.message): # bad input + print(response.message) + return '' + print(response) + max_num_retries += 1 + + raise RuntimeError(response.message)