diff --git a/configs/models/jiutian/jiutian_139.py b/configs/models/jiutian/jiutian_139.py new file mode 100644 index 00000000..650dfaf8 --- /dev/null +++ b/configs/models/jiutian/jiutian_139.py @@ -0,0 +1,16 @@ +from opencompass.models import JiutianApi + +models = [ + dict( + abbr='JIUTIAN-13.9B', + type=JiutianApi, + path='jiutian-cm', + appcode='', + url='https://jiutian.10086.cn/kunlun/ingress/api/h3t-f9c8f9/fae3164b494b4d97b7011c839013c912/ai-7f03963dae10471bb42b6a763a875a68/service-d4cc837d3fe34656a7c0eebd6cec8311/v1/chat/completions', + max_seq_len=8192, + max_out_len=4096, + batch_size=1, + max_tokens=512, + model_id='jiutian-cm' + ) +] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 580402d4..85ecab73 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -49,3 +49,5 @@ from .yayi_api import Yayi # noqa: F401 from .yi_api import YiAPI # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401 from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401 +from .jiutian_api import JiutianApi # noqa: F401 + diff --git a/opencompass/models/jiutian_api.py b/opencompass/models/jiutian_api.py new file mode 100644 index 00000000..fef6b94f --- /dev/null +++ b/opencompass/models/jiutian_api.py @@ -0,0 +1,203 @@ +import json +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union +import requests +from opencompass.registry import MODELS +from opencompass.utils.prompt import PromptList +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +@MODELS.register_module() +class JiutianApi(BaseAPIModel): + """Model wrapper around Jiutian API's models. + + Args: + path (str): The name of model. + max_seq_len (int): The maximum allowed sequence length of a model. + Note that the length of prompt + generated tokens shall not exceed + this value. Defaults to 2048. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + retry (int): Number of retires if the API call fails. Defaults to 2. + url (str): The base url + mode (str, optional): The method of input truncation when input length + exceeds max_seq_len. 'front','mid' and 'rear' represents the part + of input to truncate. Defaults to 'none'. + temperature (float, optional): What sampling temperature to use. + If not None, will override the temperature in the `generate()` + call. Defaults to None. + model_id : The id of model + appcode : auth token + """ + + is_api: bool = True + + def __init__(self, + path: str = 'cmri_base', + max_seq_len: int = 4096, + query_per_second: int = 1, + retry: int = 2, + appcode: str = '', + url: str = None, + stream: bool = True, + max_tokens: int = 1024, + model_id: str = '', + temperature: Optional[float] = None): + + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + retry=retry) + import tiktoken + self.tiktoken = tiktoken + self.temperature = temperature + self.url = url + self.path = path + self.stream = stream + self.max_tokens = max_tokens + self.model_id = model_id + self.appcode = appcode + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512 + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + return results + + def _generate(self, input: str or PromptList, max_out_len: int) -> str: + """Generate results given a list of inputs. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [ + { + "role": "user", + "content": input + } + ] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + elif item['role'] == 'SYSTEM': + msg['role'] = 'system' + messages.append(msg) + messages = [] + + max_num_retries = 0 + while max_num_retries < self.retry: + self.wait() + header = { + "Content-Type": "application/json", + "Authorization": "Bearer %s" % self.appcode + } + data = { + 'model': self.model_id, + 'messages': messages, + 'max_tokens': self.max_tokens, + 'stream': True + } + + try: + raw_response = requests.request('POST', + url=self.url, + headers=header, + json=data, + stream=True) + except Exception as err: + self.logger.error('Request Error:{}'.format(err)) + time.sleep(2) + continue + + try: + response = self.parse_event_data(raw_response) + except Exception as err: + self.logger.error('Response Error:{}'.format(err)) + response = None + self.release() + + if response is None: + self.logger.error('Connection error, reconnect.') + self.wait() + continue + if isinstance(response, str): + self.logger.error('Get stram result error, retry.') + self.wait() + continue + try: + msg = response['full_text'] + self.logger.debug(f'Generated: {msg}') + return msg + except: + return '' + + max_num_retries += 1 + + raise RuntimeError('max error in max_num_retries') + + def parse_event_data(self, resp) -> Dict: + """ + 解析事件数据 + :return: + """ + + def _deal_data(data: str): + if data.startswith("data"): + data = data.split("data:")[-1] + try: + d_data = json.loads(data) + if "full_text" in d_data and d_data["full_text"]: + self.logger.debug(f"client, request response={data}") + return True, d_data + except Exception as e: + self.logger.error(f"client, request response={data}, error={e}") + + return False, {} + + try: + if resp.encoding is None: + resp.encoding = 'utf-8' + for chunk in resp.iter_lines(decode_unicode=True): + if chunk.startswith(("event", "ping")): + continue + flag, data = _deal_data(chunk) + if flag: + return data + return '' + except Exception as e: + self.logger.error(f"client, get stram response error={e}") + return "get parse_event_data error"