2023-07-04 21:34:55 +08:00
|
|
|
"""PPL Inferencer."""
|
|
|
|
|
|
|
|
import os
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from tqdm import trange
|
|
|
|
|
|
|
|
from opencompass.models.base import BaseModel
|
|
|
|
from opencompass.registry import ICL_INFERENCERS
|
|
|
|
|
|
|
|
from ..icl_prompt_template import PromptTemplate
|
|
|
|
from ..icl_retriever import BaseRetriever
|
|
|
|
from ..utils import get_logger
|
|
|
|
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
@ICL_INFERENCERS.register_module()
|
|
|
|
class PPLInferencer(BaseInferencer):
|
|
|
|
"""PPL Inferencer class to evaluate by perplexity.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
model (:obj:`BaseModel`, optional): The module to inference.
|
|
|
|
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
|
|
|
|
the LM.
|
2023-07-17 15:59:10 +08:00
|
|
|
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
|
2023-07-04 21:34:55 +08:00
|
|
|
output_json_filepath (:obj:`str`, optional): File path for output
|
|
|
|
`JSON` file.
|
|
|
|
output_json_filename (:obj:`str`, optional): File name for output
|
|
|
|
`JSON` file.
|
|
|
|
labels (:obj:`List`, optional): A list of labels for all classes.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model: BaseModel,
|
|
|
|
max_seq_len: Optional[int] = None,
|
|
|
|
batch_size: Optional[int] = 1,
|
|
|
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
|
|
|
output_json_filename: Optional[str] = 'predictions',
|
|
|
|
labels: Optional[List] = None,
|
|
|
|
fix_id_list: Optional[List[int]] = None,
|
|
|
|
**kwargs) -> None:
|
|
|
|
super().__init__(
|
|
|
|
model=model,
|
|
|
|
max_seq_len=max_seq_len,
|
|
|
|
batch_size=batch_size,
|
|
|
|
output_json_filename=output_json_filename,
|
|
|
|
output_json_filepath=output_json_filepath,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.labels = labels
|
|
|
|
self.fix_id_list = fix_id_list
|
|
|
|
|
|
|
|
def inference(self,
|
|
|
|
retriever: BaseRetriever,
|
|
|
|
ice_template: Optional[PromptTemplate] = None,
|
|
|
|
prompt_template: Optional[PromptTemplate] = None,
|
|
|
|
output_json_filepath: Optional[str] = None,
|
|
|
|
output_json_filename: Optional[str] = None,
|
|
|
|
normalizing_str: Optional[str] = None) -> List:
|
|
|
|
# 1. Preparation for output logs
|
|
|
|
output_handler = PPLInferencerOutputHandler()
|
|
|
|
|
|
|
|
sub_predictions = []
|
|
|
|
ppl = []
|
|
|
|
ice = []
|
|
|
|
|
|
|
|
if output_json_filepath is None:
|
|
|
|
output_json_filepath = self.output_json_filepath
|
|
|
|
if output_json_filename is None:
|
|
|
|
output_json_filename = self.output_json_filename
|
|
|
|
|
|
|
|
# 2. Get results of retrieval process
|
|
|
|
if self.fix_id_list:
|
|
|
|
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
|
|
|
else:
|
|
|
|
ice_idx_list = retriever.retrieve()
|
|
|
|
|
|
|
|
# 3. Get labels of all the classes
|
|
|
|
if self.labels is None:
|
|
|
|
labels = retriever.get_labels(ice_template=ice_template,
|
|
|
|
prompt_template=prompt_template)
|
|
|
|
else:
|
|
|
|
labels = self.labels
|
|
|
|
|
|
|
|
# 4. Generate in-context examples for testing inputs
|
|
|
|
for idx in range(len(ice_idx_list)):
|
|
|
|
ice.append(
|
|
|
|
retriever.generate_ice(ice_idx_list[idx],
|
|
|
|
ice_template=ice_template))
|
|
|
|
output_handler.save_ice(self.model.parse_template(ice, mode='ppl'))
|
|
|
|
|
|
|
|
# 5. Calculating PPL for prompts in each label's class
|
|
|
|
for label in labels:
|
|
|
|
index = 0
|
|
|
|
prompt_list = []
|
|
|
|
sub_ppl_list = []
|
|
|
|
normalizing_prompt_list = []
|
|
|
|
context_length_list = []
|
|
|
|
|
|
|
|
# 5.1 Generate prompts of current label and truncate
|
|
|
|
# TODO: Refactor
|
|
|
|
for idx in range(len(ice_idx_list)):
|
|
|
|
prompt = retriever.generate_label_prompt(
|
|
|
|
idx,
|
|
|
|
ice[idx],
|
|
|
|
label,
|
|
|
|
ice_template=ice_template,
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
remain_sep=normalizing_str is not None)
|
|
|
|
if self.max_seq_len is not None:
|
|
|
|
prompt_token_num = self.model.get_token_len_from_template(
|
|
|
|
prompt, mode='ppl')
|
|
|
|
while len(ice_idx_list[idx]
|
|
|
|
) > 0 and prompt_token_num > self.max_seq_len:
|
|
|
|
ice_idx_list[idx] = ice_idx_list[idx][:-1]
|
|
|
|
ice[idx] = retriever.generate_ice(
|
|
|
|
ice_idx_list[idx], ice_template=ice_template)
|
|
|
|
prompt = retriever.generate_label_prompt(
|
|
|
|
idx,
|
|
|
|
ice[idx],
|
|
|
|
label,
|
|
|
|
ice_template=ice_template,
|
|
|
|
prompt_template=prompt_template)
|
2023-07-17 15:59:10 +08:00
|
|
|
prompt_token_num = self.model.get_token_len_from_template( # noqa
|
2023-07-04 21:34:55 +08:00
|
|
|
prompt, mode='ppl') # noqa
|
|
|
|
|
|
|
|
if normalizing_str is not None:
|
|
|
|
assert isinstance(prompt, str), \
|
|
|
|
'Prompt must be a string when normalizing_str is set.'
|
|
|
|
prompt_sep = prompt
|
|
|
|
if prompt_template is not None:
|
|
|
|
sep_token = prompt_template.sep_token
|
|
|
|
else:
|
|
|
|
sep_token = ice_template.sep_token
|
|
|
|
sep_pos = prompt_sep.find(sep_token)
|
|
|
|
|
|
|
|
context = prompt_sep[0:sep_pos]
|
|
|
|
answer = prompt_sep[sep_pos:].replace(sep_token, '')
|
|
|
|
prompt = context + answer
|
|
|
|
normalizing_prompt = normalizing_str + answer
|
|
|
|
|
|
|
|
context_length_list.append(
|
|
|
|
self.model.get_token_len_from_template(context,
|
|
|
|
mode='ppl'))
|
|
|
|
normalizing_prompt_list.append(normalizing_prompt)
|
|
|
|
prompt_list.append(prompt)
|
|
|
|
|
|
|
|
if normalizing_str is not None:
|
|
|
|
normalizing_str_len = self.model.get_token_len_from_template(
|
|
|
|
normalizing_str, mode='ppl')
|
|
|
|
|
|
|
|
# 5.2 Get PPL
|
|
|
|
logger.info(f"Calculating PPL for prompts labeled '{label}'")
|
|
|
|
for idx in trange(0,
|
|
|
|
len(prompt_list),
|
|
|
|
self.batch_size,
|
|
|
|
disable=not self.is_main_process):
|
|
|
|
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
|
|
|
|
if normalizing_str is not None:
|
|
|
|
sub_context_length_list = context_length_list[idx:idx +
|
|
|
|
self.
|
|
|
|
batch_size]
|
|
|
|
sub_normalizing_prompt_list = normalizing_prompt_list[
|
|
|
|
idx:idx + self.batch_size]
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
if normalizing_str is not None:
|
|
|
|
res1 = self.model.get_ppl_from_template(
|
|
|
|
sub_prompt_list,
|
|
|
|
mask_length=sub_context_length_list)
|
|
|
|
res2 = self.model.get_ppl_from_template(
|
|
|
|
sub_normalizing_prompt_list,
|
|
|
|
mask_length=[
|
|
|
|
normalizing_str_len
|
|
|
|
for i in range(len(sub_prompt_list))
|
|
|
|
])
|
|
|
|
sub_res = res1 - res2
|
|
|
|
else:
|
|
|
|
sub_res = self.model.get_ppl_from_template(
|
|
|
|
sub_prompt_list).tolist()
|
|
|
|
for res, prompt in zip(
|
|
|
|
sub_res,
|
|
|
|
self.model.parse_template(sub_prompt_list,
|
|
|
|
mode='ppl')):
|
|
|
|
sub_ppl_list.append(res)
|
|
|
|
ice_str = self.model.parse_template(ice[idx], mode='ppl')
|
|
|
|
output_handler.save_prompt_and_ppl(
|
|
|
|
label, prompt.replace(ice_str, ''), prompt, res, index)
|
|
|
|
index = index + 1
|
|
|
|
ppl.append(sub_ppl_list)
|
|
|
|
|
|
|
|
# 6. Get lowest PPL class as predictions
|
|
|
|
ppl = list(zip(*ppl))
|
|
|
|
for single_ppl in ppl:
|
|
|
|
sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
|
|
|
|
output_handler.save_predictions(sub_predictions)
|
|
|
|
|
|
|
|
# 7. Output
|
|
|
|
if self.is_main_process:
|
|
|
|
os.makedirs(output_json_filepath, exist_ok=True)
|
|
|
|
output_handler.write_to_json(output_json_filepath,
|
|
|
|
output_json_filename)
|
|
|
|
|
|
|
|
return [
|
|
|
|
sample['prediction']
|
|
|
|
for sample in output_handler.results_dict.values()
|
|
|
|
]
|