OpenCompass/opencompass/models/base_api.py

463 lines
18 KiB
Python
Raw Normal View History

2023-07-04 21:34:55 +08:00
import re
import sys
2023-07-04 21:34:55 +08:00
import threading
2023-12-11 17:42:53 +08:00
import time
2023-07-04 21:34:55 +08:00
import warnings
from abc import abstractmethod
2023-07-04 21:34:55 +08:00
from copy import deepcopy
2023-12-11 17:42:53 +08:00
from queue import Queue
2023-07-04 21:34:55 +08:00
from time import sleep
from typing import Dict, List, Optional, Tuple, Union
from opencompass.utils import get_logger
from opencompass.utils.prompt import PromptList
from .base import BaseModel
PromptType = Union[PromptList, str]
class BaseAPIModel(BaseModel):
"""Base class for API model wrapper.
Args:
path (str): The path to the model.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
2023-07-04 21:34:55 +08:00
"""
is_api: bool = True
def __init__(self,
path: str,
query_per_second: int = 1,
2023-12-11 17:42:53 +08:00
rpm_verbose: bool = False,
2023-07-04 21:34:55 +08:00
retry: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
generation_kwargs: Dict = dict(),
verbose: bool = False):
2023-07-04 21:34:55 +08:00
self.path = path
self.max_seq_len = max_seq_len
self.meta_template = meta_template
self.retry = retry
self.query_per_second = query_per_second
2023-12-11 17:42:53 +08:00
self.token_bucket = TokenBucket(query_per_second, rpm_verbose)
2023-07-04 21:34:55 +08:00
self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger()
self.generation_kwargs = generation_kwargs
self.verbose = verbose
2023-07-04 21:34:55 +08:00
@abstractmethod
2023-07-04 21:34:55 +08:00
def generate(self, inputs: List[PromptType],
max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
2024-04-09 17:50:23 +08:00
inputs (List[PromptType]): A list of strings or PromptDicts.
2023-07-04 21:34:55 +08:00
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
raise NotImplementedError(f'{self.__class__.__name__} does not support'
' gen-based evaluation yet, try ppl-based '
'instead.')
2023-07-04 21:34:55 +08:00
def flush(self):
"""Ensure simultaneous emptying of stdout and stderr when concurrent
resources are available.
When employing multiprocessing with standard I/O redirected to files,
it is crucial to clear internal data for examination or prevent log
loss in case of system failures."
"""
if hasattr(self, 'tokens'):
sys.stdout.flush()
sys.stderr.flush()
def acquire(self):
"""Acquire concurrent resources if exists.
This behavior will fall back to wait with query_per_second if there are
no concurrent resources.
"""
if hasattr(self, 'tokens'):
self.tokens.acquire()
else:
self.wait()
def release(self):
"""Release concurrent resources if acquired.
This behavior will fall back to do nothing if there are no concurrent
resources.
"""
if hasattr(self, 'tokens'):
self.tokens.release()
@abstractmethod
2023-07-04 21:34:55 +08:00
def get_ppl(self,
inputs: List[PromptType],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Get perplexity scores given a list of inputs.
Args:
2024-04-09 17:50:23 +08:00
inputs (List[PromptType]): A list of strings.
2023-07-04 21:34:55 +08:00
mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out. It's okay to skip
its implementation if advanced features in PPLInfernecer is
not needed.
Returns:
List[float]: A list of perplexity scores.
"""
raise NotImplementedError(f'{self.__class__.__name__} does not support'
' ppl-based evaluation yet, try gen-based '
'instead.')
2023-07-04 21:34:55 +08:00
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
# Count English words
english_count = sum(len(part.split()) for part in english_parts)
# Count Chinese words
chinese_count = sum(len(part) for part in chinese_parts)
return english_count + chinese_count
def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
def to(self, device):
pass
class APITemplateParser:
"""Intermidate prompt template parser, specifically for API models.
Args:
meta_template (Dict): The meta template for the model.
"""
def __init__(self, meta_template: Optional[Dict] = None):
self.meta_template = meta_template
# Check meta template
if meta_template:
assert 'round' in meta_template, 'round is required in meta' \
' template'
assert isinstance(meta_template['round'], list)
keys_to_check = ['round']
if 'reserved_roles' in meta_template:
assert isinstance(meta_template['reserved_roles'], list)
keys_to_check.append('reserved_roles')
self.roles: Dict[str, dict] = dict() # maps role name to config
for meta_key in keys_to_check:
for item in meta_template[meta_key]:
assert isinstance(item, (str, dict))
if isinstance(item, dict):
assert item['role'] not in self.roles, \
'role in meta prompt must be unique!'
self.roles[item['role']] = item.copy()
def parse_template(self, prompt_template: PromptType,
mode: str) -> PromptType:
"""Parse the intermidate prompt template, and wrap it with meta
template if applicable. When the meta template is set and the input is
a PromptList, the return value will be a PromptList containing the full
conversation history. Each item looks like:
.. code-block:: python
{'role': 'user', 'prompt': '...'}).
Args:
2024-04-09 17:50:23 +08:00
prompt_template (List[PromptType]): An intermidate prompt
2023-07-04 21:34:55 +08:00
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns:
2024-04-09 17:50:23 +08:00
List[PromptType]: The finalized prompt or a conversation.
2023-07-04 21:34:55 +08:00
"""
assert isinstance(prompt_template, (str, list, PromptList, tuple))
2023-07-04 21:34:55 +08:00
if not isinstance(prompt_template, (str, PromptList)):
return [self.parse_template(p, mode=mode) for p in prompt_template]
assert mode in ['ppl', 'gen']
if isinstance(prompt_template, str):
return prompt_template
if self.meta_template:
prompt = PromptList()
# Whether to keep generating the prompt
generate = True
section_stack = [] # stores tuples: (section_name, start_idx)
for i, item in enumerate(prompt_template):
if not generate:
break
if isinstance(item, str):
if item.strip():
# TODO: logger
warnings.warn('Non-empty string in prompt template '
'will be ignored in API models.')
elif isinstance(item, dict) and 'section' in item:
if item['pos'] == 'end':
section_name, start_idx = section_stack.pop(-1)
assert section_name == item['section']
if section_name in ['round', 'ice']:
dialogue = prompt_template[start_idx:i]
round_ranges = self._split_rounds(
dialogue, self.meta_template['round'])
# Consider inserting multiple round examples into
# template
for i in range(len(round_ranges) - 1):
start = round_ranges[i]
end = round_ranges[i + 1]
round_template = dialogue[start:end]
role_dict = self._update_role_dict(
round_template)
api_prompts, generate = self._prompt2api(
self.meta_template['round'],
role_dict,
# Start generating only when the mode is in
# generation and the template reaches the
# last round
for_gen=mode == 'gen'
and section_name == 'round'
and i == len(round_ranges) - 2)
prompt += api_prompts
elif item['pos'] == 'begin':
assert item['section'] in [
'begin', 'round', 'end', 'ice'
]
section_stack.append((item['section'], i + 1))
else:
raise ValueError(f'Invalid pos {item["pos"]}')
elif section_stack[-1][0] in ['begin', 'end']:
role_dict = self._update_role_dict(item)
api_prompts, generate = self._prompt2api(
item, role_dict, for_gen=mode == 'gen')
prompt.append(api_prompts)
# merge the consecutive prompts assigned to the same role
new_prompt = PromptList([prompt[0]])
last_role = prompt[0]['role']
for item in prompt[1:]:
if item['role'] == last_role:
new_prompt[-1]['prompt'] += '\n' + item['prompt']
else:
last_role = item['role']
new_prompt.append(item)
prompt = new_prompt
if self.meta_template.get('begin', None):
prompt.insert(0, self.meta_template['begin'])
2023-07-04 21:34:55 +08:00
else:
# in case the model does not have any meta template
prompt = ''
last_sep = ''
for item in prompt_template:
if isinstance(item, dict) and set(['section', 'pos']) == set(
item.keys()):
continue
if isinstance(item, str):
if item:
prompt += last_sep + item
elif item.get('prompt', ''):
prompt += last_sep + item.get('prompt', '')
last_sep = '\n'
return prompt
def _update_role_dict(self, prompts: Union[List, str]) -> Dict[str, Dict]:
"""Update the default role dict with the given prompts."""
role_dict = deepcopy(self.roles)
if isinstance(prompts, str):
return role_dict
elif isinstance(prompts, dict):
prompts = [prompts]
for prompt in prompts:
if isinstance(prompt, dict):
role = prompt['role']
if role not in self.roles:
role = prompt.get('fallback_role', None)
if not role:
print(f'{prompt} neither has an appropriate role nor '
'a fallback role.')
role_dict[role].update(prompt)
return role_dict
def _split_rounds(
self, prompt_template: List[Union[str, Dict]],
single_round_template: List[Union[str, Dict]]) -> List[int]:
"""Split the prompt template into rounds, based on single round
template.
Return the index ranges of each round. Specifically,
prompt_template[res[i]:res[i+1]] represents the i-th round in the
template.
"""
role_idxs = {
role_cfg['role']: i
for i, role_cfg in enumerate(single_round_template)
if not isinstance(role_cfg, str)
}
last_role_idx = -1
cutoff_idxs = [0]
for idx, template in enumerate(prompt_template):
if isinstance(template, str):
continue
role_idx = role_idxs.get(template['role'], None)
if role_idx is None:
try:
role_idx = role_idxs[template['fallback_role']]
except KeyError:
raise KeyError(f'{template} neither has an appropriate '
'role nor a fallback role.')
if role_idx <= last_role_idx:
cutoff_idxs.append(idx)
last_role_idx = role_idx
cutoff_idxs.append(len(prompt_template))
return cutoff_idxs
def _prompt2api(self,
prompts: Union[List, str],
role_dict: Dict[str, Dict],
2024-07-18 13:41:24 +08:00
for_gen: bool = False) -> Tuple[List, bool]:
2023-07-04 21:34:55 +08:00
"""Convert the prompts to a API-style prompts, given an updated
role_dict.
Args:
prompts (Union[List, str]): The prompts to be converted.
role_dict (Dict[str, Dict]): The updated role dict.
for_gen (bool): If True, the prompts will be converted for
generation tasks. The conversion stops before the first
role whose "generate" is set to True.
Returns:
2024-07-18 13:41:24 +08:00
Tuple[List, bool]: The converted string, and whether the follow-up
2023-07-04 21:34:55 +08:00
conversion should be proceeded.
"""
cont = True
if isinstance(prompts, str):
return prompts, cont
elif isinstance(prompts, dict):
api_role, cont = self._role2api_role(prompts, role_dict, for_gen)
return api_role, cont
res = []
for prompt in prompts:
if isinstance(prompt, str):
raise TypeError('Mixing str without explicit role is not '
2023-07-04 21:34:55 +08:00
'allowed in API models!')
else:
api_role, cont = self._role2api_role(prompt, role_dict,
for_gen)
if api_role:
res.append(api_role)
if not cont:
break
return res, cont
def _role2api_role(self,
role_prompt: Dict,
role_dict: Dict[str, Dict],
2024-07-18 13:41:24 +08:00
for_gen: bool = False) -> Tuple[Dict, bool]:
2023-07-04 21:34:55 +08:00
"""Convert a role prompt to a string, given an updated role_dict.
Args:
role_prompt (Dict): The role prompt to be converted.
role_dict (Dict[str, Dict]): The updated role dict.
for_gen (bool): If True, the prompts will be converted for
generation tasks. The conversion stops before the first
role whose "generate" is set to True.
Returns:
2024-07-18 13:41:24 +08:00
Tuple[Dict, bool]: The converted string, and whether the follow-up
2023-07-04 21:34:55 +08:00
conversion should be proceeded.
"""
merged_prompt = role_dict.get(
role_prompt['role'],
role_dict.get(role_prompt.get('fallback_role')))
# res_api_prompt = dict(type='', )
if for_gen and merged_prompt.get('generate', False):
return None, False
res = {}
res['role'] = merged_prompt['api_role']
res['prompt'] = merged_prompt.get('begin', '')
res['prompt'] += merged_prompt.get('prompt', '')
res['prompt'] += merged_prompt.get('end', '')
return res, True
class TokenBucket:
"""A token bucket for rate limiting.
Args:
query_per_second (float): The rate of the token bucket.
"""
2023-12-11 17:42:53 +08:00
def __init__(self, rate, verbose=False):
2023-07-04 21:34:55 +08:00
self._rate = rate
self._tokens = threading.Semaphore(0)
self.started = False
2023-12-11 17:42:53 +08:00
self._request_queue = Queue()
self.logger = get_logger()
self.verbose = verbose
2023-07-04 21:34:55 +08:00
def _add_tokens(self):
"""Add tokens to the bucket."""
while True:
if self._tokens._value < self._rate:
self._tokens.release()
sleep(1 / self._rate)
def get_token(self):
"""Get a token from the bucket."""
if not self.started:
self.started = True
threading.Thread(target=self._add_tokens, daemon=True).start()
self._tokens.acquire()
2023-12-11 17:42:53 +08:00
if self.verbose:
cur_time = time.time()
while not self._request_queue.empty():
if cur_time - self._request_queue.queue[0] > 60:
self._request_queue.get()
else:
break
self._request_queue.put(cur_time)
self.logger.info(f'Current RPM {self._request_queue.qsize()}.')