mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
204 lines
6.9 KiB
Python
204 lines
6.9 KiB
Python
import json
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import Dict, List, Optional, Union
|
||
import requests
|
||
from opencompass.registry import MODELS
|
||
from opencompass.utils.prompt import PromptList
|
||
from .base_api import BaseAPIModel
|
||
|
||
PromptType = Union[PromptList, str]
|
||
|
||
|
||
@MODELS.register_module()
|
||
class JiutianApi(BaseAPIModel):
|
||
"""Model wrapper around Jiutian API's models.
|
||
|
||
Args:
|
||
path (str): The name of model.
|
||
max_seq_len (int): The maximum allowed sequence length of a model.
|
||
Note that the length of prompt + generated tokens shall not exceed
|
||
this value. Defaults to 2048.
|
||
query_per_second (int): The maximum queries allowed per second
|
||
between two consecutive calls of the API. Defaults to 1.
|
||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||
url (str): The base url
|
||
mode (str, optional): The method of input truncation when input length
|
||
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
|
||
of input to truncate. Defaults to 'none'.
|
||
temperature (float, optional): What sampling temperature to use.
|
||
If not None, will override the temperature in the `generate()`
|
||
call. Defaults to None.
|
||
model_id : The id of model
|
||
appcode : auth token
|
||
"""
|
||
|
||
is_api: bool = True
|
||
|
||
def __init__(self,
|
||
path: str = 'cmri_base',
|
||
max_seq_len: int = 4096,
|
||
query_per_second: int = 1,
|
||
retry: int = 2,
|
||
appcode: str = '',
|
||
url: str = None,
|
||
stream: bool = True,
|
||
max_tokens: int = 1024,
|
||
model_id: str = '',
|
||
temperature: Optional[float] = None):
|
||
|
||
super().__init__(path=path,
|
||
max_seq_len=max_seq_len,
|
||
query_per_second=query_per_second,
|
||
retry=retry)
|
||
import tiktoken
|
||
self.tiktoken = tiktoken
|
||
self.temperature = temperature
|
||
self.url = url
|
||
self.path = path
|
||
self.stream = stream
|
||
self.max_tokens = max_tokens
|
||
self.model_id = model_id
|
||
self.appcode = appcode
|
||
|
||
def generate(
|
||
self,
|
||
inputs: List[str or PromptList],
|
||
max_out_len: int = 512
|
||
) -> List[str]:
|
||
"""Generate results given a list of inputs.
|
||
|
||
Args:
|
||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||
The PromptDict should be organized in OpenCompass'
|
||
API format.
|
||
max_out_len (int): The maximum length of the output.
|
||
|
||
Returns:
|
||
List[str]: A list of generated strings.
|
||
"""
|
||
|
||
with ThreadPoolExecutor() as executor:
|
||
results = list(
|
||
executor.map(self._generate, inputs,
|
||
[max_out_len] * len(inputs)))
|
||
return results
|
||
|
||
def _generate(self, input: str or PromptList, max_out_len: int) -> str:
|
||
"""Generate results given a list of inputs.
|
||
|
||
Args:
|
||
inputs (str or PromptList): A string or PromptDict.
|
||
The PromptDict should be organized in OpenCompass'
|
||
API format.
|
||
max_out_len (int): The maximum length of the output.
|
||
|
||
Returns:
|
||
str: The generated string.
|
||
"""
|
||
assert isinstance(input, (str, PromptList))
|
||
|
||
if isinstance(input, str):
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": input
|
||
}
|
||
]
|
||
else:
|
||
messages = []
|
||
for item in input:
|
||
msg = {'content': item['prompt']}
|
||
if item['role'] == 'HUMAN':
|
||
msg['role'] = 'user'
|
||
elif item['role'] == 'BOT':
|
||
msg['role'] = 'assistant'
|
||
elif item['role'] == 'SYSTEM':
|
||
msg['role'] = 'system'
|
||
messages.append(msg)
|
||
messages = []
|
||
|
||
max_num_retries = 0
|
||
while max_num_retries < self.retry:
|
||
self.wait()
|
||
header = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": "Bearer %s" % self.appcode
|
||
}
|
||
data = {
|
||
'model': self.model_id,
|
||
'messages': messages,
|
||
'max_tokens': self.max_tokens,
|
||
'stream': True
|
||
}
|
||
|
||
try:
|
||
raw_response = requests.request('POST',
|
||
url=self.url,
|
||
headers=header,
|
||
json=data,
|
||
stream=True)
|
||
except Exception as err:
|
||
self.logger.error('Request Error:{}'.format(err))
|
||
time.sleep(2)
|
||
continue
|
||
|
||
try:
|
||
response = self.parse_event_data(raw_response)
|
||
except Exception as err:
|
||
self.logger.error('Response Error:{}'.format(err))
|
||
response = None
|
||
self.release()
|
||
|
||
if response is None:
|
||
self.logger.error('Connection error, reconnect.')
|
||
self.wait()
|
||
continue
|
||
if isinstance(response, str):
|
||
self.logger.error('Get stram result error, retry.')
|
||
self.wait()
|
||
continue
|
||
try:
|
||
msg = response['full_text']
|
||
self.logger.debug(f'Generated: {msg}')
|
||
return msg
|
||
except:
|
||
return ''
|
||
|
||
max_num_retries += 1
|
||
|
||
raise RuntimeError('max error in max_num_retries')
|
||
|
||
def parse_event_data(self, resp) -> Dict:
|
||
"""
|
||
解析事件数据
|
||
:return:
|
||
"""
|
||
|
||
def _deal_data(data: str):
|
||
if data.startswith("data"):
|
||
data = data.split("data:")[-1]
|
||
try:
|
||
d_data = json.loads(data)
|
||
if "full_text" in d_data and d_data["full_text"]:
|
||
self.logger.debug(f"client, request response={data}")
|
||
return True, d_data
|
||
except Exception as e:
|
||
self.logger.error(f"client, request response={data}, error={e}")
|
||
|
||
return False, {}
|
||
|
||
try:
|
||
if resp.encoding is None:
|
||
resp.encoding = 'utf-8'
|
||
for chunk in resp.iter_lines(decode_unicode=True):
|
||
if chunk.startswith(("event", "ping")):
|
||
continue
|
||
flag, data = _deal_data(chunk)
|
||
if flag:
|
||
return data
|
||
return ''
|
||
except Exception as e:
|
||
self.logger.error(f"client, get stram response error={e}")
|
||
return "get parse_event_data error"
|