[Feature] Enhance OpenAI API, add example config for GPT evaluation (#53)

* [Feature] Enhance OpenAI API, add example config for GPT evaluation

* fix
This commit is contained in:
Tong Gao 2023-07-12 16:43:46 +08:00 committed by GitHub
parent f5103f93dd
commit 7ee5a86fee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 25 deletions

36
configs/eval_gpt3.5.py Normal file
View File

@ -0,0 +1,36 @@
from mmengine.config import read_base
from opencompass.models import OpenAI
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
# choose a list of datasets
from .datasets.collections.chat_medium import datasets
# and output the results in a choosen format
from .summarizers.medium import summarizer
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
],
)
models = [
dict(abbr='GPT-3.5-turbo-0613',
type=OpenAI, path='gpt-3.5-turbo-0613',
key='ENV', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=8),
]
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
max_num_workers=8,
task=dict(type=OpenICLInferTask)),
)

View File

@ -1,7 +1,11 @@
import json
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import requests
from opencompass.registry import MODELS from opencompass.registry import MODELS
from opencompass.utils.prompt import PromptList from opencompass.utils.prompt import PromptList
@ -22,39 +26,54 @@ class OpenAI(BaseAPIModel):
query_per_second (int): The maximum queries allowed per second query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1. between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2. retry (int): Number of retires if the API call fails. Defaults to 2.
key (str): OpenAI key. In particular, when it is set to "ENV", the key key (str or List[str]): OpenAI key(s). In particular, when it
will be fetched from the environment variable $OPENAI_API_KEY, as is set to "ENV", the key will be fetched from the environment
how openai defaults to be. Defaults to 'ENV' variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
list, the keys will be used in round-robin manner. Defaults to
'ENV'.
org (str or List[str], optional): OpenAI organization(s). If not
specified, OpenAI uses the default organization bound to each API
key. If specified, the orgs will be posted with each request in
round-robin manner. Defaults to None.
meta_template (Dict, optional): The model's meta prompt meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or template if needed, in case the requirement of injecting or
wrapping of any meta instructions. wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1'. 'https://api.openai.com/v1/chat/completions'.
""" """
is_api: bool = True is_api: bool = True
def __init__(self, def __init__(
path: str, self,
max_seq_len: int = 2048, path: str,
query_per_second: int = 1, max_seq_len: int = 2048,
retry: int = 2, query_per_second: int = 1,
key: str = 'ENV', retry: int = 2,
meta_template: Optional[Dict] = None, key: Union[str, List[str]] = 'ENV',
openai_api_base: str = 'https://api.openai.com/v1'): org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None,
openai_api_base: str = 'https://api.openai.com/v1/chat/completions'
): # noqa
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template, meta_template=meta_template,
query_per_second=query_per_second, query_per_second=query_per_second,
retry=retry) retry=retry)
import openai
import tiktoken import tiktoken
self.openai = openai
self.tiktoken = tiktoken self.tiktoken = tiktoken
self.openai.api_key = os.getenv( if isinstance(key, str):
'OPENAI_API_KEY') if key == 'ENV' else key self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
self.openai.api_rase = openai_api_base else:
self.keys = key
self.key_ctr = 0
if isinstance(org, str):
self.orgs = [org]
else:
self.orgs = org
self.org_ctr = 0
self.url = openai_api_base
def generate( def generate(
self, self,
@ -103,9 +122,6 @@ class OpenAI(BaseAPIModel):
""" """
assert isinstance(input, (str, PromptList)) assert isinstance(input, (str, PromptList))
# max num token for gpt-3.5-turbo is 4097
max_out_len = min(max_out_len, 4000 - self.get_token_len(str(input)))
if isinstance(input, str): if isinstance(input, str):
messages = [{'role': 'user', 'content': input}] messages = [{'role': 'user', 'content': input}]
else: else:
@ -120,11 +136,32 @@ class OpenAI(BaseAPIModel):
msg['role'] = 'system' msg['role'] = 'system'
messages.append(msg) messages.append(msg)
# max num token for gpt-3.5-turbo is 4097
max_out_len = min(max_out_len, 4000 - self.get_token_len(str(input)))
if max_out_len <= 0:
return ''
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.wait() self.wait()
if hasattr(self, 'keys'):
with Lock():
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0
header = {
'Authorization': f'Bearer {self.keys[self.key_ctr]}',
'content-type': 'application/json',
}
if self.orgs:
with Lock():
self.org_ctr += 1
if self.org_ctr == len(self.orgs):
self.org_ctr = 0
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
try: try:
response = self.openai.ChatCompletion.create( data = dict(
model=self.path, model=self.path,
messages=messages, messages=messages,
max_tokens=max_out_len, max_tokens=max_out_len,
@ -132,12 +169,28 @@ class OpenAI(BaseAPIModel):
stop=None, stop=None,
temperature=temperature, temperature=temperature,
) )
except self.openai.error.RateLimitError: raw_response = requests.post(self.url,
max_num_retries -= 1 headers=header,
data=json.dumps(data))
except requests.ConnectionError:
self.logger.error('Got connection error, retrying...')
continue
try:
response = raw_response.json()
except requests.JSONDecodeError:
self.logger.error('JsonDecode error, got',
str(raw_response.content))
try:
return response['choices'][0]['message']['content'].strip()
except KeyError:
if 'error' in response:
self.logger.error('Find error message in response: ',
str(response['error']))
max_num_retries += 1 max_num_retries += 1
result = response.choices[0].message.content.strip() raise RuntimeError('Calling OpenAI failed after retrying for '
return result f'{max_num_retries} times. Check the logs for '
'details.')
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese """Get lengths of the tokenized string. Only English and Chinese