mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Added extra_body
support for OpenAISDK; Added support for proxy URL when connecting to OpenAI's API. (#1467)
* fix lint issues * fix lint issues
This commit is contained in:
parent
a488b9b4f5
commit
8b39225259
@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import jieba
|
||||
import requests
|
||||
|
||||
@ -46,6 +47,10 @@ class OpenAI(BaseAPIModel):
|
||||
wrapping of any meta instructions.
|
||||
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
||||
'https://api.openai.com/v1/chat/completions'.
|
||||
openai_proxy_url (str, optional): An optional proxy url to use when
|
||||
connecting to OpenAI's API. When set to 'ENV', the url will be
|
||||
fetched from the environment variable $OPENAI_PROXY_URL.
|
||||
Defaults to None.
|
||||
mode (str, optional): The method of input truncation when input length
|
||||
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
|
||||
of input to truncate. Defaults to 'none'.
|
||||
@ -71,6 +76,7 @@ class OpenAI(BaseAPIModel):
|
||||
org: Optional[Union[str, List[str]]] = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
openai_proxy_url: Optional[str] = None,
|
||||
mode: str = 'none',
|
||||
logprobs: Optional[bool] = False,
|
||||
top_logprobs: Optional[int] = None,
|
||||
@ -116,6 +122,14 @@ class OpenAI(BaseAPIModel):
|
||||
self.orgs = org
|
||||
self.org_ctr = 0
|
||||
self.url = openai_api_base
|
||||
|
||||
if openai_proxy_url == 'ENV':
|
||||
if 'OPENAI_PROXY_URL' not in os.environ:
|
||||
raise ValueError('OPENAI_PROXY_URL is not set.')
|
||||
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
|
||||
else:
|
||||
self.proxy_url = openai_proxy_url
|
||||
|
||||
self.path = path
|
||||
|
||||
def generate(self,
|
||||
@ -258,9 +272,24 @@ class OpenAI(BaseAPIModel):
|
||||
url = self.url[random.randint(0, len(self.url) - 1)]
|
||||
else:
|
||||
url = self.url
|
||||
|
||||
if self.proxy_url is None:
|
||||
raw_response = requests.post(url,
|
||||
headers=header,
|
||||
data=json.dumps(data))
|
||||
else:
|
||||
proxies = {
|
||||
'http': self.proxy_url,
|
||||
'https': self.proxy_url,
|
||||
}
|
||||
|
||||
raw_response = requests.post(
|
||||
url,
|
||||
headers=header,
|
||||
data=json.dumps(data),
|
||||
proxies=proxies,
|
||||
)
|
||||
|
||||
except requests.ConnectionError:
|
||||
self.logger.error('Got connection error, retrying...')
|
||||
continue
|
||||
@ -394,6 +423,7 @@ class OpenAISDK(OpenAI):
|
||||
org: str | List[str] | None = None,
|
||||
meta_template: Dict | None = None,
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
openai_proxy_url: Optional[str] = None,
|
||||
mode: str = 'none',
|
||||
logprobs: bool | None = False,
|
||||
top_logprobs: int | None = None,
|
||||
@ -401,11 +431,23 @@ class OpenAISDK(OpenAI):
|
||||
tokenizer_path: str | None = None,
|
||||
extra_body: Dict | None = None):
|
||||
super().__init__(path, max_seq_len, query_per_second, rpm_verbose,
|
||||
retry, key, org, meta_template, openai_api_base, mode,
|
||||
logprobs, top_logprobs, temperature, tokenizer_path,
|
||||
extra_body)
|
||||
retry, key, org, meta_template, openai_api_base,
|
||||
openai_proxy_url, mode, logprobs, top_logprobs,
|
||||
temperature, tokenizer_path, extra_body)
|
||||
from openai import OpenAI
|
||||
self.opeanai_cleint = OpenAI(base_url=openai_api_base, api_key=key)
|
||||
|
||||
if self.proxy_url is None:
|
||||
self.openai_client = OpenAI(base_url=openai_api_base, api_key=key)
|
||||
else:
|
||||
proxies = {
|
||||
'http://': self.proxy_url,
|
||||
'https://': self.proxy_url,
|
||||
}
|
||||
|
||||
self.openai_client = OpenAI(
|
||||
base_url=openai_api_base,
|
||||
api_key=key,
|
||||
http_client=httpx.Client(proxies=proxies))
|
||||
|
||||
def _generate(self, input: PromptList | str, max_out_len: int,
|
||||
temperature: float) -> str:
|
||||
@ -456,12 +498,14 @@ class OpenAISDK(OpenAI):
|
||||
while num_retries < self.retry:
|
||||
self.wait()
|
||||
try:
|
||||
responses = self.opeanai_cleint.chat.completions.create(
|
||||
responses = self.openai_client.chat.completions.create(
|
||||
model=self.path,
|
||||
max_tokens=max_out_len,
|
||||
n=1,
|
||||
temperature=self.temperature,
|
||||
messages=messages)
|
||||
messages=messages,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return responses.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.logger.error(e)
|
||||
|
Loading…
Reference in New Issue
Block a user