OpenCompass/opencompass/models/yayi_api.py
2025-05-28 11:28:52 +02:00

250 lines
8.0 KiB
Python

import base64
import hashlib
import hmac
import random
import string
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Dict, List, Optional, Union
import requests
from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
def generate_random_string(length=16):
"""生成随机串.
:param length: 随机串长度,默认为 16
:return: 随机串
"""
letters = string.ascii_letters + string.digits
rand_str = ''.join(random.choice(letters) for i in range(length))
return rand_str
def get_current_time(format='%Y-%m-%d %H:%M:%S'):
"""获取当前时间.
:param format: 时间格式,默认为 '%H:%M:%S'
:return: 当前时间字符串
"""
now = datetime.now()
time_str = now.strftime(format)
return time_str
def get_current_timestamp():
"""获取当前时间时间戳 :return:"""
timestamp_str = int(round(time.time() * 1000))
return str(timestamp_str)
def encode_base64_string(s):
"""对字符串进行 Base64 编码.
:param s: 字符串
:return: 编码后的字符串
"""
encoded = base64.b64encode(s).decode()
return encoded
def get_current_time_gmt_format():
"""获取当前时间的GMT 时间 :return:"""
GMT_FORMAT = '%a, %d %b %Y %H:%M:%SGMT+00:00'
now = datetime.now()
time_str = now.strftime(GMT_FORMAT)
return time_str
class Yayi(BaseAPIModel):
"""Model wrapper around Yayi.
Args:
path (str): The name of Yayi model.
url (str): The base URL for the API.
url_path (str): The specific path for the API endpoint.
x_tilake_app_key (str): The application key for authentication.
x_tilake_app_secret (str): The application secret for authentication.
x_tilake_ca_sginature_method (str): The signature method for authentication.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 2.
max_seq_len (int): The maximum sequence length. Defaults to 8192.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
retry (int): Number of retries if the API call fails. Defaults to 2.
temperature (float): The temperature for the model's response. Defaults to 0.0.
"""
def __init__(
self,
path: str,
url: str,
url_path: str,
x_tilake_app_key: str,
x_tilake_app_secret: str,
x_tilake_ca_sginature_method: str,
query_per_second: int = 2,
max_seq_len: int = 8192,
meta_template: Optional[Dict] = None,
retry: int = 2,
temperature: float = 0.0,
):
super().__init__(
path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry,
)
self.url = url
self.url_path = url_path
self.X_TILAKE_APP_KEY = x_tilake_app_key
self.X_TILAKE_APP_SECRET = x_tilake_app_secret
self.X_TILAKE_CA_SGINATURE_METHOD = x_tilake_ca_sginature_method
self.temperature = temperature
self.model = path
def generate_signature(self, method, accept, content_type, date, url_path):
"""
生成签名.
"""
string_to_sign = (method + '\n' + accept + '\n' + content_type + '\n' +
date + '\n' + url_path)
string_to_sign = string_to_sign.encode('utf-8')
secret_key = self.X_TILAKE_APP_SECRET.encode('utf-8')
signature = hmac.new(secret_key, string_to_sign,
hashlib.sha256).digest()
return encode_base64_string(signature)
def generate_header(self, content_type, accept, date, signature):
"""
生成请求头参数.
"""
headers = {
'x-tilake-app-key': self.X_TILAKE_APP_KEY,
'x-tilake-ca-signature-method': self.X_TILAKE_CA_SGINATURE_METHOD,
'x-tilake-ca-timestamp': get_current_timestamp(),
'x-tilake-ca-nonce': generate_random_string(),
'x-tilake-ca-signature': signature,
'Date': date,
'Content-Type': content_type,
'Accept': accept,
}
return headers
def generate(
self,
inputs: List[PromptType],
max_out_len: int = 8192,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
self.flush()
return results
def _generate(
self,
input: PromptType,
max_out_len: int = 8192,
) -> str:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert isinstance(input, (str, PromptList))
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
msg_buffer, last_role = [], None
for item in input:
item['role'] = 'assistant' if item['role'] == 'BOT' else 'user'
if item['role'] != last_role and last_role is not None:
messages.append({
'content': '\n'.join(msg_buffer),
'role': last_role
})
msg_buffer = []
msg_buffer.append(item['prompt'])
last_role = item['role']
messages.append({
'content': '\n'.join(msg_buffer),
'role': last_role
})
date = get_current_time_gmt_format()
content_type = 'application/json'
accept = '*/*'
method = 'POST'
data = {
'id': "001",
'messages': messages,
'max_new_tokens': 8192,
'temperature': 0.0,
'presence_penalty': 0.0,
'frequency_penalty': 0.0,
'top_p': 1.0,
'top_k': -1,
}
for _ in range(self.retry):
signature_str = self.generate_signature(method=method,
accept=accept,
content_type=content_type,
date=date,
url_path=self.url_path)
headers = self.generate_header(content_type=content_type,
accept=accept,
date=date,
signature=signature_str)
try:
print(data)
response = requests.post(self.url, json=data, headers=headers)
except Exception as e:
print(e)
continue
try:
response = response.json()
except Exception as e:
print(e)
continue
print(response)
try:
return response['choices'][0]['message']['content']
except Exception as e:
print(e)
continue
raise RuntimeError(f'Failed to respond in {self.retry} retrys')