import base64 import hashlib import hmac import random import string import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Dict, List, Optional, Union import requests from opencompass.utils.prompt import PromptList from .base_api import BaseAPIModel PromptType = Union[PromptList, str] def generate_random_string(length=16): """生成随机串. :param length: 随机串长度,默认为 16 :return: 随机串 """ letters = string.ascii_letters + string.digits rand_str = ''.join(random.choice(letters) for i in range(length)) return rand_str def get_current_time(format='%Y-%m-%d %H:%M:%S'): """获取当前时间. :param format: 时间格式,默认为 '%H:%M:%S' :return: 当前时间字符串 """ now = datetime.now() time_str = now.strftime(format) return time_str def get_current_timestamp(): """ 获取当前时间时间戳 :return: """ timestamp_str = int(round(time.time() * 1000)) return str(timestamp_str) def encode_base64_string(s): """对字符串进行 Base64 编码. :param s: 字符串 :return: 编码后的字符串 """ encoded = base64.b64encode(s).decode() return encoded def get_current_time_gmt_format(): """ 获取当前时间的GMT 时间 :return: """ GMT_FORMAT = '%a, %d %b %Y %H:%M:%SGMT+00:00' now = datetime.now() time_str = now.strftime(GMT_FORMAT) return time_str class Yayi(BaseAPIModel): """Model wrapper around Yayi. Args: path (str): The name of Yayi model. url (str): The base URL for the API. url_path (str): The specific path for the API endpoint. x_tilake_app_key (str): The application key for authentication. x_tilake_app_secret (str): The application secret for authentication. x_tilake_ca_sginature_method (str): The signature method for authentication. query_per_second (int): The maximum queries allowed per second between two consecutive calls of the API. Defaults to 2. max_seq_len (int): The maximum sequence length. Defaults to 8192. meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. retry (int): Number of retries if the API call fails. Defaults to 2. temperature (float): The temperature for the model's response. Defaults to 0.0. """ def __init__( self, path: str, url: str, url_path: str, x_tilake_app_key: str, x_tilake_app_secret: str, x_tilake_ca_sginature_method: str, query_per_second: int = 2, max_seq_len: int = 8192, meta_template: Optional[Dict] = None, retry: int = 2, temperature: float = 0.0, ): super().__init__( path=path, max_seq_len=max_seq_len, query_per_second=query_per_second, meta_template=meta_template, retry=retry, ) self.url = url self.url_path = url_path self.X_TILAKE_APP_KEY = x_tilake_app_key self.X_TILAKE_APP_SECRET = x_tilake_app_secret self.X_TILAKE_CA_SGINATURE_METHOD = x_tilake_ca_sginature_method self.temperature = temperature self.model = path def generate_signature(self, method, accept, content_type, date, url_path): """ 生成签名. """ string_to_sign = (method + '\n' + accept + '\n' + content_type + '\n' + date + '\n' + url_path) string_to_sign = string_to_sign.encode('utf-8') secret_key = self.X_TILAKE_APP_SECRET.encode('utf-8') signature = hmac.new(secret_key, string_to_sign, hashlib.sha256).digest() return encode_base64_string(signature) def generate_header(self, content_type, accept, date, signature): """ 生成请求头参数. """ headers = { 'x-tilake-app-key': self.X_TILAKE_APP_KEY, 'x-tilake-ca-signature-method': self.X_TILAKE_CA_SGINATURE_METHOD, 'x-tilake-ca-timestamp': get_current_timestamp(), 'x-tilake-ca-nonce': generate_random_string(), 'x-tilake-ca-signature': signature, 'Date': date, 'Content-Type': content_type, 'Accept': accept, } return headers def generate( self, inputs: List[PromptType], max_out_len: int = 8192, ) -> List[str]: """Generate results given a list of inputs. Args: inputs (List[PromptType]): A list of strings or PromptDicts. The PromptDict should be organized in OpenCompass' API format. max_out_len (int): The maximum length of the output. Returns: List[str]: A list of generated strings. """ with ThreadPoolExecutor() as executor: results = list( executor.map(self._generate, inputs, [max_out_len] * len(inputs))) self.flush() return results def _generate( self, input: PromptType, max_out_len: int = 8192, ) -> str: """Generate results given an input. Args: inputs (PromptType): A string or PromptDict. The PromptDict should be organized in OpenCompass' API format. max_out_len (int): The maximum length of the output. Returns: str: The generated string. """ assert isinstance(input, (str, PromptList)) if isinstance(input, str): messages = [{'role': 'user', 'content': input}] else: messages = [] msg_buffer, last_role = [], None for item in input: item['role'] = 'assistant' if item['role'] == 'BOT' else 'user' if item['role'] != last_role and last_role is not None: messages.append({ 'content': '\n'.join(msg_buffer), 'role': last_role }) msg_buffer = [] msg_buffer.append(item['prompt']) last_role = item['role'] messages.append({ 'content': '\n'.join(msg_buffer), 'role': last_role }) date = get_current_time_gmt_format() content_type = 'application/json' accept = '*/*' method = 'POST' data = { 'id': "001", 'messages': messages, 'max_new_tokens': 8192, 'temperature': 0.0, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'top_p': 1.0, 'top_k': -1, } for _ in range(self.retry): signature_str = self.generate_signature(method=method, accept=accept, content_type=content_type, date=date, url_path=self.url_path) headers = self.generate_header(content_type=content_type, accept=accept, date=date, signature=signature_str) try: print(data) response = requests.post(self.url, json=data, headers=headers) except Exception as e: print(e) continue try: response = response.json() except Exception as e: print(e) continue print(response) try: return response['choices'][0]['message']['content'] except Exception as e: print(e) continue raise RuntimeError(f'Failed to respond in {self.retry} retrys')