mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Create jiutian_api.py
add jiutian_api client
This commit is contained in:
parent
45651dfb34
commit
eb277591f7
203
opencompass/models/jiutian_api.py
Normal file
203
opencompass/models/jiutian_api.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
import requests
|
||||||
|
from opencompass.registry import MODELS
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
from .base_api import BaseAPIModel
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class JiutianApi(BaseAPIModel):
|
||||||
|
"""Model wrapper around Jiutian API's models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The name of model.
|
||||||
|
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||||
|
Note that the length of prompt + generated tokens shall not exceed
|
||||||
|
this value. Defaults to 2048.
|
||||||
|
query_per_second (int): The maximum queries allowed per second
|
||||||
|
between two consecutive calls of the API. Defaults to 1.
|
||||||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||||
|
url (str): The base url
|
||||||
|
mode (str, optional): The method of input truncation when input length
|
||||||
|
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
|
||||||
|
of input to truncate. Defaults to 'none'.
|
||||||
|
temperature (float, optional): What sampling temperature to use.
|
||||||
|
If not None, will override the temperature in the `generate()`
|
||||||
|
call. Defaults to None.
|
||||||
|
model_id : The id of model
|
||||||
|
appcode : auth token
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_api: bool = True
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
path: str = 'cmri_base',
|
||||||
|
max_seq_len: int = 4096,
|
||||||
|
query_per_second: int = 1,
|
||||||
|
retry: int = 2,
|
||||||
|
appcode: str = '',
|
||||||
|
url: str = None,
|
||||||
|
stream: bool = True,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
model_id: str = '',
|
||||||
|
temperature: Optional[float] = None):
|
||||||
|
|
||||||
|
super().__init__(path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
retry=retry)
|
||||||
|
import tiktoken
|
||||||
|
self.tiktoken = tiktoken
|
||||||
|
self.temperature = temperature
|
||||||
|
self.url = url
|
||||||
|
self.path = path
|
||||||
|
self.stream = stream
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.model_id = model_id
|
||||||
|
self.appcode = appcode
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: List[str or PromptList],
|
||||||
|
max_out_len: int = 512
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = list(
|
||||||
|
executor.map(self._generate, inputs,
|
||||||
|
[max_out_len] * len(inputs)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate(self, input: str or PromptList, max_out_len: int) -> str:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (str or PromptList): A string or PromptDict.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated string.
|
||||||
|
"""
|
||||||
|
assert isinstance(input, (str, PromptList))
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
for item in input:
|
||||||
|
msg = {'content': item['prompt']}
|
||||||
|
if item['role'] == 'HUMAN':
|
||||||
|
msg['role'] = 'user'
|
||||||
|
elif item['role'] == 'BOT':
|
||||||
|
msg['role'] = 'assistant'
|
||||||
|
elif item['role'] == 'SYSTEM':
|
||||||
|
msg['role'] = 'system'
|
||||||
|
messages.append(msg)
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
max_num_retries = 0
|
||||||
|
while max_num_retries < self.retry:
|
||||||
|
self.wait()
|
||||||
|
header = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer %s" % self.appcode
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
'model': self.model_id,
|
||||||
|
'messages': messages,
|
||||||
|
'max_tokens': self.max_tokens,
|
||||||
|
'stream': True
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_response = requests.request('POST',
|
||||||
|
url=self.url,
|
||||||
|
headers=header,
|
||||||
|
json=data,
|
||||||
|
stream=True)
|
||||||
|
except Exception as err:
|
||||||
|
self.logger.error('Request Error:{}'.format(err))
|
||||||
|
time.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.parse_event_data(raw_response)
|
||||||
|
except Exception as err:
|
||||||
|
self.logger.error('Response Error:{}'.format(err))
|
||||||
|
response = None
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
self.logger.error('Connection error, reconnect.')
|
||||||
|
self.wait()
|
||||||
|
continue
|
||||||
|
if isinstance(response, str):
|
||||||
|
self.logger.error('Get stram result error, retry.')
|
||||||
|
self.wait()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
msg = response['full_text']
|
||||||
|
self.logger.debug(f'Generated: {msg}')
|
||||||
|
return msg
|
||||||
|
except:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
max_num_retries += 1
|
||||||
|
|
||||||
|
raise RuntimeError('max error in max_num_retries')
|
||||||
|
|
||||||
|
def parse_event_data(self, resp) -> Dict:
|
||||||
|
"""
|
||||||
|
解析事件数据
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _deal_data(data: str):
|
||||||
|
if data.startswith("data"):
|
||||||
|
data = data.split("data:")[-1]
|
||||||
|
try:
|
||||||
|
d_data = json.loads(data)
|
||||||
|
if "full_text" in d_data and d_data["full_text"]:
|
||||||
|
self.logger.debug(f"client, request response={data}")
|
||||||
|
return True, d_data
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"client, request response={data}, error={e}")
|
||||||
|
|
||||||
|
return False, {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if resp.encoding is None:
|
||||||
|
resp.encoding = 'utf-8'
|
||||||
|
for chunk in resp.iter_lines(decode_unicode=True):
|
||||||
|
if chunk.startswith(("event", "ping")):
|
||||||
|
continue
|
||||||
|
flag, data = _deal_data(chunk)
|
||||||
|
if flag:
|
||||||
|
return data
|
||||||
|
return ''
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"client, get stram response error={e}")
|
||||||
|
return "get parse_event_data error"
|
Loading…
Reference in New Issue
Block a user