mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge 39a35f10cd
into 07930b854a
This commit is contained in:
commit
b1eb11da82
16
configs/models/jiutian/jiutian_139.py
Normal file
16
configs/models/jiutian/jiutian_139.py
Normal file
@ -0,0 +1,16 @@
|
||||
from opencompass.models import JiutianApi
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='JIUTIAN-13.9B',
|
||||
type=JiutianApi,
|
||||
path='jiutian-cm',
|
||||
appcode='',
|
||||
url='https://jiutian.10086.cn/kunlun/ingress/api/h3t-f9c8f9/fae3164b494b4d97b7011c839013c912/ai-7f03963dae10471bb42b6a763a875a68/service-d4cc837d3fe34656a7c0eebd6cec8311/v1/chat/completions',
|
||||
max_seq_len=8192,
|
||||
max_out_len=4096,
|
||||
batch_size=1,
|
||||
max_tokens=512,
|
||||
model_id='jiutian-cm'
|
||||
)
|
||||
]
|
@ -49,3 +49,5 @@ from .yayi_api import Yayi # noqa: F401
|
||||
from .yi_api import YiAPI # noqa: F401
|
||||
from .zhipuai_api import ZhiPuAI # noqa: F401
|
||||
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
|
||||
from .jiutian_api import JiutianApi # noqa: F401
|
||||
|
||||
|
203
opencompass/models/jiutian_api.py
Normal file
203
opencompass/models/jiutian_api.py
Normal file
@ -0,0 +1,203 @@
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
import requests
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils.prompt import PromptList
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class JiutianApi(BaseAPIModel):
|
||||
"""Model wrapper around Jiutian API's models.
|
||||
|
||||
Args:
|
||||
path (str): The name of model.
|
||||
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||
Note that the length of prompt + generated tokens shall not exceed
|
||||
this value. Defaults to 2048.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
url (str): The base url
|
||||
mode (str, optional): The method of input truncation when input length
|
||||
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
|
||||
of input to truncate. Defaults to 'none'.
|
||||
temperature (float, optional): What sampling temperature to use.
|
||||
If not None, will override the temperature in the `generate()`
|
||||
call. Defaults to None.
|
||||
model_id : The id of model
|
||||
appcode : auth token
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
path: str = 'cmri_base',
|
||||
max_seq_len: int = 4096,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
appcode: str = '',
|
||||
url: str = None,
|
||||
stream: bool = True,
|
||||
max_tokens: int = 1024,
|
||||
model_id: str = '',
|
||||
temperature: Optional[float] = None):
|
||||
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
query_per_second=query_per_second,
|
||||
retry=retry)
|
||||
import tiktoken
|
||||
self.tiktoken = tiktoken
|
||||
self.temperature = temperature
|
||||
self.url = url
|
||||
self.path = path
|
||||
self.stream = stream
|
||||
self.max_tokens = max_tokens
|
||||
self.model_id = model_id
|
||||
self.appcode = appcode
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs)))
|
||||
return results
|
||||
|
||||
def _generate(self, input: str or PromptList, max_out_len: int) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": input
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
elif item['role'] == 'SYSTEM':
|
||||
msg['role'] = 'system'
|
||||
messages.append(msg)
|
||||
messages = []
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.wait()
|
||||
header = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer %s" % self.appcode
|
||||
}
|
||||
data = {
|
||||
'model': self.model_id,
|
||||
'messages': messages,
|
||||
'max_tokens': self.max_tokens,
|
||||
'stream': True
|
||||
}
|
||||
|
||||
try:
|
||||
raw_response = requests.request('POST',
|
||||
url=self.url,
|
||||
headers=header,
|
||||
json=data,
|
||||
stream=True)
|
||||
except Exception as err:
|
||||
self.logger.error('Request Error:{}'.format(err))
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
try:
|
||||
response = self.parse_event_data(raw_response)
|
||||
except Exception as err:
|
||||
self.logger.error('Response Error:{}'.format(err))
|
||||
response = None
|
||||
self.release()
|
||||
|
||||
if response is None:
|
||||
self.logger.error('Connection error, reconnect.')
|
||||
self.wait()
|
||||
continue
|
||||
if isinstance(response, str):
|
||||
self.logger.error('Get stram result error, retry.')
|
||||
self.wait()
|
||||
continue
|
||||
try:
|
||||
msg = response['full_text']
|
||||
self.logger.debug(f'Generated: {msg}')
|
||||
return msg
|
||||
except:
|
||||
return ''
|
||||
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError('max error in max_num_retries')
|
||||
|
||||
def parse_event_data(self, resp) -> Dict:
|
||||
"""
|
||||
解析事件数据
|
||||
:return:
|
||||
"""
|
||||
|
||||
def _deal_data(data: str):
|
||||
if data.startswith("data"):
|
||||
data = data.split("data:")[-1]
|
||||
try:
|
||||
d_data = json.loads(data)
|
||||
if "full_text" in d_data and d_data["full_text"]:
|
||||
self.logger.debug(f"client, request response={data}")
|
||||
return True, d_data
|
||||
except Exception as e:
|
||||
self.logger.error(f"client, request response={data}, error={e}")
|
||||
|
||||
return False, {}
|
||||
|
||||
try:
|
||||
if resp.encoding is None:
|
||||
resp.encoding = 'utf-8'
|
||||
for chunk in resp.iter_lines(decode_unicode=True):
|
||||
if chunk.startswith(("event", "ping")):
|
||||
continue
|
||||
flag, data = _deal_data(chunk)
|
||||
if flag:
|
||||
return data
|
||||
return ''
|
||||
except Exception as e:
|
||||
self.logger.error(f"client, get stram response error={e}")
|
||||
return "get parse_event_data error"
|
Loading…
Reference in New Issue
Block a user