diff --git a/configs/eval_gpt3.5.py b/configs/eval_gpt3.5.py new file mode 100644 index 00000000..987ef47e --- /dev/null +++ b/configs/eval_gpt3.5.py @@ -0,0 +1,36 @@ +from mmengine.config import read_base +from opencompass.models import OpenAI +from opencompass.partitioners import NaivePartitioner +from opencompass.runners import LocalRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + # choose a list of datasets + from .datasets.collections.chat_medium import datasets + # and output the results in a choosen format + from .summarizers.medium import summarizer + + +api_meta_template = dict( + round=[ + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ], +) + +models = [ + dict(abbr='GPT-3.5-turbo-0613', + type=OpenAI, path='gpt-3.5-turbo-0613', + key='ENV', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + meta_template=api_meta_template, + query_per_second=1, + max_out_len=2048, max_seq_len=2048, batch_size=8), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalRunner, + max_num_workers=8, + task=dict(type=OpenICLInferTask)), +) diff --git a/opencompass/models/openai_api.py b/opencompass/models/openai_api.py index cc5e1f8d..364cc7fb 100644 --- a/opencompass/models/openai_api.py +++ b/opencompass/models/openai_api.py @@ -1,7 +1,11 @@ +import json import os from concurrent.futures import ThreadPoolExecutor +from threading import Lock from typing import Dict, List, Optional, Union +import requests + from opencompass.registry import MODELS from opencompass.utils.prompt import PromptList @@ -22,39 +26,54 @@ class OpenAI(BaseAPIModel): query_per_second (int): The maximum queries allowed per second between two consecutive calls of the API. Defaults to 1. retry (int): Number of retires if the API call fails. Defaults to 2. - key (str): OpenAI key. In particular, when it is set to "ENV", the key - will be fetched from the environment variable $OPENAI_API_KEY, as - how openai defaults to be. Defaults to 'ENV' + key (str or List[str]): OpenAI key(s). In particular, when it + is set to "ENV", the key will be fetched from the environment + variable $OPENAI_API_KEY, as how openai defaults to be. If it's a + list, the keys will be used in round-robin manner. Defaults to + 'ENV'. + org (str or List[str], optional): OpenAI organization(s). If not + specified, OpenAI uses the default organization bound to each API + key. If specified, the orgs will be posted with each request in + round-robin manner. Defaults to None. meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. openai_api_base (str): The base url of OpenAI's API. Defaults to - 'https://api.openai.com/v1'. + 'https://api.openai.com/v1/chat/completions'. """ is_api: bool = True - def __init__(self, - path: str, - max_seq_len: int = 2048, - query_per_second: int = 1, - retry: int = 2, - key: str = 'ENV', - meta_template: Optional[Dict] = None, - openai_api_base: str = 'https://api.openai.com/v1'): + def __init__( + self, + path: str, + max_seq_len: int = 2048, + query_per_second: int = 1, + retry: int = 2, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = None, + openai_api_base: str = 'https://api.openai.com/v1/chat/completions' + ): # noqa super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, query_per_second=query_per_second, retry=retry) - import openai import tiktoken - self.openai = openai self.tiktoken = tiktoken - self.openai.api_key = os.getenv( - 'OPENAI_API_KEY') if key == 'ENV' else key - self.openai.api_rase = openai_api_base + if isinstance(key, str): + self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] + else: + self.keys = key + self.key_ctr = 0 + if isinstance(org, str): + self.orgs = [org] + else: + self.orgs = org + self.org_ctr = 0 + self.url = openai_api_base def generate( self, @@ -103,9 +122,6 @@ class OpenAI(BaseAPIModel): """ assert isinstance(input, (str, PromptList)) - # max num token for gpt-3.5-turbo is 4097 - max_out_len = min(max_out_len, 4000 - self.get_token_len(str(input))) - if isinstance(input, str): messages = [{'role': 'user', 'content': input}] else: @@ -120,11 +136,32 @@ class OpenAI(BaseAPIModel): msg['role'] = 'system' messages.append(msg) + # max num token for gpt-3.5-turbo is 4097 + max_out_len = min(max_out_len, 4000 - self.get_token_len(str(input))) + if max_out_len <= 0: + return '' + max_num_retries = 0 while max_num_retries < self.retry: self.wait() + if hasattr(self, 'keys'): + with Lock(): + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + header = { + 'Authorization': f'Bearer {self.keys[self.key_ctr]}', + 'content-type': 'application/json', + } + if self.orgs: + with Lock(): + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + try: - response = self.openai.ChatCompletion.create( + data = dict( model=self.path, messages=messages, max_tokens=max_out_len, @@ -132,12 +169,28 @@ class OpenAI(BaseAPIModel): stop=None, temperature=temperature, ) - except self.openai.error.RateLimitError: - max_num_retries -= 1 + raw_response = requests.post(self.url, + headers=header, + data=json.dumps(data)) + except requests.ConnectionError: + self.logger.error('Got connection error, retrying...') + continue + try: + response = raw_response.json() + except requests.JSONDecodeError: + self.logger.error('JsonDecode error, got', + str(raw_response.content)) + try: + return response['choices'][0]['message']['content'].strip() + except KeyError: + if 'error' in response: + self.logger.error('Find error message in response: ', + str(response['error'])) max_num_retries += 1 - result = response.choices[0].message.content.strip() - return result + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + 'details.') def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized string. Only English and Chinese