mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support verbose for OpenAI API (#1546)
This commit is contained in:
parent
a81bbb85bf
commit
ee058e25b2
@ -43,7 +43,8 @@ class BaseAPIModel(BaseModel):
|
||||
retry: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
generation_kwargs: Dict = dict()):
|
||||
generation_kwargs: Dict = dict(),
|
||||
verbose: bool = False):
|
||||
self.path = path
|
||||
self.max_seq_len = max_seq_len
|
||||
self.meta_template = meta_template
|
||||
@ -53,6 +54,7 @@ class BaseAPIModel(BaseModel):
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
self.logger = get_logger()
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.verbose = verbose
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, inputs: List[PromptType],
|
||||
|
@ -90,14 +90,16 @@ class OpenAI(BaseAPIModel):
|
||||
temperature: Optional[float] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
extra_body: Optional[Dict] = None,
|
||||
max_completion_tokens: int = 16384):
|
||||
max_completion_tokens: int = 16384,
|
||||
verbose: bool = False):
|
||||
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
query_per_second=query_per_second,
|
||||
rpm_verbose=rpm_verbose,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
verbose=verbose)
|
||||
import tiktoken
|
||||
self.tiktoken = tiktoken
|
||||
self.temperature = temperature
|
||||
@ -310,7 +312,9 @@ class OpenAI(BaseAPIModel):
|
||||
'http': self.proxy_url,
|
||||
'https': self.proxy_url,
|
||||
}
|
||||
|
||||
if self.verbose:
|
||||
self.logger.debug(
|
||||
f'Start send query to {self.proxy_url}')
|
||||
raw_response = requests.post(
|
||||
url,
|
||||
headers=header,
|
||||
@ -318,6 +322,10 @@ class OpenAI(BaseAPIModel):
|
||||
proxies=proxies,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.debug(
|
||||
f'Get response from {self.proxy_url}')
|
||||
|
||||
except requests.ConnectionError:
|
||||
self.logger.error('Got connection error, retrying...')
|
||||
continue
|
||||
@ -371,27 +379,44 @@ class OpenAI(BaseAPIModel):
|
||||
"""
|
||||
assert self.tokenizer_path or self.path
|
||||
try:
|
||||
if self.verbose:
|
||||
self.logger.info(f'Used tokenizer_path: {self.tokenizer_path}')
|
||||
tokenizer_path = self.tokenizer_path if self.tokenizer_path \
|
||||
else self.path
|
||||
try:
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f'Start load tiktoken encoding: {tokenizer_path}')
|
||||
enc = self.tiktoken.encoding_for_model(tokenizer_path)
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f'Successfully tiktoken encoding: {tokenizer_path}')
|
||||
return len(enc.encode(prompt))
|
||||
except Exception as e:
|
||||
self.logger.warn(f'{e}, tiktoken encoding cannot load '
|
||||
f'{tokenizer_path}')
|
||||
from transformers import AutoTokenizer
|
||||
if self.hf_tokenizer is None:
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f'Start load hf tokenizer: {tokenizer_path}')
|
||||
self.hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, trust_remote_code=True)
|
||||
self.logger.info(
|
||||
f'Tokenizer is loaded from {tokenizer_path}')
|
||||
f'Successfully load HF Tokenizer from {tokenizer_path}'
|
||||
)
|
||||
return len(self.hf_tokenizer(prompt).input_ids)
|
||||
except Exception:
|
||||
self.logger.warn(
|
||||
'Can not get tokenizer automatically, '
|
||||
'will use default tokenizer gpt-4 for length calculation.')
|
||||
default_tokenizer = 'gpt-4'
|
||||
|
||||
enc = self.tiktoken.encoding_for_model(default_tokenizer)
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f'Successfully load default tiktoken tokenizer: '
|
||||
f' {default_tokenizer}')
|
||||
return len(enc.encode(prompt))
|
||||
|
||||
def bin_trim(self, prompt: str, num_token: int) -> str:
|
||||
@ -458,12 +483,26 @@ class OpenAISDK(OpenAI):
|
||||
temperature: float | None = None,
|
||||
tokenizer_path: str | None = None,
|
||||
extra_body: Dict | None = None,
|
||||
max_completion_tokens: int = 16384):
|
||||
super().__init__(path, max_seq_len, query_per_second, rpm_verbose,
|
||||
retry, key, org, meta_template, openai_api_base,
|
||||
openai_proxy_url, mode, logprobs, top_logprobs,
|
||||
temperature, tokenizer_path, extra_body,
|
||||
max_completion_tokens)
|
||||
max_completion_tokens: int = 16384,
|
||||
verbose: bool = False):
|
||||
super().__init__(path,
|
||||
max_seq_len,
|
||||
query_per_second,
|
||||
rpm_verbose,
|
||||
retry,
|
||||
key,
|
||||
org,
|
||||
meta_template,
|
||||
openai_api_base,
|
||||
openai_proxy_url,
|
||||
mode,
|
||||
logprobs,
|
||||
top_logprobs,
|
||||
temperature,
|
||||
tokenizer_path,
|
||||
extra_body,
|
||||
verbose=verbose,
|
||||
max_completion_tokens=max_completion_tokens)
|
||||
from openai import OpenAI
|
||||
|
||||
if self.proxy_url is None:
|
||||
@ -478,6 +517,8 @@ class OpenAISDK(OpenAI):
|
||||
base_url=openai_api_base,
|
||||
api_key=key,
|
||||
http_client=httpx.Client(proxies=proxies))
|
||||
if self.verbose:
|
||||
self.logger.info(f'Used openai_client: {self.openai_client}')
|
||||
|
||||
def _generate(self, input: PromptList | str, max_out_len: int,
|
||||
temperature: float) -> str:
|
||||
@ -553,8 +594,13 @@ class OpenAISDK(OpenAI):
|
||||
)
|
||||
|
||||
try:
|
||||
if self.verbose:
|
||||
self.logger.info('Start calling OpenAI API')
|
||||
responses = self.openai_client.chat.completions.create(
|
||||
**query_data)
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
'Successfully get response from OpenAI API')
|
||||
return responses.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.logger.error(e)
|
||||
|
@ -127,6 +127,7 @@ class GenInferencer(BaseInferencer):
|
||||
index = len(tmp_result_dict)
|
||||
|
||||
# 4. Wrap prompts with Dataloader
|
||||
logger.info('Starting build dataloader')
|
||||
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
|
Loading…
Reference in New Issue
Block a user