OpenCompass/opencompass/openicl/icl_inferencer/icl_base_inferencer.py
2023-07-05 03:15:31 +00:00

163 lines
6.3 KiB
Python

"""Basic Inferencer."""
import json
import os
from pathlib import Path
from typing import List, Optional
import numpy as np
from mmengine.dist import is_main_process
from torch.utils.data import DataLoader
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
class BaseInferencer:
"""Base Inferencer class for all evaluation Inferencer.
Attributes:
model (:obj:`BaseModel`, optional): The module to inference.
max_model_token_num (:obj:`int`, optional): Maximum number of
tokenized words allowed by the LM.
batch_size (:obj:`int`, optional): Batch size for the
:obj:`DataLoader`.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
api_name (:obj:`str`, optional): Name of API service.
call_api (:obj:`bool`): If ``True``, an API for LM models will be used,
determined by :obj:`api_name`.
"""
model = None
def __init__(
self,
model,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
**kwargs,
) -> None:
self.model = model
self.max_seq_len = max_seq_len
self.batch_size = batch_size
self.output_json_filepath = output_json_filepath
self.output_json_filename = output_json_filename
self.is_main_process = is_main_process()
if not os.path.exists(self.output_json_filepath):
os.makedirs(self.output_json_filepath)
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
"""Perform In-Context Inference given a retriever and optional
templates.
Args:
retriever (:obj:`BaseRetriever`): An instance of a Retriever class
that will be used to retrieve in-context examples
ice_template (:obj:`PromptTemplate`, optional): A template for
generating the in-context examples prompt. Defaults to None.
prompt_template (:obj:`PromptTemplate`, optional): A template for
generating the final prompt. Defaults to None.
output_json_filepath (:obj:`str`, optional): The file path to save
the results as a `JSON` file. Defaults to None.
output_json_filename (:obj:`str`, optional): The file name to save
the results as a `JSON` file. Defaults to None.
Raises:
NotImplementedError: If the function is not implemented in the
subclass.
Returns:
:obj:`List:` A list of string, each representing the results of one
inference.
"""
raise NotImplementedError("Method hasn't been implemented yet")
@staticmethod
def get_dataloader(datalist: List[List], batch_size: int) -> DataLoader:
"""Return a dataloader of the input data list."""
dataloader = DataLoader(datalist,
batch_size=batch_size,
collate_fn=lambda x: x)
return dataloader
def dump_results_dict(results_dict, filename):
with open(filename, 'w', encoding='utf-8') as json_file:
json.dump(results_dict, json_file, indent=4, ensure_ascii=False)
class GenInferencerOutputHandler:
origin_prompt_dict = {}
output_dict = {}
prediction_dict = {}
results_dict = {}
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_results(self, origin_prompt, prediction, idx):
self.results_dict[str(idx)] = {
'origin_prompt': origin_prompt,
'prediction': prediction,
}
class PPLInferencerOutputHandler:
results_dict = {}
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)
def save_ice(self, ice):
for idx, example in enumerate(ice):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['in-context examples'] = example
def save_predictions(self, predictions):
for idx, prediction in enumerate(predictions):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['prediction'] = prediction
def save_prompt_and_ppl(self, label, input, prompt, ppl, idx):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
if 'label: ' + str(label) not in self.results_dict[str(idx)].keys():
self.results_dict[str(idx)]['label: ' + str(label)] = {}
self.results_dict[str(idx)]['label: ' +
str(label)]['testing input'] = input
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
# TODO:
# for single token situation, the input will always be yes currently
self.results_dict[str(idx)]['testing input'] = input
self.results_dict[str(idx)]['prompt'] = prompt
# TODO: hard code here
self.results_dict[str(idx)]['choices'] = choices
# For calculate auc scores, set scores as prediction
self.results_dict[str(idx)]['prediction'] = cond_prob
# set pred label in case needed
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob))