import concurrent import concurrent.futures import os import socket import traceback from typing import Dict, List, Optional, Union import requests from requests.adapters import HTTPAdapter from retrying import retry from urllib3.connection import HTTPConnection from opencompass.utils.prompt import PromptList from .base_api import BaseAPIModel PromptType = Union[PromptList, str] class HTTPAdapterWithSocketOptions(HTTPAdapter): def __init__(self, *args, **kwargs): self._socket_options = HTTPConnection.default_socket_options + [ (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), (socket.SOL_TCP, socket.TCP_KEEPIDLE, 75), (socket.SOL_TCP, socket.TCP_KEEPINTVL, 30), (socket.SOL_TCP, socket.TCP_KEEPCNT, 120), ] super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs) def init_poolmanager(self, *args, **kwargs): if self._socket_options is not None: kwargs["socket_options"] = self._socket_options super(HTTPAdapterWithSocketOptions, self).init_poolmanager(*args, **kwargs) class BailingAPI(BaseAPIModel): """Model wrapper around Bailing Service. Args: ouput_key (str): key for prediction query_per_second (int): The maximum queries allowed per second between two consecutive calls of the API. Defaults to 1. generation_kwargs: other params retry (int): Number of retires if the API call fails. Defaults to 2. """ def __init__( self, path: str, token: str, url: str, meta_template: Optional[Dict] = None, query_per_second: int = 1, retry: int = 3, generation_kwargs: Dict = {}, max_seq_len=4096, ): super().__init__( path=path, max_seq_len=max_seq_len, query_per_second=query_per_second, meta_template=meta_template, retry=retry, generation_kwargs=generation_kwargs, ) self.logger.info(f"Bailing API Model Init path: {path} url={url}") if not token: token = os.environ.get("BAILING_API_KEY") if token: self._headers = {"Authorization": f"Bearer {token}"} else: raise RuntimeError(f"There is not valid token.") self._headers["Content-Type"] = "application/json" self._url = url if url else "https://bailingchat.alipay.com/chat/completions" self._model = path self._sessions = [] self._num = ( int(os.environ.get("BAILING_API_PARALLEL_NUM")) if os.environ.get("BAILING_API_PARALLEL_NUM") else 1 ) try: for _ in range(self._num): adapter = HTTPAdapterWithSocketOptions() sess = requests.Session() sess.mount("http://", adapter) sess.mount("https://", adapter) self._sessions.append(sess) except Exception as e: self.logger.error(f"Fail to setup the session. {e}") raise e def generate( self, inputs: Union[List[str], PromptList], max_out_len: int = 4096, ) -> List[str]: """Generate results given a list of inputs. Args: inputs (Union[List[str], PromptList]): A list of strings or PromptDicts. The PromptDict should be organized in OpenCompass' API format. max_out_len (int): The maximum length of the output. Returns: List[str]: A list of generated strings. """ with concurrent.futures.ThreadPoolExecutor( max_workers=self._num, ) as executor: future_to_m = { executor.submit( self._generate, self._sessions[i % self._num], input, max_out_len, ): i for i, input in enumerate(inputs) } results = [] for future in concurrent.futures.as_completed(future_to_m): m = future_to_m[future] resp = future.result() if resp and resp.status_code == 200: try: result = resp.json() except: results.append("") else: if ( result.get("choices") and result["choices"][0].get("message") and result["choices"][0]["message"].get("content") ): results.append(result["choices"][0]["message"]["content"]) else: results.append("") self.flush() return results def _generate( self, sess, input: Union[str, PromptList], max_out_len: int, ) -> str: """Generate results given an input. Args: inputs (str or PromptList): A string or PromptDict. The PromptDict should be organized in OpenCompass' API format. max_out_len (int): The maximum length of the output. Returns: str: The generated string. """ if isinstance(input, str): messages = [{"role": "user", "content": input}] else: messages = [] for item in input: content = item["prompt"] if not content: continue message = {"content": content} if item["role"] == "HUMAN": message["role"] = "user" elif item["role"] == "BOT": message["role"] = "assistant" elif item["role"] == "SYSTEM": message["role"] = "system" else: message["role"] = item["role"] messages.append(message) request = { "model": self._model, "messages": messages, "max_seq_len": max( max_out_len if max_out_len else 4096, self.max_seq_len if self.max_seq_len else 4096, ), } request.update(self.generation_kwargs) try: retry_num = 0 while retry_num < self.retry: response = self._infer_result(request, sess) if response.status_code == 200: break # success elif response.status_code == 426: retry_num += 1 # retry else: raise ValueError(f"Status code = {response.status_code}") else: raise ValueError( f"Exceed the maximal retry times. Last status code = {response.status_code}" ) except Exception as e: self.logger.error( f"Fail to inference request={request}; model_name={self.path}; error={e}, stack:{traceback.format_exc()}" ) raise e return response @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms def _infer_result(self, request, sess): response = sess.request( "POST", self._url, json=request, headers=self._headers, timeout=500, ) return response