mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge 39a35f10cd
into 07930b854a
This commit is contained in:
commit
b1eb11da82
16
configs/models/jiutian/jiutian_139.py
Normal file
16
configs/models/jiutian/jiutian_139.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from opencompass.models import JiutianApi
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
abbr='JIUTIAN-13.9B',
|
||||||
|
type=JiutianApi,
|
||||||
|
path='jiutian-cm',
|
||||||
|
appcode='',
|
||||||
|
url='https://jiutian.10086.cn/kunlun/ingress/api/h3t-f9c8f9/fae3164b494b4d97b7011c839013c912/ai-7f03963dae10471bb42b6a763a875a68/service-d4cc837d3fe34656a7c0eebd6cec8311/v1/chat/completions',
|
||||||
|
max_seq_len=8192,
|
||||||
|
max_out_len=4096,
|
||||||
|
batch_size=1,
|
||||||
|
max_tokens=512,
|
||||||
|
model_id='jiutian-cm'
|
||||||
|
)
|
||||||
|
]
|
@ -49,3 +49,5 @@ from .yayi_api import Yayi # noqa: F401
|
|||||||
from .yi_api import YiAPI # noqa: F401
|
from .yi_api import YiAPI # noqa: F401
|
||||||
from .zhipuai_api import ZhiPuAI # noqa: F401
|
from .zhipuai_api import ZhiPuAI # noqa: F401
|
||||||
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
|
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
|
||||||
|
from .jiutian_api import JiutianApi # noqa: F401
|
||||||
|
|
||||||
|
203
opencompass/models/jiutian_api.py
Normal file
203
opencompass/models/jiutian_api.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
import requests
|
||||||
|
from opencompass.registry import MODELS
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
from .base_api import BaseAPIModel
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class JiutianApi(BaseAPIModel):
|
||||||
|
"""Model wrapper around Jiutian API's models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The name of model.
|
||||||
|
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||||
|
Note that the length of prompt + generated tokens shall not exceed
|
||||||
|
this value. Defaults to 2048.
|
||||||
|
query_per_second (int): The maximum queries allowed per second
|
||||||
|
between two consecutive calls of the API. Defaults to 1.
|
||||||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||||
|
url (str): The base url
|
||||||
|
mode (str, optional): The method of input truncation when input length
|
||||||
|
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
|
||||||
|
of input to truncate. Defaults to 'none'.
|
||||||
|
temperature (float, optional): What sampling temperature to use.
|
||||||
|
If not None, will override the temperature in the `generate()`
|
||||||
|
call. Defaults to None.
|
||||||
|
model_id : The id of model
|
||||||
|
appcode : auth token
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_api: bool = True
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
path: str = 'cmri_base',
|
||||||
|
max_seq_len: int = 4096,
|
||||||
|
query_per_second: int = 1,
|
||||||
|
retry: int = 2,
|
||||||
|
appcode: str = '',
|
||||||
|
url: str = None,
|
||||||
|
stream: bool = True,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
model_id: str = '',
|
||||||
|
temperature: Optional[float] = None):
|
||||||
|
|
||||||
|
super().__init__(path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
retry=retry)
|
||||||
|
import tiktoken
|
||||||
|
self.tiktoken = tiktoken
|
||||||
|
self.temperature = temperature
|
||||||
|
self.url = url
|
||||||
|
self.path = path
|
||||||
|
self.stream = stream
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.model_id = model_id
|
||||||
|
self.appcode = appcode
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: List[str or PromptList],
|
||||||
|
max_out_len: int = 512
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = list(
|
||||||
|
executor.map(self._generate, inputs,
|
||||||
|
[max_out_len] * len(inputs)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate(self, input: str or PromptList, max_out_len: int) -> str:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (str or PromptList): A string or PromptDict.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated string.
|
||||||
|
"""
|
||||||
|
assert isinstance(input, (str, PromptList))
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
for item in input:
|
||||||
|
msg = {'content': item['prompt']}
|
||||||
|
if item['role'] == 'HUMAN':
|
||||||
|
msg['role'] = 'user'
|
||||||
|
elif item['role'] == 'BOT':
|
||||||
|
msg['role'] = 'assistant'
|
||||||
|
elif item['role'] == 'SYSTEM':
|
||||||
|
msg['role'] = 'system'
|
||||||
|
messages.append(msg)
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
max_num_retries = 0
|
||||||
|
while max_num_retries < self.retry:
|
||||||
|
self.wait()
|
||||||
|
header = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer %s" % self.appcode
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
'model': self.model_id,
|
||||||
|
'messages': messages,
|
||||||
|
'max_tokens': self.max_tokens,
|
||||||
|
'stream': True
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_response = requests.request('POST',
|
||||||
|
url=self.url,
|
||||||
|
headers=header,
|
||||||
|
json=data,
|
||||||
|
stream=True)
|
||||||
|
except Exception as err:
|
||||||
|
self.logger.error('Request Error:{}'.format(err))
|
||||||
|
time.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.parse_event_data(raw_response)
|
||||||
|
except Exception as err:
|
||||||
|
self.logger.error('Response Error:{}'.format(err))
|
||||||
|
response = None
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
self.logger.error('Connection error, reconnect.')
|
||||||
|
self.wait()
|
||||||
|
continue
|
||||||
|
if isinstance(response, str):
|
||||||
|
self.logger.error('Get stram result error, retry.')
|
||||||
|
self.wait()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
msg = response['full_text']
|
||||||
|
self.logger.debug(f'Generated: {msg}')
|
||||||
|
return msg
|
||||||
|
except:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
max_num_retries += 1
|
||||||
|
|
||||||
|
raise RuntimeError('max error in max_num_retries')
|
||||||
|
|
||||||
|
def parse_event_data(self, resp) -> Dict:
|
||||||
|
"""
|
||||||
|
解析事件数据
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _deal_data(data: str):
|
||||||
|
if data.startswith("data"):
|
||||||
|
data = data.split("data:")[-1]
|
||||||
|
try:
|
||||||
|
d_data = json.loads(data)
|
||||||
|
if "full_text" in d_data and d_data["full_text"]:
|
||||||
|
self.logger.debug(f"client, request response={data}")
|
||||||
|
return True, d_data
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"client, request response={data}, error={e}")
|
||||||
|
|
||||||
|
return False, {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if resp.encoding is None:
|
||||||
|
resp.encoding = 'utf-8'
|
||||||
|
for chunk in resp.iter_lines(decode_unicode=True):
|
||||||
|
if chunk.startswith(("event", "ping")):
|
||||||
|
continue
|
||||||
|
flag, data = _deal_data(chunk)
|
||||||
|
if flag:
|
||||||
|
return data
|
||||||
|
return ''
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"client, get stram response error={e}")
|
||||||
|
return "get parse_event_data error"
|
Loading…
Reference in New Issue
Block a user