From cdca59ff49bd543c7292c887a0bcbb6c04a42404 Mon Sep 17 00:00:00 2001 From: Songyang Zhang Date: Sun, 28 Jan 2024 14:52:43 +0800 Subject: [PATCH] [Fix] Update Zhipu API and Fix issue min_out_len issue of API models (#847) * Update zhipu api and fix min_out_len issue of API class * Update example * Update example --- configs/api_examples/eval_api_zhipu_v2.py | 67 +++++++++ opencompass/models/__init__.py | 1 + opencompass/models/zhipuai_v2_api.py | 172 ++++++++++++++++++++++ opencompass/runners/local_api.py | 1 + 4 files changed, 241 insertions(+) create mode 100644 configs/api_examples/eval_api_zhipu_v2.py create mode 100644 opencompass/models/zhipuai_v2_api.py diff --git a/configs/api_examples/eval_api_zhipu_v2.py b/configs/api_examples/eval_api_zhipu_v2.py new file mode 100644 index 00000000..acfa8af3 --- /dev/null +++ b/configs/api_examples/eval_api_zhipu_v2.py @@ -0,0 +1,67 @@ +from mmengine.config import read_base +from opencompass.models import ZhiPuV2AI +from opencompass.partitioners import NaivePartitioner +from opencompass.runners.local_api import LocalAPIRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + # from .datasets.collections.chat_medium import datasets + from ..summarizers.medium import summarizer + from ..datasets.ceval.ceval_gen import ceval_datasets + +datasets = [ + *ceval_datasets, +] + +# needs a special postprocessor for all +# except 'gsm8k' and 'strategyqa' +from opencompass.utils import general_eval_wrapper_postprocess +for _dataset in datasets: + if _dataset['abbr'] not in ['gsm8k', 'strategyqa']: + if hasattr(_dataset['eval_cfg'], 'pred_postprocessor'): + _dataset['eval_cfg']['pred_postprocessor']['postprocess'] = _dataset['eval_cfg']['pred_postprocessor']['type'] + _dataset['eval_cfg']['pred_postprocessor']['type'] = general_eval_wrapper_postprocess + else: + _dataset['eval_cfg']['pred_postprocessor'] = {'type': general_eval_wrapper_postprocess} + + +api_meta_template = dict( + round=[ + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ], +) + +models = [ + dict( + abbr='glm4_notools', + type=ZhiPuV2AI, + path='glm-4', + key='xxxxxx', + generation_kwargs={ + 'tools': [ + { + 'type': 'web_search', + 'web_search': { + 'enable': False # turn off the search + } + } + ] + }, + meta_template=api_meta_template, + query_per_second=1, + max_out_len=2048, + max_seq_len=2048, + batch_size=8) +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalAPIRunner, + max_num_workers=2, + concurrent_users=2, + task=dict(type=OpenICLInferTask)), +) + +work_dir = "outputs/api_zhipu_v2/" \ No newline at end of file diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 41c3e127..6790d652 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -28,3 +28,4 @@ from .turbomind_tis import TurboMindTisModel # noqa: F401 from .vllm import VLLM # noqa: F401 from .xunfei_api import XunFei # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401 +from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401 diff --git a/opencompass/models/zhipuai_v2_api.py b/opencompass/models/zhipuai_v2_api.py new file mode 100644 index 00000000..c1f79d98 --- /dev/null +++ b/opencompass/models/zhipuai_v2_api.py @@ -0,0 +1,172 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from httpx import ProxyError + +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +try: + from zhipuai.core._errors import APIStatusError, APITimeoutError +except ImportError: + APIStatusError = None + APITimeoutError = None + +PromptType = Union[PromptList, str] + + +class ZhiPuV2AI(BaseAPIModel): + """Model wrapper around ZhiPuAI. + + Args: + path (str): The name of OpenAI's model. + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__(self, + path: str, + key: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + generation_kwargs: Dict = { + 'tools': [{ + 'type': 'web_search', + 'enable': False + }] + }): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry, + generation_kwargs=generation_kwargs) + from zhipuai import ZhipuAI + + # self.zhipuai = zhipuai + self.client = ZhipuAI(api_key=key) + self.model = path + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + elif item['role'] == 'SYSTEM': + msg['role'] = 'system' + messages.append(msg) + + data = {'model': self.model, 'messages': messages} + data.update(self.generation_kwargs) + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + + try: + response = self.client.chat.completions.create(**data) + except APIStatusError as err: + err_message = str(err.response.json()['error']['message']) + status_code = str(err.status_code) + err_code = str(err.response.json()['error']['code']) + print('Error message:{}'.format(err_message)) + print('Statues code:{}'.format(status_code)) + print('Error code:{}'.format(err_code)) + + if err_code == '1301': + return 'Sensitive content' + elif err_code == '1302': + print('Reach rate limit') + time.sleep(1) + continue + except ProxyError as err: + print('Proxy Error, try again. {}'.format(err)) + time.sleep(3) + continue + except APITimeoutError as err: + print('APITimeoutError {}'.format(err)) + time.sleep(3) + continue + + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + max_num_retries += 1 + continue + + # if response['code'] == 200 and response['success']: + # msg = response['data']['choices'][0]['content'] + else: + msg = response.choices[0].message.content + return msg + # sensitive content, prompt overlength, network error + # or illegal prompt + if (response['code'] == 1301 or response['code'] == 1261 + or response['code'] == 1234 or response['code'] == 1214): + print(response['msg']) + return '' + print(response) + max_num_retries += 1 + + raise RuntimeError(response['msg']) diff --git a/opencompass/runners/local_api.py b/opencompass/runners/local_api.py index 8ec3df55..253075f0 100644 --- a/opencompass/runners/local_api.py +++ b/opencompass/runners/local_api.py @@ -28,6 +28,7 @@ def monkey_run(self, tokens: SyncManager.Semaphore): self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}') for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): self.max_out_len = model_cfg.get('max_out_len', None) + self.min_out_len = model_cfg.get('min_out_len', None) self.batch_size = model_cfg.get('batch_size', None) self.model = build_model_from_cfg(model_cfg) # add global tokens for concurrents