From 8b39225259efa079dc35c5c105da392e576d42ba Mon Sep 17 00:00:00 2001 From: Alexander Lam Date: Thu, 29 Aug 2024 00:43:43 +0800 Subject: [PATCH] [Feature] Added `extra_body` support for OpenAISDK; Added support for proxy URL when connecting to OpenAI's API. (#1467) * fix lint issues * fix lint issues --- opencompass/models/openai_api.py | 62 +++++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/opencompass/models/openai_api.py b/opencompass/models/openai_api.py index 1ae48006..f572a846 100644 --- a/opencompass/models/openai_api.py +++ b/opencompass/models/openai_api.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor from threading import Lock from typing import Dict, List, Optional, Union +import httpx import jieba import requests @@ -46,6 +47,10 @@ class OpenAI(BaseAPIModel): wrapping of any meta instructions. openai_api_base (str): The base url of OpenAI's API. Defaults to 'https://api.openai.com/v1/chat/completions'. + openai_proxy_url (str, optional): An optional proxy url to use when + connecting to OpenAI's API. When set to 'ENV', the url will be + fetched from the environment variable $OPENAI_PROXY_URL. + Defaults to None. mode (str, optional): The method of input truncation when input length exceeds max_seq_len. 'front','mid' and 'rear' represents the part of input to truncate. Defaults to 'none'. @@ -71,6 +76,7 @@ class OpenAI(BaseAPIModel): org: Optional[Union[str, List[str]]] = None, meta_template: Optional[Dict] = None, openai_api_base: str = OPENAI_API_BASE, + openai_proxy_url: Optional[str] = None, mode: str = 'none', logprobs: Optional[bool] = False, top_logprobs: Optional[int] = None, @@ -116,6 +122,14 @@ class OpenAI(BaseAPIModel): self.orgs = org self.org_ctr = 0 self.url = openai_api_base + + if openai_proxy_url == 'ENV': + if 'OPENAI_PROXY_URL' not in os.environ: + raise ValueError('OPENAI_PROXY_URL is not set.') + self.proxy_url = os.getenv('OPENAI_PROXY_URL') + else: + self.proxy_url = openai_proxy_url + self.path = path def generate(self, @@ -258,9 +272,24 @@ class OpenAI(BaseAPIModel): url = self.url[random.randint(0, len(self.url) - 1)] else: url = self.url - raw_response = requests.post(url, - headers=header, - data=json.dumps(data)) + + if self.proxy_url is None: + raw_response = requests.post(url, + headers=header, + data=json.dumps(data)) + else: + proxies = { + 'http': self.proxy_url, + 'https': self.proxy_url, + } + + raw_response = requests.post( + url, + headers=header, + data=json.dumps(data), + proxies=proxies, + ) + except requests.ConnectionError: self.logger.error('Got connection error, retrying...') continue @@ -394,6 +423,7 @@ class OpenAISDK(OpenAI): org: str | List[str] | None = None, meta_template: Dict | None = None, openai_api_base: str = OPENAI_API_BASE, + openai_proxy_url: Optional[str] = None, mode: str = 'none', logprobs: bool | None = False, top_logprobs: int | None = None, @@ -401,11 +431,23 @@ class OpenAISDK(OpenAI): tokenizer_path: str | None = None, extra_body: Dict | None = None): super().__init__(path, max_seq_len, query_per_second, rpm_verbose, - retry, key, org, meta_template, openai_api_base, mode, - logprobs, top_logprobs, temperature, tokenizer_path, - extra_body) + retry, key, org, meta_template, openai_api_base, + openai_proxy_url, mode, logprobs, top_logprobs, + temperature, tokenizer_path, extra_body) from openai import OpenAI - self.opeanai_cleint = OpenAI(base_url=openai_api_base, api_key=key) + + if self.proxy_url is None: + self.openai_client = OpenAI(base_url=openai_api_base, api_key=key) + else: + proxies = { + 'http://': self.proxy_url, + 'https://': self.proxy_url, + } + + self.openai_client = OpenAI( + base_url=openai_api_base, + api_key=key, + http_client=httpx.Client(proxies=proxies)) def _generate(self, input: PromptList | str, max_out_len: int, temperature: float) -> str: @@ -456,12 +498,14 @@ class OpenAISDK(OpenAI): while num_retries < self.retry: self.wait() try: - responses = self.opeanai_cleint.chat.completions.create( + responses = self.openai_client.chat.completions.create( model=self.path, max_tokens=max_out_len, n=1, temperature=self.temperature, - messages=messages) + messages=messages, + extra_body=self.extra_body, + ) return responses.choices[0].message.content except Exception as e: self.logger.error(e)