mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
370 lines
13 KiB
Python
370 lines
13 KiB
Python
import contextlib
|
|
import multiprocessing
|
|
import os
|
|
import re
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import jieba
|
|
import weakref
|
|
from typing import Literal, Tuple, Iterable
|
|
|
|
from opencompass.utils.prompt import PromptList
|
|
from opencompass.models.base_api import AsyncTokenBucket, BaseAPIModel
|
|
|
|
import threading
|
|
import asyncio
|
|
from typing import cast
|
|
from contextlib import contextmanager
|
|
|
|
|
|
PromptType = Union[PromptList, str]
|
|
OPENAI_API_BASE = os.path.join(
|
|
os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/'),
|
|
'chat/completions')
|
|
|
|
|
|
class _APIModelState:
|
|
_instance: Dict[str, weakref.ReferenceType["_APIModelState"]] = {}
|
|
_count: int
|
|
_concurrency: int
|
|
_locks = [threading.Lock(), multiprocessing.Lock()]
|
|
|
|
def __init__(self, *, name: str, concurrency: int, query_per_second=1) -> None:
|
|
self._name = name
|
|
self._count = 0
|
|
self._concurrency = concurrency
|
|
self._token_bucket = AsyncTokenBucket(rate=query_per_second)
|
|
|
|
self._count += 1
|
|
self._concurrency = max(1, self._concurrency // self._count)
|
|
|
|
@property
|
|
def concurrency(self) -> int:
|
|
# If update and concurrency are called simultaneously, the values
|
|
# returned here may be inaccurate, but the impact is likely minimal
|
|
return self._concurrency
|
|
|
|
async def acquire(self):
|
|
return await self._token_bucket.acquire()
|
|
|
|
@property
|
|
def rpm(self):
|
|
return self._token_bucket.rpm
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
@property
|
|
def count(self):
|
|
return self._count
|
|
|
|
@classmethod
|
|
def _cleanup(cls, ref: weakref.ReferenceType["_APIModelState"]):
|
|
with cls._lock():
|
|
self: _APIModelState = ref() # type: ignore
|
|
cls._instance.pop(self._name)
|
|
|
|
def __new__(cls, name: str, *args, **kwargs) -> "_APIModelState":
|
|
with cls._lock():
|
|
if name not in cls._instance:
|
|
self = super().__new__(cls)
|
|
cls._instance[name] = weakref.ref(self, cls._cleanup)
|
|
return cls._instance[name]() # type: ignore
|
|
|
|
@classmethod
|
|
@contextmanager
|
|
def _lock(cls):
|
|
with contextlib.ExitStack() as stack:
|
|
[stack.enter_context(lock) for lock in cls._locks]
|
|
yield
|
|
|
|
|
|
|
|
class AsyncOpenAISDK(BaseAPIModel):
|
|
states: Dict[str, _APIModelState] = {}
|
|
|
|
def __init__(
|
|
self,
|
|
path: str = 'gpt-3.5-turbo',
|
|
max_seq_len: int | None = None, # type: ignore
|
|
query_per_second: int = 1,
|
|
retry: int = 2,
|
|
key: str = 'ENV',
|
|
org: str | List[str] | None = None,
|
|
meta_template: Dict | None = None,
|
|
openai_api_base: str = OPENAI_API_BASE,
|
|
openai_proxy_url: Optional[str] = None,
|
|
mode: Literal['none', 'front', 'mid', 'rear'] = 'none',
|
|
logprobs: bool | None = False,
|
|
top_logprobs: int | None = None,
|
|
temperature: float | None = None,
|
|
tokenizer_path: str | None = None,
|
|
extra_body: Dict | None = None,
|
|
max_completion_tokens: int = 16384,
|
|
verbose: bool = False,
|
|
concurrency: int = 64,
|
|
status_code_mappings: dict = {},
|
|
):
|
|
from openai import AsyncOpenAI
|
|
|
|
assert mode in ['none', 'front', 'mid', 'rear']
|
|
self.mode = mode
|
|
state_key = self._get_state_key(api_base=openai_api_base, model_name=path)
|
|
if state_key not in AsyncOpenAISDK.states:
|
|
AsyncOpenAISDK.states[path] = _APIModelState(
|
|
name=state_key,
|
|
concurrency=concurrency,
|
|
query_per_second=query_per_second,
|
|
)
|
|
self.state = AsyncOpenAISDK.states[path]
|
|
self.openai_client = AsyncOpenAI(base_url=openai_api_base, api_key=key)
|
|
|
|
if max_seq_len is None:
|
|
if '16k' in path:
|
|
max_seq_len = 16384
|
|
elif 'gpt-4' in path:
|
|
max_seq_len = 8192
|
|
elif 'gpt-3.5' in path:
|
|
max_seq_len = 4097
|
|
else:
|
|
max_seq_len = 32768
|
|
else:
|
|
max_seq_len = max_seq_len
|
|
|
|
super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, retry=retry)
|
|
|
|
self.logprobs = logprobs
|
|
self.top_logprobs = top_logprobs
|
|
self.tokenizer_path = tokenizer_path
|
|
self.hf_tokenizer = None
|
|
self.extra_body = extra_body
|
|
self.max_completion_tokens = max_completion_tokens
|
|
self.temperature = temperature
|
|
self.openai_api_base = openai_api_base
|
|
self.concurrency = concurrency
|
|
|
|
self.status_code_mappings = status_code_mappings
|
|
|
|
if openai_proxy_url == 'ENV':
|
|
if 'OPENAI_PROXY_URL' not in os.environ:
|
|
raise ValueError('OPENAI_PROXY_URL is not set.')
|
|
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
|
|
else:
|
|
self.proxy_url = openai_proxy_url
|
|
|
|
async def generate(self, # type: ignore
|
|
inputs: Iterable[PromptType],
|
|
max_out_len: int = 512,
|
|
temperature: float = 0.7,
|
|
**kwargs) -> List[str]:
|
|
"""Generate results given a list of inputs.
|
|
|
|
Args:
|
|
inputs (List[PromptType]): A list of strings or PromptDicts.
|
|
The PromptDict should be organized in OpenCompass'
|
|
API format.
|
|
max_out_len (int): The maximum length of the output.
|
|
temperature (float): What sampling temperature to use,
|
|
between 0 and 2. Higher values like 0.8 will make the output
|
|
more random, while lower values like 0.2 will make it more
|
|
focused and deterministic. Defaults to 0.7.
|
|
|
|
Returns:
|
|
List[str]: A list of generated strings.
|
|
"""
|
|
if self.temperature is not None:
|
|
temperature = self.temperature
|
|
|
|
# TODO: This should be an AsyncGenerator if an real `AsyncInference` has been implemented
|
|
tasks_queue: List[asyncio.Future] = []
|
|
results_queue: List[Tuple[int, str]] = []
|
|
inputs_iter = enumerate(inputs)
|
|
|
|
data_stop = False
|
|
while not (data_stop and not tasks_queue):
|
|
concurrency = self.state.concurrency
|
|
|
|
if tasks_queue:
|
|
done, pending = await asyncio.wait(tasks_queue, return_when=asyncio.FIRST_COMPLETED)
|
|
tasks_queue = list(pending)
|
|
for queue in done:
|
|
result: Tuple[int, str] = queue.result()
|
|
results_queue.append(result)
|
|
|
|
while not data_stop and len(tasks_queue) < concurrency:
|
|
try:
|
|
index, _input = next(inputs_iter)
|
|
except StopIteration:
|
|
data_stop = True
|
|
break
|
|
tasks_queue.append(
|
|
asyncio.create_task(
|
|
self._generate(
|
|
input=_input,
|
|
max_out_len=self.max_completion_tokens or max_out_len,
|
|
temperature=temperature,
|
|
index=index,
|
|
)
|
|
)
|
|
)
|
|
results_queue.sort()
|
|
return [item[1] for item in results_queue]
|
|
|
|
async def generate_from_template(self, templates: List[PromptType], # type: ignore
|
|
max_out_len: int, **kwargs):
|
|
"""Generate completion from a list of templates.
|
|
|
|
Args:
|
|
templates (List[PromptType]): A list of templates.
|
|
max_out_len (int): The maximum length of the output.
|
|
"""
|
|
inputs = self.parse_template(templates, mode='gen') # type: ignore
|
|
return await self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
|
|
|
async def _generate(self, input: PromptList | str, max_out_len: int,
|
|
temperature: float, index: int) -> Tuple[int, str]:
|
|
from openai import APIStatusError, BadRequestError
|
|
assert isinstance(input, (str, PromptList))
|
|
|
|
# max num token for gpt-3.5-turbo is 4097
|
|
# Most models' token limits are above 32k
|
|
|
|
# will leave 100 tokens as prompt buffer, triggered if input is str
|
|
if isinstance(input, str) and self.mode != 'none':
|
|
context_window = self.max_seq_len
|
|
input = self.bin_trim(
|
|
input,
|
|
context_window - 100 - max_out_len,
|
|
cast(Literal['front', 'mid', 'rear'], self.mode),
|
|
)
|
|
|
|
if isinstance(input, str):
|
|
messages = [{'role': 'user', 'content': input}]
|
|
else:
|
|
messages = []
|
|
for item in input:
|
|
msg = {'content': item['prompt']}
|
|
if item['role'] == 'HUMAN':
|
|
msg['role'] = 'user'
|
|
elif item['role'] == 'BOT':
|
|
msg['role'] = 'assistant'
|
|
elif item['role'] == 'SYSTEM':
|
|
msg['role'] = 'system'
|
|
messages.append(msg)
|
|
|
|
|
|
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
|
# try:
|
|
# max_out_len = min(
|
|
# max_out_len,
|
|
# context_window - self.get_token_len(str(input)) - 100)
|
|
# except KeyError:
|
|
# max_out_len = max_out_len
|
|
# if max_out_len <= 0:
|
|
# return ''
|
|
|
|
num_retries = 0
|
|
while num_retries < self.retry:
|
|
await self.state.acquire()
|
|
|
|
query_data = dict(
|
|
model=self.path,
|
|
max_tokens=max_out_len,
|
|
n=1,
|
|
temperature=self.temperature,
|
|
messages=messages,
|
|
extra_body=self.extra_body,
|
|
timeout=600,
|
|
)
|
|
|
|
try:
|
|
if self.verbose:
|
|
self.logger.info('Start calling OpenAI API')
|
|
responses = await self.openai_client.chat.completions.create(**query_data)
|
|
|
|
if self.verbose:
|
|
self.logger.info(
|
|
'Successfully get response from OpenAI API')
|
|
try:
|
|
self.logger.info(responses)
|
|
except Exception as e: # noqa F841
|
|
pass
|
|
if not responses.choices:
|
|
self.logger.error(
|
|
'Response is empty, it is an internal server error \
|
|
from the API provider.')
|
|
return index, responses.choices[0].message.content
|
|
|
|
except (BadRequestError, APIStatusError) as e:
|
|
# Handle BadRequest status
|
|
# You can specify self.status_code_mappings to bypass \
|
|
# API sensitivity blocks
|
|
# For example: status_code_mappings={400: 'Input data \
|
|
# may contain inappropriate content.'}
|
|
status_code = e.status_code
|
|
if (status_code is not None
|
|
and status_code in self.status_code_mappings):
|
|
error_message = self.status_code_mappings[status_code]
|
|
self.logger.info(f'Status Code: {status_code},\n'
|
|
f'Original Error Message: {e},\n'
|
|
f'Return Message: {error_message} ')
|
|
return index, error_message
|
|
else:
|
|
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
|
|
except Exception as e:
|
|
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
|
|
num_retries += 1
|
|
raise RuntimeError('Calling OpenAI API failed after retrying for '
|
|
f'{self.retry} times. Check the logs for details.')
|
|
|
|
def _get_state_key(self, api_base: str, model_name: str):
|
|
return api_base + model_name
|
|
|
|
def bin_trim(self, prompt: str, num_token: int, mode: Literal['front', 'mid', 'rear']) -> str:
|
|
"""Get a suffix of prompt which is no longer than num_token tokens.
|
|
|
|
Args:
|
|
prompt (str): Input string.
|
|
num_token (int): The upper bound of token numbers.
|
|
|
|
Returns:
|
|
str: The trimmed prompt.
|
|
"""
|
|
token_len = self.get_token_len(prompt)
|
|
if token_len <= num_token:
|
|
return prompt
|
|
pattern = re.compile(r'[\u4e00-\u9fa5]')
|
|
if pattern.search(prompt):
|
|
words = list(jieba.cut(prompt, cut_all=False))
|
|
sep = ''
|
|
else:
|
|
words = prompt.split(' ')
|
|
sep = ' '
|
|
|
|
l, r = 1, len(words)
|
|
while l + 2 < r:
|
|
# mode: Literal['front', 'mid', 'rear'] = self.mode
|
|
mid = (l + r) // 2
|
|
if mode == 'front':
|
|
cur_prompt = sep.join(words[-mid:])
|
|
elif mode == 'mid':
|
|
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
|
|
elif mode == 'rear':
|
|
cur_prompt = sep.join(words[:mid])
|
|
|
|
if self.get_token_len(cur_prompt) <= num_token:
|
|
l = mid # noqa: E741
|
|
else:
|
|
r = mid
|
|
|
|
if self.mode == 'front':
|
|
prompt = sep.join(words[-l:])
|
|
elif self.mode == 'mid':
|
|
prompt = sep.join(words[:l]) + sep.join(words[-l:])
|
|
elif self.mode == 'rear':
|
|
prompt = sep.join(words[:l])
|
|
return prompt
|
|
|
|
|