[Feature] Added extra_body support for OpenAISDK; Added support for proxy URL when connecting to OpenAI's API. (#1467)

* fix lint issues

* fix lint issues
This commit is contained in:
Alexander Lam 2024-08-29 00:43:43 +08:00 committed by GitHub
parent a488b9b4f5
commit 8b39225259
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from typing import Dict, List, Optional, Union
import httpx
import jieba
import requests
@ -46,6 +47,10 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
openai_proxy_url (str, optional): An optional proxy url to use when
connecting to OpenAI's API. When set to 'ENV', the url will be
fetched from the environment variable $OPENAI_PROXY_URL.
Defaults to None.
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
of input to truncate. Defaults to 'none'.
@ -71,6 +76,7 @@ class OpenAI(BaseAPIModel):
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None,
openai_api_base: str = OPENAI_API_BASE,
openai_proxy_url: Optional[str] = None,
mode: str = 'none',
logprobs: Optional[bool] = False,
top_logprobs: Optional[int] = None,
@ -116,6 +122,14 @@ class OpenAI(BaseAPIModel):
self.orgs = org
self.org_ctr = 0
self.url = openai_api_base
if openai_proxy_url == 'ENV':
if 'OPENAI_PROXY_URL' not in os.environ:
raise ValueError('OPENAI_PROXY_URL is not set.')
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
else:
self.proxy_url = openai_proxy_url
self.path = path
def generate(self,
@ -258,9 +272,24 @@ class OpenAI(BaseAPIModel):
url = self.url[random.randint(0, len(self.url) - 1)]
else:
url = self.url
raw_response = requests.post(url,
headers=header,
data=json.dumps(data))
if self.proxy_url is None:
raw_response = requests.post(url,
headers=header,
data=json.dumps(data))
else:
proxies = {
'http': self.proxy_url,
'https': self.proxy_url,
}
raw_response = requests.post(
url,
headers=header,
data=json.dumps(data),
proxies=proxies,
)
except requests.ConnectionError:
self.logger.error('Got connection error, retrying...')
continue
@ -394,6 +423,7 @@ class OpenAISDK(OpenAI):
org: str | List[str] | None = None,
meta_template: Dict | None = None,
openai_api_base: str = OPENAI_API_BASE,
openai_proxy_url: Optional[str] = None,
mode: str = 'none',
logprobs: bool | None = False,
top_logprobs: int | None = None,
@ -401,11 +431,23 @@ class OpenAISDK(OpenAI):
tokenizer_path: str | None = None,
extra_body: Dict | None = None):
super().__init__(path, max_seq_len, query_per_second, rpm_verbose,
retry, key, org, meta_template, openai_api_base, mode,
logprobs, top_logprobs, temperature, tokenizer_path,
extra_body)
retry, key, org, meta_template, openai_api_base,
openai_proxy_url, mode, logprobs, top_logprobs,
temperature, tokenizer_path, extra_body)
from openai import OpenAI
self.opeanai_cleint = OpenAI(base_url=openai_api_base, api_key=key)
if self.proxy_url is None:
self.openai_client = OpenAI(base_url=openai_api_base, api_key=key)
else:
proxies = {
'http://': self.proxy_url,
'https://': self.proxy_url,
}
self.openai_client = OpenAI(
base_url=openai_api_base,
api_key=key,
http_client=httpx.Client(proxies=proxies))
def _generate(self, input: PromptList | str, max_out_len: int,
temperature: float) -> str:
@ -456,12 +498,14 @@ class OpenAISDK(OpenAI):
while num_retries < self.retry:
self.wait()
try:
responses = self.opeanai_cleint.chat.completions.create(
responses = self.openai_client.chat.completions.create(
model=self.path,
max_tokens=max_out_len,
n=1,
temperature=self.temperature,
messages=messages)
messages=messages,
extra_body=self.extra_body,
)
return responses.choices[0].message.content
except Exception as e:
self.logger.error(e)