This commit is contained in:
HAOCHENYE 2024-12-15 18:49:03 +08:00
parent aeded4c4db
commit aa48a2843d
7 changed files with 1356 additions and 2 deletions

View File

@ -0,0 +1,369 @@
import contextlib
import multiprocessing
import os
import re
from typing import Dict, List, Optional, Union
import jieba
import weakref
from typing import Literal, Tuple, Iterable
from opencompass.utils.prompt import PromptList
from opencompass.models.base_api import AsyncTokenBucket, BaseAPIModel
import threading
import asyncio
from typing import cast
from contextlib import contextmanager
PromptType = Union[PromptList, str]
OPENAI_API_BASE = os.path.join(
os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/'),
'chat/completions')
class _APIModelState:
_instance: Dict[str, weakref.ReferenceType["_APIModelState"]] = {}
_count: int
_concurrency: int
_locks = [threading.Lock(), multiprocessing.Lock()]
def __init__(self, *, name: str, concurrency: int, query_per_second=1) -> None:
self._name = name
self._count = 0
self._concurrency = concurrency
self._token_bucket = AsyncTokenBucket(rate=query_per_second)
self._count += 1
self._concurrency = max(1, self._concurrency // self._count)
@property
def concurrency(self) -> int:
# If update and concurrency are called simultaneously, the values
# returned here may be inaccurate, but the impact is likely minimal
return self._concurrency
async def acquire(self):
return await self._token_bucket.acquire()
@property
def rpm(self):
return self._token_bucket.rpm
@property
def name(self) -> str:
return self._name
@property
def count(self):
return self._count
@classmethod
def _cleanup(cls, ref: weakref.ReferenceType["_APIModelState"]):
with cls._lock():
self: _APIModelState = ref() # type: ignore
cls._instance.pop(self._name)
def __new__(cls, name: str, *args, **kwargs) -> "_APIModelState":
with cls._lock():
if name not in cls._instance:
self = super().__new__(cls)
cls._instance[name] = weakref.ref(self, cls._cleanup)
return cls._instance[name]() # type: ignore
@classmethod
@contextmanager
def _lock(cls):
with contextlib.ExitStack() as stack:
[stack.enter_context(lock) for lock in cls._locks]
yield
class AsyncOpenAISDK(BaseAPIModel):
states: Dict[str, _APIModelState] = {}
def __init__(
self,
path: str = 'gpt-3.5-turbo',
max_seq_len: int | None = None, # type: ignore
query_per_second: int = 1,
retry: int = 2,
key: str = 'ENV',
org: str | List[str] | None = None,
meta_template: Dict | None = None,
openai_api_base: str = OPENAI_API_BASE,
openai_proxy_url: Optional[str] = None,
mode: Literal['none', 'front', 'mid', 'rear'] = 'none',
logprobs: bool | None = False,
top_logprobs: int | None = None,
temperature: float | None = None,
tokenizer_path: str | None = None,
extra_body: Dict | None = None,
max_completion_tokens: int = 16384,
verbose: bool = False,
concurrency: int = 64,
status_code_mappings: dict = {},
):
from openai import AsyncOpenAI
assert mode in ['none', 'front', 'mid', 'rear']
self.mode = mode
state_key = self._get_state_key(api_base=openai_api_base, model_name=path)
if state_key not in AsyncOpenAISDK.states:
AsyncOpenAISDK.states[path] = _APIModelState(
name=state_key,
concurrency=concurrency,
query_per_second=query_per_second,
)
self.state = AsyncOpenAISDK.states[path]
self.openai_client = AsyncOpenAI(base_url=openai_api_base, api_key=key)
if max_seq_len is None:
if '16k' in path:
max_seq_len = 16384
elif 'gpt-4' in path:
max_seq_len = 8192
elif 'gpt-3.5' in path:
max_seq_len = 4097
else:
max_seq_len = 32768
else:
max_seq_len = max_seq_len
super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, retry=retry)
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.tokenizer_path = tokenizer_path
self.hf_tokenizer = None
self.extra_body = extra_body
self.max_completion_tokens = max_completion_tokens
self.temperature = temperature
self.openai_api_base = openai_api_base
self.concurrency = concurrency
self.status_code_mappings = status_code_mappings
if openai_proxy_url == 'ENV':
if 'OPENAI_PROXY_URL' not in os.environ:
raise ValueError('OPENAI_PROXY_URL is not set.')
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
else:
self.proxy_url = openai_proxy_url
async def generate(self, # type: ignore
inputs: Iterable[PromptType],
max_out_len: int = 512,
temperature: float = 0.7,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
Returns:
List[str]: A list of generated strings.
"""
if self.temperature is not None:
temperature = self.temperature
# TODO: This should be an AsyncGenerator if an real `AsyncInference` has been implemented
tasks_queue: List[asyncio.Future] = []
results_queue: List[Tuple[int, str]] = []
inputs_iter = enumerate(inputs)
data_stop = False
while not (data_stop and not tasks_queue):
concurrency = self.state.concurrency
if tasks_queue:
done, pending = await asyncio.wait(tasks_queue, return_when=asyncio.FIRST_COMPLETED)
tasks_queue = list(pending)
for queue in done:
result: Tuple[int, str] = queue.result()
results_queue.append(result)
while not data_stop and len(tasks_queue) < concurrency:
try:
index, _input = next(inputs_iter)
except StopIteration:
data_stop = True
break
tasks_queue.append(
asyncio.create_task(
self._generate(
input=_input,
max_out_len=self.max_completion_tokens or max_out_len,
temperature=temperature,
index=index,
)
)
)
results_queue.sort()
return [item[1] for item in results_queue]
async def generate_from_template(self, templates: List[PromptType], # type: ignore
max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates, mode='gen') # type: ignore
return await self.generate(inputs, max_out_len=max_out_len, **kwargs)
async def _generate(self, input: PromptList | str, max_out_len: int,
temperature: float, index: int) -> Tuple[int, str]:
from openai import APIStatusError, BadRequestError
assert isinstance(input, (str, PromptList))
# max num token for gpt-3.5-turbo is 4097
# Most models' token limits are above 32k
# will leave 100 tokens as prompt buffer, triggered if input is str
if isinstance(input, str) and self.mode != 'none':
context_window = self.max_seq_len
input = self.bin_trim(
input,
context_window - 100 - max_out_len,
cast(Literal['front', 'mid', 'rear'], self.mode),
)
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
for item in input:
msg = {'content': item['prompt']}
if item['role'] == 'HUMAN':
msg['role'] = 'user'
elif item['role'] == 'BOT':
msg['role'] = 'assistant'
elif item['role'] == 'SYSTEM':
msg['role'] = 'system'
messages.append(msg)
# Hold out 100 tokens due to potential errors in tiktoken calculation
# try:
# max_out_len = min(
# max_out_len,
# context_window - self.get_token_len(str(input)) - 100)
# except KeyError:
# max_out_len = max_out_len
# if max_out_len <= 0:
# return ''
num_retries = 0
while num_retries < self.retry:
await self.state.acquire()
query_data = dict(
model=self.path,
max_tokens=max_out_len,
n=1,
temperature=self.temperature,
messages=messages,
extra_body=self.extra_body,
timeout=600,
)
try:
if self.verbose:
self.logger.info('Start calling OpenAI API')
responses = await self.openai_client.chat.completions.create(**query_data)
if self.verbose:
self.logger.info(
'Successfully get response from OpenAI API')
try:
self.logger.info(responses)
except Exception as e: # noqa F841
pass
if not responses.choices:
self.logger.error(
'Response is empty, it is an internal server error \
from the API provider.')
return index, responses.choices[0].message.content
except (BadRequestError, APIStatusError) as e:
# Handle BadRequest status
# You can specify self.status_code_mappings to bypass \
# API sensitivity blocks
# For example: status_code_mappings={400: 'Input data \
# may contain inappropriate content.'}
status_code = e.status_code
if (status_code is not None
and status_code in self.status_code_mappings):
error_message = self.status_code_mappings[status_code]
self.logger.info(f'Status Code: {status_code},\n'
f'Original Error Message: {e},\n'
f'Return Message: {error_message} ')
return index, error_message
else:
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
except Exception as e:
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
num_retries += 1
raise RuntimeError('Calling OpenAI API failed after retrying for '
f'{self.retry} times. Check the logs for details.')
def _get_state_key(self, api_base: str, model_name: str):
return api_base + model_name
def bin_trim(self, prompt: str, num_token: int, mode: Literal['front', 'mid', 'rear']) -> str:
"""Get a suffix of prompt which is no longer than num_token tokens.
Args:
prompt (str): Input string.
num_token (int): The upper bound of token numbers.
Returns:
str: The trimmed prompt.
"""
token_len = self.get_token_len(prompt)
if token_len <= num_token:
return prompt
pattern = re.compile(r'[\u4e00-\u9fa5]')
if pattern.search(prompt):
words = list(jieba.cut(prompt, cut_all=False))
sep = ''
else:
words = prompt.split(' ')
sep = ' '
l, r = 1, len(words)
while l + 2 < r:
# mode: Literal['front', 'mid', 'rear'] = self.mode
mid = (l + r) // 2
if mode == 'front':
cur_prompt = sep.join(words[-mid:])
elif mode == 'mid':
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
elif mode == 'rear':
cur_prompt = sep.join(words[:mid])
if self.get_token_len(cur_prompt) <= num_token:
l = mid # noqa: E741
else:
r = mid
if self.mode == 'front':
prompt = sep.join(words[-l:])
elif self.mode == 'mid':
prompt = sep.join(words[:l]) + sep.join(words[-l:])
elif self.mode == 'rear':
prompt = sep.join(words[:l])
return prompt

View File

@ -4,6 +4,7 @@ import threading
import time import time
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from collections import deque
from copy import deepcopy from copy import deepcopy
from queue import Queue from queue import Queue
from time import sleep from time import sleep
@ -12,6 +13,8 @@ from typing import Dict, List, Optional, Tuple, Union
from opencompass.utils import get_logger from opencompass.utils import get_logger
from opencompass.utils.prompt import PromptList from opencompass.utils.prompt import PromptList
import asyncio
from .base import BaseModel from .base import BaseModel
PromptType = Union[PromptList, str] PromptType = Union[PromptList, str]
@ -51,7 +54,7 @@ class BaseAPIModel(BaseModel):
self.retry = retry self.retry = retry
self.query_per_second = query_per_second self.query_per_second = query_per_second
self.token_bucket = TokenBucket(query_per_second, rpm_verbose) self.token_bucket = TokenBucket(query_per_second, rpm_verbose)
self.template_parser = APITemplateParser(meta_template) self.template_parser = APITemplateParser(meta_template) # type: ignore
self.logger = get_logger() self.logger = get_logger()
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.verbose = verbose self.verbose = verbose
@ -459,4 +462,69 @@ class TokenBucket:
else: else:
break break
self._request_queue.put(cur_time) self._request_queue.put(cur_time)
self.logger.info(f'Current RPM {self._request_queue.qsize()}.') self.logger.info(f"Current RPM {self._request_queue.qsize()}.")
class AsyncTokenBucket:
def __init__(self, rate: int = 1):
self._rate = rate
self._max_tokens = rate * 60
self._tokens: float = float(self._max_tokens)
self._last_refill_time: float | None = None
self._request_timestamps: deque[float] = deque()
self._max_window_size = 60
self._token_available = asyncio.Event()
async def _release(self) -> None:
if self._last_refill_time is None:
self._last_refill_time: float = time.monotonic()
now = time.monotonic()
elapsed = now - self._last_refill_time
tokens_to_add = elapsed * self._rate
self._tokens = min(self._max_tokens, self._tokens + tokens_to_add)
self._last_refill_time = now
async def acquire(self) -> bool:
while True:
await self._release()
if self._tokens >= 1:
self._tokens -= 1
now = time.monotonic()
self._request_timestamps.append(now)
while (
self._request_timestamps
and now - self._request_timestamps[0] > self._max_window_size
):
self._request_timestamps.popleft()
self._token_available.set()
return True
self._token_available.clear()
await self._token_available.wait()
@property
def rpm(self) -> int:
now = time.monotonic()
while (
self._request_timestamps
and now - self._request_timestamps[0] > self._max_window_size
):
self._request_timestamps.popleft()
return len(self._request_timestamps)
@property
def available_tokens(self) -> float:
return self._tokens

View File

@ -0,0 +1,397 @@
"""Chat Inferencer."""
import os
import os.path as osp
from typing import List, Optional, Union
import mmengine
from mmengine import is_list_of
from tqdm import tqdm
from opencompass.models import APITemplateParser as _APITemplateParser
from opencompass.models import BaseModel
from opencompass.models import LMTemplateParser as _LMTemplateParser
from opencompass.registry import ICL_INFERENCERS
from opencompass.utils.prompt import PromptList
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
logger = get_logger(__name__)
def promptlist_to_openai(prompt: Union[str, PromptList]):
output = []
if isinstance(prompt, str):
return [dict(role='user', content=prompt)]
for item in prompt:
if 'section' in item:
continue
if isinstance(item, str) and item:
output.append(dict(role='user', content=item))
elif item['role'] == 'SYSTEM':
output.append(dict(role='system', content=item['prompt']))
elif item['role'] == 'HUMAN':
output.append(dict(role='user', content=item['prompt']))
elif item['role'] == 'BOT':
output.append(dict(role='assistant', content=item['prompt']))
return output
class LMTemplateParser:
"""LMTemplateParser accepts OpenAI format dialog inputs."""
def __init__(self, meta_template: Optional[dict] = None):
self.meta_template = meta_template
self.roles = {}
role_mapping = {
'SYSTEM': 'system',
'HUMAN': 'user',
'BOT': 'assistant',
}
if meta_template:
for item in meta_template.get('round', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
for item in meta_template.get('reserved_roles', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
def parse_template(self, chat: List[dict], mode='gen') -> str:
if is_list_of(chat, list):
# Handle batch inputs
return [self.parse_template(item) for item in chat]
assert is_list_of(chat, dict)
prompt = ''
if self.roles:
for dialog in chat:
role_cfg = self.roles.get(dialog['role'], {})
prompt += (role_cfg.get('begin') or '')
prompt += (dialog.get('content') or '')
prompt += (role_cfg.get('end') or '')
prompt += (self.roles['assistant'].get('begin') or '')
else:
# in case the model does not have any meta template
last_sep = ''
for item in chat:
prompt += last_sep + (item.get('content') or '')
last_sep = '\n'
return prompt
class APITemplateParser:
"""APITemplateParser accepts OpenAI format dialog inputs."""
def __init__(self, meta_template: Optional[dict] = None):
self.meta_template = meta_template
self.roles = {}
role_mapping = {
'SYSTEM': 'system',
'HUMAN': 'user',
'BOT': 'assistant',
}
if meta_template:
for item in meta_template.get('round', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
for item in meta_template.get('reserved_roles', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
else:
self.roles = dict(
system=dict(api_role='SYSTEM'),
user=dict(api_role='HUMAN'),
assistant=dict(api_role='BOT', generate=True),
)
def parse_template(self, chat: List[dict], mode='gen') -> str:
if is_list_of(chat, list):
# Handle batch inputs
return [self.parse_template(item) for item in chat]
assert is_list_of(chat, dict)
prompt = []
for dialog in chat:
if dialog['role'] in self.roles:
role = self.roles[dialog['role']]['api_role']
else:
role = dialog['role']
prompt.append(dict(role=role, prompt=dialog.get('content') or ''))
return PromptList(prompt)
class ChatOutputHandler:
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
def save_results(self,
origin_prompt: list,
prediction: str,
idx: int,
gold: str = None):
result_dict = {}
if gold:
result_dict['gold'] = gold
result_dict.update({
'prediction': prediction,
'origin_prompt': origin_prompt,
})
self.results_dict[str(idx)] = result_dict
def save_multiround_results(self,
origin_prompt: list,
prediction: str,
idx: int,
gold: str = None):
result_dict = self.results_dict.get(str(idx), {
'gold': [],
'prediction': [],
'origin_prompt': [],
})
result_dict['gold'].append(gold)
result_dict['prediction'].append(prediction)
result_dict['origin_prompt'].append(origin_prompt)
self.results_dict[str(idx)] = result_dict
@ICL_INFERENCERS.register_module()
class AsyncChatInferencer(BaseInferencer):
HandlerType = ChatOutputHandler
def __init__(
self,
model,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
infer_mode: str = 'last',
max_out_len: int = 512,
**kwargs) -> None:
super().__init__(
model=model,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
assert infer_mode in ['last', 'every', 'every_with_gt']
self.infer_mode = infer_mode
self.model: BaseModel
self._set_meta_template(self.model)
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
self.dialogue_mode = False
self.max_out_len = max_out_len
def _set_meta_template(self, model):
origin = model.template_parser
if isinstance(origin, _APITemplateParser):
model.template_parser = APITemplateParser(origin.meta_template)
if isinstance(origin, _LMTemplateParser):
model.template_parser = LMTemplateParser(origin.meta_template)
async def inference(self, # type: ignore
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> dict:
# 1. Preparation for output logs
output_handler = self.HandlerType()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
chat_list = self.get_chat_list(
ice_idx_list,
retriever,
prompt_template=prompt_template,
)
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
try:
tmp_result_dict = mmengine.load(tmp_json_filepath)
except Exception:
pass
else:
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(chat_list[index:], batch_size=1)
# 5. Inference for prompts in each batch
logger.debug('Starting inference process...')
for datum in tqdm(dataloader, disable=not self.is_main_process):
chat = datum[0]
if self.infer_mode == 'last':
await self.infer_last(chat, index, output_handler)
elif self.infer_mode == 'every':
await self.infer_every(chat, index, output_handler)
elif self.infer_mode == 'every_with_gt':
await self.infer_every_with_gt(chat, index, output_handler)
index += 1
# Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 4. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return output_handler.results_dict
def get_chat_list(self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
input_columns = retriever.dataset_reader.input_columns
output_column = retriever.dataset_reader.output_column
def chat_from_entry(entry):
if prompt_template is None and len(input_columns) == 1:
# Directly use the input column as the user input
user = entry.get(input_columns[0])
assistant = entry.get(output_column, '')
return [
dict(role='user', content=user),
dict(role='assistant', content=assistant),
]
elif prompt_template is not None:
# Use prompt template to generate chat history
chat = promptlist_to_openai(
prompt_template.generate_item(entry))
gold = entry.get(output_column, '')
if chat[-1]['role'] != 'assistant':
chat.append(dict(role='assistant', content=gold))
return chat
else:
raise ValueError()
for idx, ice_idx in enumerate(ice_idx_list):
# NOTE: The in-context examples won't be used by now.
item = {
k: v
for k, v in retriever.test_ds[idx].items()
if k in input_columns or k == output_column
}
if all(isinstance(value, str) for value in item.values()):
# Every column is a single string
chat = chat_from_entry(item)
elif all(is_list_of(value, str) for value in item.values()):
# Every column is a list of string for multi-round chat
entries = [dict(zip(item, v)) for v in zip(*item.values())]
chat = sum((chat_from_entry(entry) for entry in entries), [])
elif len(input_columns) == 1 and is_list_of(
item[input_columns[0]], dict):
# Single input column and it's already a chat.
chat = item[input_columns[0]]
elif 'dialogue' in input_columns:
chat = item['dialogue']
self.dialogue_mode = True
else:
raise ValueError('Cannot construct chat from the dataset.')
prompt_list.append(chat)
return prompt_list
async def infer_last(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
history = chat[:assistant_indices[-1]]
output = await self.model.generate_from_template(
[history], max_out_len=self.max_out_len)[0]
output_handler.save_results(
origin_prompt=history,
prediction=output,
idx=index,
gold=chat[assistant_indices[-1]]['content'],
)
async def infer_every(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
index_copy = index
for i in assistant_indices:
history = chat[:i]
output = await self.model.generate_from_template(
[history], max_out_len=self.max_out_len)[0]
chat[i]['content'] = output
if not self.dialogue_mode:
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
# index += 1
if self.dialogue_mode:
# dialogue mode for subjective evaluation
assert len(chat) % 2 == 0
round_num = int(len(chat) / 2)
preds_list = []
for i in range(round_num):
temp_dict = {
'round': i + 1,
'user': chat[i * 2]['content'],
'assistant': chat[i * 2 + 1]['content']
}
preds_list.append(temp_dict)
output_handler.save_results(
origin_prompt=None,
prediction=preds_list,
idx=index_copy,
gold=None,
)
async def infer_every_with_gt(self, chat: List[dict], index: int,
output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
for i in assistant_indices:
history = chat[:i]
output = await self.model.generate_from_template(
[history], max_out_len=self.max_out_len)[0]
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
index += 1

View File

@ -0,0 +1,239 @@
"""Direct Generation Inferencer."""
import inspect
import json
import os
import os.path as osp
import time
from typing import List, Optional
import mmengine
import torch
from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import ICL_INFERENCERS
from opencompass.utils import batched
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class AsyncGenInferencer(BaseInferencer):
"""Generation Inferencer class to directly evaluate by generation.
Attributes:
model (:obj:`BaseModelWrapper`, optional): The module to inference.
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
allowed by the LM.
min_out_len (:obj:`int`, optional): Minimum number of generated tokens
by the LM
batch_size (:obj:`int`, optional): Batch size for the
:obj:`DataLoader`.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
gen_field_replace_token (:obj:`str`, optional): Used to replace the
generation field token when generating prompts.
save_every (:obj:`int`, optional): Save intermediate results every
`save_every` iters. Defaults to 1.
generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method.
"""
def __init__(
self,
model: BaseModel,
max_out_len: int,
stopping_criteria: List[str] = [],
max_seq_len: Optional[int] = None,
min_out_len: Optional[int] = None,
batch_size: Optional[int] = 1,
gen_field_replace_token: Optional[str] = '',
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.min_out_len = min_out_len
self.stopping_criteria = stopping_criteria
self.dump_timer = kwargs.get('dump_timer', False)
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
async def inference(self, # type: ignore
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = GenInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
ice_idx_list,
retriever,
self.gen_field_replace_token,
max_seq_len=self.max_seq_len,
ice_template=ice_template,
prompt_template=prompt_template)
# 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
prompt_list = list(zip(prompt_list, gold_ans))
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
try:
tmp_result_dict = mmengine.load(tmp_json_filepath)
except Exception:
pass
else:
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
logger.debug('Starting build dataloader')
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
# 5. Inference for prompts in each batch
logger.debug('Starting inference process...')
start_time_stamp = time.time()
num_sample = 0
# TODO: batched dataloader shoule be replaced with async fetching
for datum in dataloader:
if ds_reader.output_column:
entry, golds = list(zip(*datum))
else:
entry = datum
golds = [None for _ in range(len(entry))]
# 5-1. Inference with local model
extra_gen_kwargs = {}
sig = inspect.signature(self.model.generate)
if 'stopping_criteria' in sig.parameters:
extra_gen_kwargs['stopping_criteria'] = self.stopping_criteria
if 'min_out_len' in sig.parameters:
extra_gen_kwargs['min_out_len'] = self.min_out_len
with torch.no_grad():
parsed_entries = self.model.parse_template(entry, mode='gen')
results = await self.model.generate_from_template(
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
generated = results
num_return_sequences = getattr(self.model, 'generation_kwargs',
{}).get('num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences),
golds):
if num_return_sequences == 1:
prediction = prediction[0]
output_handler.save_results(prompt,
prediction,
index,
gold=gold)
index = index + 1
# 5-4. Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
num_sample += len(datum)
end_time_stamp = time.time()
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
if self.dump_timer and self.is_main_process:
timer_filepath = os.path.join(output_json_filepath, 'timer',
'time.jsonl')
os.makedirs(os.path.dirname(timer_filepath), exist_ok=True)
time_dict = {
'dataset_name': output_json_filename.removesuffix('.json'),
'time': end_time_stamp - start_time_stamp,
'num_sample': num_sample
}
with open(timer_filepath, 'a') as f:
f.write(json.dumps(time_dict) + '\n')
return [
sample['prediction']
for sample in output_handler.results_dict.values()
]
def get_generation_prompt_list_from_retriever_indices(
self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
gen_field_replace_token: str,
max_seq_len: Optional[int] = None,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
prompt_list.append(prompt)
return prompt_list

View File

@ -0,0 +1,112 @@
from math import prod
import os
import os.path as osp
import re
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from threading import Lock
from typing import Any, Dict, List, Tuple
import mmengine
import numpy as np
from mmengine.config import ConfigDict
from mmengine.device import is_npu_available
from tqdm import tqdm
from opencompass.registry import RUNNERS, TASKS
from opencompass.utils import get_logger, model_abbr_from_cfg
from .base import BaseRunner
from typing import TypedDict, Optional
from multiprocessing.managers import Namespace
import threading
import uuid
import enum
import signal
from enum import IntEnum
import asyncio
import traceback
class Status(IntEnum):
SUCCESS = 0
FAILED = -1
INTERRUPT = signal.SIGINT
@RUNNERS.register_module()
class AsyncRunner(BaseRunner):
"""Local runner. Start tasks by local python.
Args:
task (ConfigDict): Task type config.
max_num_workers (int): Max number of workers to run in parallel.
Defaults to 16.
max_workers_per_gpu (int): Max number of workers to run for one GPU.
Defaults to 1.
debug (bool): Whether to run in debug mode.
lark_bot_url (str): Lark bot url.
"""
# These is a fake typehint
def __init__(self,
task: ConfigDict,
debug: bool = False,
*,
max_num_workers: int = 16,
keep_tmp_file: bool = False,
**kwargs):
super().__init__(task=task, debug=debug)
self.max_num_workers = max_num_workers
self.keep_tmp_file = keep_tmp_file
logger = get_logger()
for k, v in kwargs.items():
logger.warning(f'Ignored argument in `AsyncRunner`: {k}={v}')
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, Status]]: # type: ignore
"""Launch multiple tasks.
Args:
tasks (list[dict]): A list of task configs, usually generated by
Partitioner.
Returns:
list[tuple[str, int]]: A list of (task name, exit code).
"""
from opencompass.tasks.openicl_async_task import OpenICLAsyncInferTask
if not tasks:
return [("", Status.SUCCESS)]
assert len(tasks) == 1, f"Task num must be 1 for `AsyncRunner`"
task_cfg = tasks[0]
task: OpenICLAsyncInferTask = TASKS.build(dict(cfg=task_cfg, type=self.task_cfg['type']))
task_name = task.name
# get cmd
mmengine.mkdir_or_exist('tmp/')
try:
asyncio.run(task.run())
except KeyboardInterrupt:
return [(task_name, Status.INTERRUPT)]
except:
print(traceback.print_exc())
return [(task_name, Status.FAILED)]
else:
return [(task_name, Status.SUCCESS)]
def __call__(self, tasks: List[Dict[str, Any]]):
"""Launch multiple tasks and summarize the results.
Args:
tasks (list[dict]): A list of task configs, usually generated by
Partitioner.
"""
status = self.launch(tasks)
status_list = list(status) # change into list format
self.summarize(status_list)

View File

@ -1,3 +1,4 @@
from .openicl_attack import * # noqa: F401, F403 from .openicl_attack import * # noqa: F401, F403
from .openicl_eval import * # noqa: F401, F403 from .openicl_eval import * # noqa: F401, F403
from .openicl_infer import * # noqa: F401, F403 from .openicl_infer import * # noqa: F401, F403
from .openicl_async_task import * # noqa: F401, F403

View File

@ -0,0 +1,168 @@
import argparse
import os
import os.path as osp
import random
import sys
import time
from typing import Any
from tqdm.asyncio import tqdm
from mmengine.config import Config, ConfigDict
import inspect
from mmengine.utils import mkdir_or_exist
from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
ICL_RETRIEVERS, TASKS)
from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
get_infer_output_path, get_logger,
task_abbr_from_cfg)
from opencompass.openicl.icl_inferencer.icl_gen_async_inferencer import AsyncGenInferencer
from opencompass.openicl.icl_inferencer.icl_chat_async_inferencer import AsyncChatInferencer
from opencompass.openicl.icl_inferencer import GenInferencer, ChatInferencer
from concurrent.futures import ThreadPoolExecutor
import asyncio
import resource
from more_itertools import consume
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
@TASKS.register_module()
class OpenICLAsyncInferTask(BaseTask):
"""OpenICL Inference Task.
This task is used to run the inference process.
"""
name_prefix = 'OpenICLInfer'
log_subdir = 'logs/infer'
output_subdir = 'predictions'
def __init__(self, cfg: ConfigDict):
super().__init__(cfg)
run_cfg = self.model_cfgs[0].get('run_cfg', {})
self.nproc = run_cfg.get('nproc_per_worker', 16)
def get_command(self, cfg_path, template) -> str:
# TODO:
raise NotImplementedError()
return ""
async def run(self): # type: ignore
_dataset_cfgs = []
infer_cfgs = []
sub_cfgs = []
datasets = []
model_cfgs = []
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
self.max_out_len = model_cfg.get('max_out_len', None)
self.batch_size = model_cfg.get('batch_size', None)
self.min_out_len = model_cfg.get('min_out_len', None)
for dataset_cfg in dataset_cfgs:
self.dataset_cfg = dataset_cfg
out_path = get_infer_output_path(
model_cfg, dataset_cfg,
osp.join(self.work_dir, 'predictions'))
if osp.exists(out_path):
continue
_dataset_cfgs.append(dataset_cfg)
datasets.append(build_dataset_from_cfg(dataset_cfg))
infer_cfgs.append(dataset_cfg['infer_cfg'])
model_cfgs.append(model_cfg)
sub_cfg = {
'models': [model_cfg],
'datasets': [[dataset_cfg]],
}
sub_cfgs.append(sub_cfg)
tasks = []
args = list(zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs))
for arg in tqdm(
args,
total=len(args),
desc=f"Starting building tasks..."
):
tasks.append(asyncio.create_task(self._inference(*arg)))
bar = tqdm(desc="Inferencing...", total=len(tasks))
bar.refresh()
while tasks:
done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for _ in done:
bar.update()
bar.refresh()
# TODO: Needs a debug mode
# for arg in zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs):
# await self._inference(*arg)
async def _inference(self, dataset_cfg, infer_cfg, dataset, model_cfg, sub_cfg):
model = build_model_from_cfg(model_cfg)
assert hasattr(infer_cfg, 'ice_template') or hasattr(infer_cfg, 'prompt_template'), \
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
infer_kwargs: dict = {}
if hasattr(infer_cfg, 'ice_template'):
ice_template = ICL_PROMPT_TEMPLATES.build(
infer_cfg['ice_template'])
infer_kwargs['ice_template'] = ice_template
if hasattr(infer_cfg, 'prompt_template'):
prompt_template = ICL_PROMPT_TEMPLATES.build(
infer_cfg['prompt_template'])
infer_kwargs['prompt_template'] = prompt_template
retriever_cfg = infer_cfg['retriever'].copy()
retriever_cfg['dataset'] = dataset
retriever = ICL_RETRIEVERS.build(retriever_cfg)
# set inferencer's default value according to model's config'
inferencer_cfg: dict = infer_cfg['inferencer']
inferencer_cfg['model'] = model
inferencer_cfg['max_seq_len'] = model_cfg.get('max_seq_len')
infer_type = inferencer_cfg["type"]
if inspect.isclass(infer_type):
infer_name = infer_type.__name__
else:
infer_name = infer_type
if infer_name.split(".")[-1] == "ChatInferencer":
inferencer_cfg["type"] = AsyncChatInferencer
elif infer_name.split(".")[-1] == "GenInferencer":
inferencer_cfg["type"] = AsyncGenInferencer
inferencer_cfg.setdefault('max_out_len', self.max_out_len)
inferencer_cfg.setdefault('min_out_len', self.min_out_len)
inferencer_cfg.setdefault('batch_size', self.batch_size)
inferencer = ICL_INFERENCERS.build(inferencer_cfg)
out_path = get_infer_output_path(
model_cfg, dataset_cfg,
osp.join(self.work_dir, 'predictions'))
out_dir, out_file = osp.split(out_path)
mkdir_or_exist(out_dir)
infer_kwargs['output_json_filepath'] = out_dir
infer_kwargs['output_json_filename'] = out_file
await inferencer.inference(retriever, **infer_kwargs)
def parse_args():
parser = argparse.ArgumentParser(description='Model Inferencer')
parser.add_argument('config', help='Config file path')
args = parser.parse_args()
return args
if __name__ == '__main__':
# TODO:
raise NotImplementedError()