mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Added extra_body
support for OpenAISDK; Added support for proxy URL when connecting to OpenAI's API. (#1467)
* fix lint issues * fix lint issues
This commit is contained in:
parent
a488b9b4f5
commit
8b39225259
@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
import jieba
|
import jieba
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -46,6 +47,10 @@ class OpenAI(BaseAPIModel):
|
|||||||
wrapping of any meta instructions.
|
wrapping of any meta instructions.
|
||||||
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
||||||
'https://api.openai.com/v1/chat/completions'.
|
'https://api.openai.com/v1/chat/completions'.
|
||||||
|
openai_proxy_url (str, optional): An optional proxy url to use when
|
||||||
|
connecting to OpenAI's API. When set to 'ENV', the url will be
|
||||||
|
fetched from the environment variable $OPENAI_PROXY_URL.
|
||||||
|
Defaults to None.
|
||||||
mode (str, optional): The method of input truncation when input length
|
mode (str, optional): The method of input truncation when input length
|
||||||
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
|
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
|
||||||
of input to truncate. Defaults to 'none'.
|
of input to truncate. Defaults to 'none'.
|
||||||
@ -71,6 +76,7 @@ class OpenAI(BaseAPIModel):
|
|||||||
org: Optional[Union[str, List[str]]] = None,
|
org: Optional[Union[str, List[str]]] = None,
|
||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
openai_api_base: str = OPENAI_API_BASE,
|
openai_api_base: str = OPENAI_API_BASE,
|
||||||
|
openai_proxy_url: Optional[str] = None,
|
||||||
mode: str = 'none',
|
mode: str = 'none',
|
||||||
logprobs: Optional[bool] = False,
|
logprobs: Optional[bool] = False,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: Optional[int] = None,
|
||||||
@ -116,6 +122,14 @@ class OpenAI(BaseAPIModel):
|
|||||||
self.orgs = org
|
self.orgs = org
|
||||||
self.org_ctr = 0
|
self.org_ctr = 0
|
||||||
self.url = openai_api_base
|
self.url = openai_api_base
|
||||||
|
|
||||||
|
if openai_proxy_url == 'ENV':
|
||||||
|
if 'OPENAI_PROXY_URL' not in os.environ:
|
||||||
|
raise ValueError('OPENAI_PROXY_URL is not set.')
|
||||||
|
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
|
||||||
|
else:
|
||||||
|
self.proxy_url = openai_proxy_url
|
||||||
|
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
@ -258,9 +272,24 @@ class OpenAI(BaseAPIModel):
|
|||||||
url = self.url[random.randint(0, len(self.url) - 1)]
|
url = self.url[random.randint(0, len(self.url) - 1)]
|
||||||
else:
|
else:
|
||||||
url = self.url
|
url = self.url
|
||||||
raw_response = requests.post(url,
|
|
||||||
headers=header,
|
if self.proxy_url is None:
|
||||||
data=json.dumps(data))
|
raw_response = requests.post(url,
|
||||||
|
headers=header,
|
||||||
|
data=json.dumps(data))
|
||||||
|
else:
|
||||||
|
proxies = {
|
||||||
|
'http': self.proxy_url,
|
||||||
|
'https': self.proxy_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=header,
|
||||||
|
data=json.dumps(data),
|
||||||
|
proxies=proxies,
|
||||||
|
)
|
||||||
|
|
||||||
except requests.ConnectionError:
|
except requests.ConnectionError:
|
||||||
self.logger.error('Got connection error, retrying...')
|
self.logger.error('Got connection error, retrying...')
|
||||||
continue
|
continue
|
||||||
@ -394,6 +423,7 @@ class OpenAISDK(OpenAI):
|
|||||||
org: str | List[str] | None = None,
|
org: str | List[str] | None = None,
|
||||||
meta_template: Dict | None = None,
|
meta_template: Dict | None = None,
|
||||||
openai_api_base: str = OPENAI_API_BASE,
|
openai_api_base: str = OPENAI_API_BASE,
|
||||||
|
openai_proxy_url: Optional[str] = None,
|
||||||
mode: str = 'none',
|
mode: str = 'none',
|
||||||
logprobs: bool | None = False,
|
logprobs: bool | None = False,
|
||||||
top_logprobs: int | None = None,
|
top_logprobs: int | None = None,
|
||||||
@ -401,11 +431,23 @@ class OpenAISDK(OpenAI):
|
|||||||
tokenizer_path: str | None = None,
|
tokenizer_path: str | None = None,
|
||||||
extra_body: Dict | None = None):
|
extra_body: Dict | None = None):
|
||||||
super().__init__(path, max_seq_len, query_per_second, rpm_verbose,
|
super().__init__(path, max_seq_len, query_per_second, rpm_verbose,
|
||||||
retry, key, org, meta_template, openai_api_base, mode,
|
retry, key, org, meta_template, openai_api_base,
|
||||||
logprobs, top_logprobs, temperature, tokenizer_path,
|
openai_proxy_url, mode, logprobs, top_logprobs,
|
||||||
extra_body)
|
temperature, tokenizer_path, extra_body)
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
self.opeanai_cleint = OpenAI(base_url=openai_api_base, api_key=key)
|
|
||||||
|
if self.proxy_url is None:
|
||||||
|
self.openai_client = OpenAI(base_url=openai_api_base, api_key=key)
|
||||||
|
else:
|
||||||
|
proxies = {
|
||||||
|
'http://': self.proxy_url,
|
||||||
|
'https://': self.proxy_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.openai_client = OpenAI(
|
||||||
|
base_url=openai_api_base,
|
||||||
|
api_key=key,
|
||||||
|
http_client=httpx.Client(proxies=proxies))
|
||||||
|
|
||||||
def _generate(self, input: PromptList | str, max_out_len: int,
|
def _generate(self, input: PromptList | str, max_out_len: int,
|
||||||
temperature: float) -> str:
|
temperature: float) -> str:
|
||||||
@ -456,12 +498,14 @@ class OpenAISDK(OpenAI):
|
|||||||
while num_retries < self.retry:
|
while num_retries < self.retry:
|
||||||
self.wait()
|
self.wait()
|
||||||
try:
|
try:
|
||||||
responses = self.opeanai_cleint.chat.completions.create(
|
responses = self.openai_client.chat.completions.create(
|
||||||
model=self.path,
|
model=self.path,
|
||||||
max_tokens=max_out_len,
|
max_tokens=max_out_len,
|
||||||
n=1,
|
n=1,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
messages=messages)
|
messages=messages,
|
||||||
|
extra_body=self.extra_body,
|
||||||
|
)
|
||||||
return responses.choices[0].message.content
|
return responses.choices[0].message.content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(e)
|
self.logger.error(e)
|
||||||
|
Loading…
Reference in New Issue
Block a user