mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
262 lines
8.0 KiB
Python
262 lines
8.0 KiB
Python
import base64
|
||
import hashlib
|
||
import hmac
|
||
import random
|
||
import string
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
import requests
|
||
|
||
from opencompass.utils.prompt import PromptList
|
||
|
||
from .base_api import BaseAPIModel
|
||
|
||
PromptType = Union[PromptList, str]
|
||
|
||
|
||
def generate_random_string(length=16):
|
||
"""生成随机串.
|
||
|
||
:param length: 随机串长度,默认为 16
|
||
:return: 随机串
|
||
"""
|
||
letters = string.ascii_letters + string.digits
|
||
rand_str = ''.join(random.choice(letters) for i in range(length))
|
||
return rand_str
|
||
|
||
|
||
def get_current_time(format='%Y-%m-%d %H:%M:%S'):
|
||
"""获取当前时间.
|
||
|
||
:param format: 时间格式,默认为 '%H:%M:%S'
|
||
:return: 当前时间字符串
|
||
"""
|
||
now = datetime.now()
|
||
time_str = now.strftime(format)
|
||
return time_str
|
||
|
||
|
||
def get_current_timestamp():
|
||
"""
|
||
获取当前时间时间戳
|
||
:return:
|
||
"""
|
||
timestamp_str = int(round(time.time() * 1000))
|
||
return str(timestamp_str)
|
||
|
||
|
||
def encode_base64_string(s):
|
||
"""对字符串进行 Base64 编码.
|
||
|
||
:param s: 字符串
|
||
:return: 编码后的字符串
|
||
"""
|
||
encoded = base64.b64encode(s).decode()
|
||
return encoded
|
||
|
||
|
||
def get_current_time_gmt_format():
|
||
"""
|
||
获取当前时间的GMT 时间
|
||
:return:
|
||
"""
|
||
GMT_FORMAT = '%a, %d %b %Y %H:%M:%SGMT+00:00'
|
||
now = datetime.now()
|
||
time_str = now.strftime(GMT_FORMAT)
|
||
return time_str
|
||
|
||
|
||
class Yayi(BaseAPIModel):
|
||
"""Model wrapper around SenseTime.
|
||
|
||
Args:
|
||
path (str): The name of SenseTime model.
|
||
e.g. `nova-ptc-xl-v1`
|
||
key (str): Authorization key.
|
||
query_per_second (int): The maximum queries allowed per second
|
||
between two consecutive calls of the API. Defaults to 1.
|
||
max_seq_len (int): Unused here.
|
||
meta_template (Dict, optional): The model's meta prompt
|
||
template if needed, in case the requirement of injecting or
|
||
wrapping of any meta instructions.
|
||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
path: str,
|
||
url: str,
|
||
url_path: str,
|
||
x_tilake_app_key: str,
|
||
x_tilake_app_secret: str,
|
||
x_tilake_ca_sginature_method: str,
|
||
query_per_second: int = 2,
|
||
max_seq_len: int = 2048,
|
||
meta_template: Optional[Dict] = None,
|
||
retry: int = 2,
|
||
temperature: float = 0.4,
|
||
):
|
||
super().__init__(
|
||
path=path,
|
||
max_seq_len=max_seq_len,
|
||
query_per_second=query_per_second,
|
||
meta_template=meta_template,
|
||
retry=retry,
|
||
)
|
||
|
||
self.url = url
|
||
self.url_path = url_path
|
||
self.X_TILAKE_APP_KEY = x_tilake_app_key
|
||
self.X_TILAKE_APP_SECRET = x_tilake_app_secret
|
||
self.X_TILAKE_CA_SGINATURE_METHOD = x_tilake_ca_sginature_method
|
||
self.temperature = temperature
|
||
self.model = path
|
||
|
||
def generate_signature(self, method, accept, content_type, date, url_path):
|
||
"""生成签名.
|
||
|
||
:param method:
|
||
:param accept:
|
||
:param content_type:
|
||
:param date:
|
||
:param url_path:
|
||
:return:
|
||
"""
|
||
string_to_sign = (method + '\n' + accept + '\n' + content_type + '\n' +
|
||
date + '\n' + url_path)
|
||
string_to_sign = string_to_sign.encode('utf-8')
|
||
secret_key = self.X_TILAKE_APP_SECRET.encode('utf-8')
|
||
signature = hmac.new(secret_key, string_to_sign,
|
||
hashlib.sha256).digest()
|
||
return encode_base64_string(signature)
|
||
|
||
def generate_header(self, content_type, accept, date, signature):
|
||
"""生成请求头参数.
|
||
|
||
:param content_type:
|
||
:param accept:
|
||
:return:
|
||
"""
|
||
headers = {
|
||
'x-tilake-app-key': self.X_TILAKE_APP_KEY,
|
||
'x-tilake-ca-signature-method': self.X_TILAKE_CA_SGINATURE_METHOD,
|
||
'x-tilake-ca-timestamp': get_current_timestamp(),
|
||
'x-tilake-ca-nonce': generate_random_string(),
|
||
'x-tilake-ca-signature': signature,
|
||
'Date': date,
|
||
'Content-Type': content_type,
|
||
'Accept': accept,
|
||
}
|
||
return headers
|
||
|
||
def generate(
|
||
self,
|
||
inputs: List[PromptType],
|
||
max_out_len: int = 512,
|
||
) -> List[str]:
|
||
"""Generate results given a list of inputs.
|
||
|
||
Args:
|
||
inputs (List[PromptType]): A list of strings or PromptDicts.
|
||
The PromptDict should be organized in OpenCompass'
|
||
API format.
|
||
max_out_len (int): The maximum length of the output.
|
||
|
||
Returns:
|
||
List[str]: A list of generated strings.
|
||
"""
|
||
with ThreadPoolExecutor() as executor:
|
||
results = list(
|
||
executor.map(self._generate, inputs,
|
||
[max_out_len] * len(inputs)))
|
||
self.flush()
|
||
return results
|
||
|
||
def _generate(
|
||
self,
|
||
input: PromptType,
|
||
max_out_len: int = 512,
|
||
) -> str:
|
||
"""Generate results given an input.
|
||
|
||
Args:
|
||
inputs (PromptType): A string or PromptDict.
|
||
The PromptDict should be organized in OpenCompass'
|
||
API format.
|
||
max_out_len (int): The maximum length of the output.
|
||
|
||
Returns:
|
||
str: The generated string.
|
||
"""
|
||
assert isinstance(input, (str, PromptList))
|
||
|
||
if isinstance(input, str):
|
||
messages = [{'role': 'user', 'content': input}]
|
||
else:
|
||
messages = []
|
||
msg_buffer, last_role = [], None
|
||
for item in input:
|
||
item['role'] = 'yayi' if item['role'] == 'BOT' else 'user'
|
||
if item['role'] != last_role and last_role is not None:
|
||
messages.append({
|
||
'content': '\n'.join(msg_buffer),
|
||
'role': last_role
|
||
})
|
||
msg_buffer = []
|
||
msg_buffer.append(item['prompt'])
|
||
last_role = item['role']
|
||
messages.append({
|
||
'content': '\n'.join(msg_buffer),
|
||
'role': last_role
|
||
})
|
||
|
||
date = get_current_time_gmt_format()
|
||
content_type = 'application/json'
|
||
accept = '*/*'
|
||
method = 'POST'
|
||
data = {
|
||
'id': '001', # 请求id,无需修改。
|
||
'model': self.model,
|
||
'messages': messages,
|
||
'max_new_tokens': max_out_len, # max_new_tokens及以下参数可根据实际任务进行调整。
|
||
'temperature': self.temperature,
|
||
'presence_penalty': 0.85,
|
||
'frequency_penalty': 0.16,
|
||
'do_sample': True,
|
||
'top_p': 1.0,
|
||
'top_k': -1,
|
||
}
|
||
|
||
for _ in range(self.retry):
|
||
signature_str = self.generate_signature(method=method,
|
||
accept=accept,
|
||
content_type=content_type,
|
||
date=date,
|
||
url_path=self.url_path)
|
||
headers = self.generate_header(content_type=content_type,
|
||
accept=accept,
|
||
date=date,
|
||
signature=signature_str)
|
||
|
||
try:
|
||
response = requests.post(self.url, json=data, headers=headers)
|
||
except Exception as e:
|
||
print(e)
|
||
continue
|
||
try:
|
||
response = response.json()
|
||
except Exception as e:
|
||
print(e)
|
||
continue
|
||
print(response)
|
||
try:
|
||
return response['data']['choices'][0]['message']['content']
|
||
except Exception as e:
|
||
print(e)
|
||
continue
|
||
|
||
raise RuntimeError(f'Failed to respond in {self.retry} retrys')
|