[Feature] Added extra_body support for OpenAISDK; Added support for proxy URL when connecting to OpenAI's API. (#1467)

* fix lint issues

* fix lint issues
This commit is contained in:
Alexander Lam 2024-08-29 00:43:43 +08:00 committed by GitHub
parent a488b9b4f5
commit 8b39225259
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
from threading import Lock from threading import Lock
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import httpx
import jieba import jieba
import requests import requests
@ -46,6 +47,10 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions. wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'. 'https://api.openai.com/v1/chat/completions'.
openai_proxy_url (str, optional): An optional proxy url to use when
connecting to OpenAI's API. When set to 'ENV', the url will be
fetched from the environment variable $OPENAI_PROXY_URL.
Defaults to None.
mode (str, optional): The method of input truncation when input length mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'front','mid' and 'rear' represents the part exceeds max_seq_len. 'front','mid' and 'rear' represents the part
of input to truncate. Defaults to 'none'. of input to truncate. Defaults to 'none'.
@ -71,6 +76,7 @@ class OpenAI(BaseAPIModel):
org: Optional[Union[str, List[str]]] = None, org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
openai_api_base: str = OPENAI_API_BASE, openai_api_base: str = OPENAI_API_BASE,
openai_proxy_url: Optional[str] = None,
mode: str = 'none', mode: str = 'none',
logprobs: Optional[bool] = False, logprobs: Optional[bool] = False,
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
@ -116,6 +122,14 @@ class OpenAI(BaseAPIModel):
self.orgs = org self.orgs = org
self.org_ctr = 0 self.org_ctr = 0
self.url = openai_api_base self.url = openai_api_base
if openai_proxy_url == 'ENV':
if 'OPENAI_PROXY_URL' not in os.environ:
raise ValueError('OPENAI_PROXY_URL is not set.')
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
else:
self.proxy_url = openai_proxy_url
self.path = path self.path = path
def generate(self, def generate(self,
@ -258,9 +272,24 @@ class OpenAI(BaseAPIModel):
url = self.url[random.randint(0, len(self.url) - 1)] url = self.url[random.randint(0, len(self.url) - 1)]
else: else:
url = self.url url = self.url
raw_response = requests.post(url,
headers=header, if self.proxy_url is None:
data=json.dumps(data)) raw_response = requests.post(url,
headers=header,
data=json.dumps(data))
else:
proxies = {
'http': self.proxy_url,
'https': self.proxy_url,
}
raw_response = requests.post(
url,
headers=header,
data=json.dumps(data),
proxies=proxies,
)
except requests.ConnectionError: except requests.ConnectionError:
self.logger.error('Got connection error, retrying...') self.logger.error('Got connection error, retrying...')
continue continue
@ -394,6 +423,7 @@ class OpenAISDK(OpenAI):
org: str | List[str] | None = None, org: str | List[str] | None = None,
meta_template: Dict | None = None, meta_template: Dict | None = None,
openai_api_base: str = OPENAI_API_BASE, openai_api_base: str = OPENAI_API_BASE,
openai_proxy_url: Optional[str] = None,
mode: str = 'none', mode: str = 'none',
logprobs: bool | None = False, logprobs: bool | None = False,
top_logprobs: int | None = None, top_logprobs: int | None = None,
@ -401,11 +431,23 @@ class OpenAISDK(OpenAI):
tokenizer_path: str | None = None, tokenizer_path: str | None = None,
extra_body: Dict | None = None): extra_body: Dict | None = None):
super().__init__(path, max_seq_len, query_per_second, rpm_verbose, super().__init__(path, max_seq_len, query_per_second, rpm_verbose,
retry, key, org, meta_template, openai_api_base, mode, retry, key, org, meta_template, openai_api_base,
logprobs, top_logprobs, temperature, tokenizer_path, openai_proxy_url, mode, logprobs, top_logprobs,
extra_body) temperature, tokenizer_path, extra_body)
from openai import OpenAI from openai import OpenAI
self.opeanai_cleint = OpenAI(base_url=openai_api_base, api_key=key)
if self.proxy_url is None:
self.openai_client = OpenAI(base_url=openai_api_base, api_key=key)
else:
proxies = {
'http://': self.proxy_url,
'https://': self.proxy_url,
}
self.openai_client = OpenAI(
base_url=openai_api_base,
api_key=key,
http_client=httpx.Client(proxies=proxies))
def _generate(self, input: PromptList | str, max_out_len: int, def _generate(self, input: PromptList | str, max_out_len: int,
temperature: float) -> str: temperature: float) -> str:
@ -456,12 +498,14 @@ class OpenAISDK(OpenAI):
while num_retries < self.retry: while num_retries < self.retry:
self.wait() self.wait()
try: try:
responses = self.opeanai_cleint.chat.completions.create( responses = self.openai_client.chat.completions.create(
model=self.path, model=self.path,
max_tokens=max_out_len, max_tokens=max_out_len,
n=1, n=1,
temperature=self.temperature, temperature=self.temperature,
messages=messages) messages=messages,
extra_body=self.extra_body,
)
return responses.choices[0].message.content return responses.choices[0].message.content
except Exception as e: except Exception as e:
self.logger.error(e) self.logger.error(e)