[Feature] Support verbose for OpenAI API (#1546)

This commit is contained in:
Songyang Zhang 2024-09-20 17:12:52 +08:00 committed by GitHub
parent a81bbb85bf
commit ee058e25b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 11 deletions

View File

@ -43,7 +43,8 @@ class BaseAPIModel(BaseModel):
retry: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
generation_kwargs: Dict = dict()):
generation_kwargs: Dict = dict(),
verbose: bool = False):
self.path = path
self.max_seq_len = max_seq_len
self.meta_template = meta_template
@ -53,6 +54,7 @@ class BaseAPIModel(BaseModel):
self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger()
self.generation_kwargs = generation_kwargs
self.verbose = verbose
@abstractmethod
def generate(self, inputs: List[PromptType],

View File

@ -90,14 +90,16 @@ class OpenAI(BaseAPIModel):
temperature: Optional[float] = None,
tokenizer_path: Optional[str] = None,
extra_body: Optional[Dict] = None,
max_completion_tokens: int = 16384):
max_completion_tokens: int = 16384,
verbose: bool = False):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
query_per_second=query_per_second,
rpm_verbose=rpm_verbose,
retry=retry)
retry=retry,
verbose=verbose)
import tiktoken
self.tiktoken = tiktoken
self.temperature = temperature
@ -310,7 +312,9 @@ class OpenAI(BaseAPIModel):
'http': self.proxy_url,
'https': self.proxy_url,
}
if self.verbose:
self.logger.debug(
f'Start send query to {self.proxy_url}')
raw_response = requests.post(
url,
headers=header,
@ -318,6 +322,10 @@ class OpenAI(BaseAPIModel):
proxies=proxies,
)
if self.verbose:
self.logger.debug(
f'Get response from {self.proxy_url}')
except requests.ConnectionError:
self.logger.error('Got connection error, retrying...')
continue
@ -371,27 +379,44 @@ class OpenAI(BaseAPIModel):
"""
assert self.tokenizer_path or self.path
try:
if self.verbose:
self.logger.info(f'Used tokenizer_path: {self.tokenizer_path}')
tokenizer_path = self.tokenizer_path if self.tokenizer_path \
else self.path
try:
if self.verbose:
self.logger.info(
f'Start load tiktoken encoding: {tokenizer_path}')
enc = self.tiktoken.encoding_for_model(tokenizer_path)
if self.verbose:
self.logger.info(
f'Successfully tiktoken encoding: {tokenizer_path}')
return len(enc.encode(prompt))
except Exception as e:
self.logger.warn(f'{e}, tiktoken encoding cannot load '
f'{tokenizer_path}')
from transformers import AutoTokenizer
if self.hf_tokenizer is None:
if self.verbose:
self.logger.info(
f'Start load hf tokenizer: {tokenizer_path}')
self.hf_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True)
self.logger.info(
f'Tokenizer is loaded from {tokenizer_path}')
f'Successfully load HF Tokenizer from {tokenizer_path}'
)
return len(self.hf_tokenizer(prompt).input_ids)
except Exception:
self.logger.warn(
'Can not get tokenizer automatically, '
'will use default tokenizer gpt-4 for length calculation.')
default_tokenizer = 'gpt-4'
enc = self.tiktoken.encoding_for_model(default_tokenizer)
if self.verbose:
self.logger.info(
f'Successfully load default tiktoken tokenizer: '
f' {default_tokenizer}')
return len(enc.encode(prompt))
def bin_trim(self, prompt: str, num_token: int) -> str:
@ -458,12 +483,26 @@ class OpenAISDK(OpenAI):
temperature: float | None = None,
tokenizer_path: str | None = None,
extra_body: Dict | None = None,
max_completion_tokens: int = 16384):
super().__init__(path, max_seq_len, query_per_second, rpm_verbose,
retry, key, org, meta_template, openai_api_base,
openai_proxy_url, mode, logprobs, top_logprobs,
temperature, tokenizer_path, extra_body,
max_completion_tokens)
max_completion_tokens: int = 16384,
verbose: bool = False):
super().__init__(path,
max_seq_len,
query_per_second,
rpm_verbose,
retry,
key,
org,
meta_template,
openai_api_base,
openai_proxy_url,
mode,
logprobs,
top_logprobs,
temperature,
tokenizer_path,
extra_body,
verbose=verbose,
max_completion_tokens=max_completion_tokens)
from openai import OpenAI
if self.proxy_url is None:
@ -478,6 +517,8 @@ class OpenAISDK(OpenAI):
base_url=openai_api_base,
api_key=key,
http_client=httpx.Client(proxies=proxies))
if self.verbose:
self.logger.info(f'Used openai_client: {self.openai_client}')
def _generate(self, input: PromptList | str, max_out_len: int,
temperature: float) -> str:
@ -553,8 +594,13 @@ class OpenAISDK(OpenAI):
)
try:
if self.verbose:
self.logger.info('Start calling OpenAI API')
responses = self.openai_client.chat.completions.create(
**query_data)
if self.verbose:
self.logger.info(
'Successfully get response from OpenAI API')
return responses.choices[0].message.content
except Exception as e:
self.logger.error(e)

View File

@ -127,6 +127,7 @@ class GenInferencer(BaseInferencer):
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
logger.info('Starting build dataloader')
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
# 5. Inference for prompts in each batch