mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
398 lines
15 KiB
Python
398 lines
15 KiB
Python
"""Chat Inferencer."""
|
|
import os
|
|
import os.path as osp
|
|
from typing import List, Optional, Union
|
|
|
|
import mmengine
|
|
from mmengine import is_list_of
|
|
from tqdm import tqdm
|
|
|
|
from opencompass.models import APITemplateParser as _APITemplateParser
|
|
from opencompass.models import BaseModel
|
|
from opencompass.models import LMTemplateParser as _LMTemplateParser
|
|
from opencompass.registry import ICL_INFERENCERS
|
|
from opencompass.utils.prompt import PromptList
|
|
|
|
from ..icl_prompt_template import PromptTemplate
|
|
from ..icl_retriever import BaseRetriever
|
|
from ..utils.logging import get_logger
|
|
from .icl_base_inferencer import BaseInferencer, dump_results_dict
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def promptlist_to_openai(prompt: Union[str, PromptList]):
|
|
output = []
|
|
if isinstance(prompt, str):
|
|
return [dict(role='user', content=prompt)]
|
|
|
|
for item in prompt:
|
|
if 'section' in item:
|
|
continue
|
|
if isinstance(item, str) and item:
|
|
output.append(dict(role='user', content=item))
|
|
elif item['role'] == 'SYSTEM':
|
|
output.append(dict(role='system', content=item['prompt']))
|
|
elif item['role'] == 'HUMAN':
|
|
output.append(dict(role='user', content=item['prompt']))
|
|
elif item['role'] == 'BOT':
|
|
output.append(dict(role='assistant', content=item['prompt']))
|
|
return output
|
|
|
|
|
|
class LMTemplateParser:
|
|
"""LMTemplateParser accepts OpenAI format dialog inputs."""
|
|
|
|
def __init__(self, meta_template: Optional[dict] = None):
|
|
self.meta_template = meta_template
|
|
self.roles = {}
|
|
role_mapping = {
|
|
'SYSTEM': 'system',
|
|
'HUMAN': 'user',
|
|
'BOT': 'assistant',
|
|
}
|
|
if meta_template:
|
|
for item in meta_template.get('round', []):
|
|
role = role_mapping.get(item['role'], item['role'])
|
|
self.roles[role] = item.copy()
|
|
for item in meta_template.get('reserved_roles', []):
|
|
role = role_mapping.get(item['role'], item['role'])
|
|
self.roles[role] = item.copy()
|
|
|
|
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
|
if is_list_of(chat, list):
|
|
# Handle batch inputs
|
|
return [self.parse_template(item) for item in chat]
|
|
|
|
assert is_list_of(chat, dict)
|
|
prompt = ''
|
|
if self.roles:
|
|
for dialog in chat:
|
|
role_cfg = self.roles.get(dialog['role'], {})
|
|
prompt += (role_cfg.get('begin') or '')
|
|
prompt += (dialog.get('content') or '')
|
|
prompt += (role_cfg.get('end') or '')
|
|
prompt += (self.roles['assistant'].get('begin') or '')
|
|
else:
|
|
# in case the model does not have any meta template
|
|
last_sep = ''
|
|
for item in chat:
|
|
prompt += last_sep + (item.get('content') or '')
|
|
last_sep = '\n'
|
|
return prompt
|
|
|
|
|
|
class APITemplateParser:
|
|
"""APITemplateParser accepts OpenAI format dialog inputs."""
|
|
|
|
def __init__(self, meta_template: Optional[dict] = None):
|
|
self.meta_template = meta_template
|
|
self.roles = {}
|
|
role_mapping = {
|
|
'SYSTEM': 'system',
|
|
'HUMAN': 'user',
|
|
'BOT': 'assistant',
|
|
}
|
|
if meta_template:
|
|
for item in meta_template.get('round', []):
|
|
role = role_mapping.get(item['role'], item['role'])
|
|
self.roles[role] = item.copy()
|
|
for item in meta_template.get('reserved_roles', []):
|
|
role = role_mapping.get(item['role'], item['role'])
|
|
self.roles[role] = item.copy()
|
|
else:
|
|
self.roles = dict(
|
|
system=dict(api_role='SYSTEM'),
|
|
user=dict(api_role='HUMAN'),
|
|
assistant=dict(api_role='BOT', generate=True),
|
|
)
|
|
|
|
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
|
if is_list_of(chat, list):
|
|
# Handle batch inputs
|
|
return [self.parse_template(item) for item in chat]
|
|
|
|
assert is_list_of(chat, dict)
|
|
prompt = []
|
|
for dialog in chat:
|
|
if dialog['role'] in self.roles:
|
|
role = self.roles[dialog['role']]['api_role']
|
|
else:
|
|
role = dialog['role']
|
|
prompt.append(dict(role=role, prompt=dialog.get('content') or ''))
|
|
return PromptList(prompt)
|
|
|
|
|
|
class ChatOutputHandler:
|
|
|
|
def __init__(self) -> None:
|
|
self.results_dict = {}
|
|
|
|
def write_to_json(self, save_dir: str, filename: str):
|
|
"""Dump the result to a json file."""
|
|
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
|
|
|
|
def save_results(self,
|
|
origin_prompt: list,
|
|
prediction: str,
|
|
idx: int,
|
|
gold: str = None):
|
|
result_dict = {}
|
|
if gold:
|
|
result_dict['gold'] = gold
|
|
result_dict.update({
|
|
'prediction': prediction,
|
|
'origin_prompt': origin_prompt,
|
|
})
|
|
self.results_dict[str(idx)] = result_dict
|
|
|
|
def save_multiround_results(self,
|
|
origin_prompt: list,
|
|
prediction: str,
|
|
idx: int,
|
|
gold: str = None):
|
|
result_dict = self.results_dict.get(str(idx), {
|
|
'gold': [],
|
|
'prediction': [],
|
|
'origin_prompt': [],
|
|
})
|
|
result_dict['gold'].append(gold)
|
|
result_dict['prediction'].append(prediction)
|
|
result_dict['origin_prompt'].append(origin_prompt)
|
|
self.results_dict[str(idx)] = result_dict
|
|
|
|
|
|
@ICL_INFERENCERS.register_module()
|
|
class AsyncChatInferencer(BaseInferencer):
|
|
HandlerType = ChatOutputHandler
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
|
output_json_filename: Optional[str] = 'predictions',
|
|
save_every: Optional[int] = 1,
|
|
infer_mode: str = 'last',
|
|
max_out_len: int = 512,
|
|
**kwargs) -> None:
|
|
super().__init__(
|
|
model=model,
|
|
output_json_filename=output_json_filename,
|
|
output_json_filepath=output_json_filepath,
|
|
**kwargs,
|
|
)
|
|
assert infer_mode in ['last', 'every', 'every_with_gt']
|
|
self.infer_mode = infer_mode
|
|
self.model: BaseModel
|
|
self._set_meta_template(self.model)
|
|
|
|
if self.model.is_api and save_every is None:
|
|
save_every = 1
|
|
self.save_every = save_every
|
|
self.dialogue_mode = False
|
|
self.max_out_len = max_out_len
|
|
|
|
def _set_meta_template(self, model):
|
|
origin = model.template_parser
|
|
if isinstance(origin, _APITemplateParser):
|
|
model.template_parser = APITemplateParser(origin.meta_template)
|
|
if isinstance(origin, _LMTemplateParser):
|
|
model.template_parser = LMTemplateParser(origin.meta_template)
|
|
|
|
async def inference(self, # type: ignore
|
|
retriever: BaseRetriever,
|
|
ice_template: Optional[PromptTemplate] = None,
|
|
prompt_template: Optional[PromptTemplate] = None,
|
|
output_json_filepath: Optional[str] = None,
|
|
output_json_filename: Optional[str] = None) -> dict:
|
|
# 1. Preparation for output logs
|
|
output_handler = self.HandlerType()
|
|
|
|
if output_json_filepath is None:
|
|
output_json_filepath = self.output_json_filepath
|
|
if output_json_filename is None:
|
|
output_json_filename = self.output_json_filename
|
|
|
|
# 2. Get results of retrieval process
|
|
ice_idx_list = retriever.retrieve()
|
|
|
|
# 3. Generate prompts for testing input
|
|
chat_list = self.get_chat_list(
|
|
ice_idx_list,
|
|
retriever,
|
|
prompt_template=prompt_template,
|
|
)
|
|
|
|
# Create tmp json file for saving intermediate results and future
|
|
# resuming
|
|
index = 0
|
|
tmp_json_filepath = os.path.join(output_json_filepath,
|
|
'tmp_' + output_json_filename)
|
|
if osp.exists(tmp_json_filepath):
|
|
# TODO: move resume to output handler
|
|
try:
|
|
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
output_handler.results_dict = tmp_result_dict
|
|
index = len(tmp_result_dict)
|
|
|
|
# 4. Wrap prompts with Dataloader
|
|
dataloader = self.get_dataloader(chat_list[index:], batch_size=1)
|
|
|
|
# 5. Inference for prompts in each batch
|
|
logger.debug('Starting inference process...')
|
|
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
|
chat = datum[0]
|
|
if self.infer_mode == 'last':
|
|
await self.infer_last(chat, index, output_handler)
|
|
elif self.infer_mode == 'every':
|
|
await self.infer_every(chat, index, output_handler)
|
|
elif self.infer_mode == 'every_with_gt':
|
|
await self.infer_every_with_gt(chat, index, output_handler)
|
|
index += 1
|
|
|
|
# Save intermediate results
|
|
if (self.save_every is not None and index % self.save_every == 0
|
|
and self.is_main_process):
|
|
output_handler.write_to_json(output_json_filepath,
|
|
'tmp_' + output_json_filename)
|
|
|
|
# 4. Output
|
|
if self.is_main_process:
|
|
os.makedirs(output_json_filepath, exist_ok=True)
|
|
output_handler.write_to_json(output_json_filepath,
|
|
output_json_filename)
|
|
if osp.exists(tmp_json_filepath):
|
|
os.remove(tmp_json_filepath)
|
|
|
|
return output_handler.results_dict
|
|
|
|
def get_chat_list(self,
|
|
ice_idx_list: List[List[int]],
|
|
retriever: BaseRetriever,
|
|
prompt_template: Optional[PromptTemplate] = None):
|
|
prompt_list = []
|
|
input_columns = retriever.dataset_reader.input_columns
|
|
output_column = retriever.dataset_reader.output_column
|
|
|
|
def chat_from_entry(entry):
|
|
if prompt_template is None and len(input_columns) == 1:
|
|
# Directly use the input column as the user input
|
|
user = entry.get(input_columns[0])
|
|
assistant = entry.get(output_column, '')
|
|
return [
|
|
dict(role='user', content=user),
|
|
dict(role='assistant', content=assistant),
|
|
]
|
|
elif prompt_template is not None:
|
|
# Use prompt template to generate chat history
|
|
chat = promptlist_to_openai(
|
|
prompt_template.generate_item(entry))
|
|
gold = entry.get(output_column, '')
|
|
if chat[-1]['role'] != 'assistant':
|
|
chat.append(dict(role='assistant', content=gold))
|
|
return chat
|
|
else:
|
|
raise ValueError()
|
|
|
|
for idx, ice_idx in enumerate(ice_idx_list):
|
|
# NOTE: The in-context examples won't be used by now.
|
|
|
|
item = {
|
|
k: v
|
|
for k, v in retriever.test_ds[idx].items()
|
|
if k in input_columns or k == output_column
|
|
}
|
|
if all(isinstance(value, str) for value in item.values()):
|
|
# Every column is a single string
|
|
chat = chat_from_entry(item)
|
|
elif all(is_list_of(value, str) for value in item.values()):
|
|
# Every column is a list of string for multi-round chat
|
|
entries = [dict(zip(item, v)) for v in zip(*item.values())]
|
|
chat = sum((chat_from_entry(entry) for entry in entries), [])
|
|
elif len(input_columns) == 1 and is_list_of(
|
|
item[input_columns[0]], dict):
|
|
# Single input column and it's already a chat.
|
|
chat = item[input_columns[0]]
|
|
elif 'dialogue' in input_columns:
|
|
chat = item['dialogue']
|
|
self.dialogue_mode = True
|
|
else:
|
|
raise ValueError('Cannot construct chat from the dataset.')
|
|
|
|
prompt_list.append(chat)
|
|
return prompt_list
|
|
|
|
async def infer_last(self, chat: List[dict], index: int, output_handler):
|
|
assistant_indices = [
|
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
|
]
|
|
|
|
history = chat[:assistant_indices[-1]]
|
|
output = await self.model.generate_from_template(
|
|
[history], max_out_len=self.max_out_len)[0]
|
|
output_handler.save_results(
|
|
origin_prompt=history,
|
|
prediction=output,
|
|
idx=index,
|
|
gold=chat[assistant_indices[-1]]['content'],
|
|
)
|
|
|
|
async def infer_every(self, chat: List[dict], index: int, output_handler):
|
|
assistant_indices = [
|
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
|
]
|
|
index_copy = index
|
|
|
|
for i in assistant_indices:
|
|
history = chat[:i]
|
|
output = await self.model.generate_from_template(
|
|
[history], max_out_len=self.max_out_len)[0]
|
|
chat[i]['content'] = output
|
|
if not self.dialogue_mode:
|
|
output_handler.save_multiround_results(
|
|
origin_prompt=history[-1]['content'],
|
|
prediction=output,
|
|
idx=index,
|
|
gold=chat[i]['content'],
|
|
)
|
|
# index += 1
|
|
if self.dialogue_mode:
|
|
# dialogue mode for subjective evaluation
|
|
assert len(chat) % 2 == 0
|
|
round_num = int(len(chat) / 2)
|
|
preds_list = []
|
|
for i in range(round_num):
|
|
temp_dict = {
|
|
'round': i + 1,
|
|
'user': chat[i * 2]['content'],
|
|
'assistant': chat[i * 2 + 1]['content']
|
|
}
|
|
preds_list.append(temp_dict)
|
|
output_handler.save_results(
|
|
origin_prompt=None,
|
|
prediction=preds_list,
|
|
idx=index_copy,
|
|
gold=None,
|
|
)
|
|
|
|
async def infer_every_with_gt(self, chat: List[dict], index: int,
|
|
output_handler):
|
|
assistant_indices = [
|
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
|
]
|
|
|
|
for i in assistant_indices:
|
|
history = chat[:i]
|
|
output = await self.model.generate_from_template(
|
|
[history], max_out_len=self.max_out_len)[0]
|
|
output_handler.save_multiround_results(
|
|
origin_prompt=history[-1]['content'],
|
|
prediction=output,
|
|
idx=index,
|
|
gold=chat[i]['content'],
|
|
)
|
|
index += 1
|