mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support inference ppl datasets (#1315)
* commit inference ppl datasets * revised format * revise * revise * revise * revise * revise * revise
This commit is contained in:
parent
e9384823f2
commit
a244453d9e
26
configs/datasets/inference_ppl/README.md
Normal file
26
configs/datasets/inference_ppl/README.md
Normal file
@ -0,0 +1,26 @@
|
||||
# Inference-PPL Datasets
|
||||
|
||||
- **Description**: Compute the loss only on the labeled positions, especially used for reasoning corpus.
|
||||
- **Datasets**: cn-reasoning-val.jsonl (example datasets, inference-ppl can be generalized to more corpus).
|
||||
|
||||
# PPL Computation
|
||||
|
||||
$$ \text{ppl} = - \frac{1}{n} \sum_{i=0}^n \sum_{c=0}^{vocab\_size} y_{i,c} \log p_{i,c} \tag{1} $$
|
||||
|
||||
where Eq. (1) is the normal mean ppl computation formula, for inference-ppl, we only compute the average score based on pre-labeled position.
|
||||
|
||||
# Quick Start
|
||||
|
||||
```shell
|
||||
cd opencompass
|
||||
python run.py configs/eval_inference_ppl.py
|
||||
```
|
||||
|
||||
# Some results
|
||||
|
||||
| Model | Result |
|
||||
| ----------- | ----------- |
|
||||
| Qwen1.5-7b | 0.59 |
|
||||
| Qwen1.5-14b | 0.54 |
|
||||
| Llama2-7b | 0.49 |
|
||||
| Llama2-13b | 0.43 |
|
38
configs/datasets/inference_ppl/inference_ppl.py
Normal file
38
configs/datasets/inference_ppl/inference_ppl.py
Normal file
@ -0,0 +1,38 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import InferencePPLOnlyInferencer
|
||||
from opencompass.openicl.icl_evaluator import AverageInferencePPLEvaluator
|
||||
|
||||
from opencompass.datasets import InferencePPLDataset
|
||||
|
||||
# Build InferencePPLDataset
|
||||
inference_ppl_datasets = []
|
||||
|
||||
llm_cmp_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='{text}',
|
||||
),
|
||||
# No in-context example, using ZeroRetriever
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
# compute inference-ppl
|
||||
inferencer=dict(type=InferencePPLOnlyInferencer),
|
||||
)
|
||||
|
||||
# Average the inference-ppl scores
|
||||
llm_cmp_eval_cfg = dict(evaluator=dict(type=AverageInferencePPLEvaluator))
|
||||
|
||||
inference_ppl_datasets.append(
|
||||
dict(
|
||||
abbr=f'inference-ppl',
|
||||
type=InferencePPLDataset,
|
||||
path='./data/inference_ppl',
|
||||
name='cn-reasoning-val',
|
||||
samples=None, # Set small samples for testing
|
||||
reader_cfg=dict(
|
||||
input_columns=['text'],
|
||||
output_column=None,
|
||||
),
|
||||
infer_cfg=llm_cmp_infer_cfg,
|
||||
eval_cfg=llm_cmp_eval_cfg,
|
||||
))
|
62
configs/eval_inference_ppl.py
Normal file
62
configs/eval_inference_ppl.py
Normal file
@ -0,0 +1,62 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
# Inference PPL datasets
|
||||
from .datasets.inference_ppl.inference_ppl import inference_ppl_datasets
|
||||
|
||||
# Model configs
|
||||
from .models.qwen.hf_qwen1_5_7b import models as qwen1_5_7b
|
||||
from .models.qwen.hf_qwen1_5_14b import models as qwen1_5_14b
|
||||
from .models.hf_llama.hf_llama2_7b import models as llama2_7b
|
||||
from .models.hf_llama.hf_llama2_13b import models as llama2_13b
|
||||
|
||||
|
||||
from opencompass.partitioners import NaivePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
|
||||
|
||||
|
||||
# -------------Inference Stage ----------------------------------------
|
||||
|
||||
datasets = [*inference_ppl_datasets]
|
||||
workdir = 'outputs/inference_ppl'
|
||||
|
||||
models = [
|
||||
*qwen1_5_7b,
|
||||
*qwen1_5_14b,
|
||||
*llama2_7b,
|
||||
*llama2_13b,
|
||||
]
|
||||
|
||||
|
||||
|
||||
# Set custom batch_size and num_gpus for faster loss calculation
|
||||
# Smaller batch_size should give more precise results, at the cost of worse efficiency
|
||||
model_cfg = dict(
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4, num_procs=1)
|
||||
)
|
||||
|
||||
for mdl in models:
|
||||
mdl.update(model_cfg)
|
||||
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
task=dict(type=OpenICLInferTask),
|
||||
max_num_workers=256, # Maximum concurrent evaluation task count
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -------------Evaluation Stage ----------------------------------------
|
||||
eval = dict(
|
||||
partitioner=dict(type=NaivePartitioner),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
task=dict(type=OpenICLEvalTask),
|
||||
max_num_workers=256,
|
||||
)
|
||||
)
|
@ -53,6 +53,7 @@ from .humaneval import * # noqa: F401, F403
|
||||
from .humanevalx import * # noqa: F401, F403
|
||||
from .hungarian_math import * # noqa: F401, F403
|
||||
from .IFEval.ifeval import IFEvalDataset, IFEvaluator # noqa: F401, F403
|
||||
from .inference_ppl import InferencePPLDataset # noqa: F401, F403
|
||||
from .infinitebench import * # noqa: F401, F403
|
||||
from .iwslt2017 import * # noqa: F401, F403
|
||||
from .jigsawmultilingual import * # noqa: F401, F403
|
||||
|
37
opencompass/datasets/inference_ppl.py
Normal file
37
opencompass/datasets/inference_ppl.py
Normal file
@ -0,0 +1,37 @@
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class InferencePPLDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: List[str] = None, samples: int = None):
|
||||
|
||||
# Check if file exists in the given path
|
||||
supported_extensions = ['jsonl']
|
||||
for ext in supported_extensions:
|
||||
filename = osp.join(
|
||||
path, f'{name}.{ext}') # name refers to data subset name
|
||||
|
||||
if osp.exists(filename):
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(f'{filename} not found.')
|
||||
|
||||
samples = 'test' if samples is None else f'test[:{samples}]'
|
||||
|
||||
data_files = {'test': filename}
|
||||
|
||||
dataset = load_dataset('json', data_files=data_files, split=samples)
|
||||
|
||||
# Filter out empty samples
|
||||
dataset = dataset.filter(lambda example: len(example['text']) > 0)
|
||||
|
||||
return dataset
|
@ -85,6 +85,28 @@ class BaseModel:
|
||||
' ppl-based evaluation yet, try gen-based '
|
||||
'instead.')
|
||||
|
||||
@abstractmethod
|
||||
def get_ppl_tokenwise(
|
||||
self,
|
||||
inputs: List[str],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
"""Get tokenwise perplexity scores given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of strings.
|
||||
mask_length (Optional[List[int]]): A list of mask lengths. If
|
||||
provided, the perplexity scores will be calculated with the
|
||||
first mask_length[i] tokens masked out. It's okay to skip
|
||||
its implementation if advanced features in PPLInfernecer is
|
||||
not needed.
|
||||
|
||||
Returns:
|
||||
List[float]: A list of perplexity scores.
|
||||
"""
|
||||
raise NotImplementedError(f'{self.__class__.__name__} does not support'
|
||||
' ppl-based evaluation yet, try gen-based '
|
||||
'instead.')
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, prompt: str) -> torch.Tensor:
|
||||
"""Encode prompt to tokens. Not necessary for most cases.
|
||||
@ -151,6 +173,20 @@ class BaseModel:
|
||||
inputs = self.parse_template(templates, mode='ppl')
|
||||
return self.get_ppl(inputs, mask_length)
|
||||
|
||||
def get_ppl_tokenwise_from_template(self,
|
||||
templates: List[PromptType],
|
||||
label: List[List[int]],
|
||||
mask_length=None):
|
||||
"""Get token-wise perplexity given a list of templates.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
mask_length (List[int]): A list of mask lengths. If provided, the
|
||||
perplexity will be calculated only on the unmasked tokens.
|
||||
"""
|
||||
inputs = self.parse_template(templates, mode='ppl')
|
||||
return self.get_ppl_tokenwise(inputs, label, mask_length)
|
||||
|
||||
def generate_from_template(self, templates: List[PromptType],
|
||||
max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
|
@ -226,6 +226,165 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
self.model.eval()
|
||||
self.model.generation_config.do_sample = False
|
||||
|
||||
|
||||
def get_ppl_tokenwise(self, inputs: List[str], label: List[List[int]], mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
"""Get inference-ppl per token given a list of inputs and label.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of strings.
|
||||
label (List[List[int]]): A list of list of label, each label is a tuple of (start, end, 1)
|
||||
mask_length (Optional[List[int]]): A list of mask lengths. If
|
||||
provided, the perplexity scores will be calculated with the
|
||||
first mask_length[i] tokens masked out. It's okay to skip
|
||||
its implementation if advanced features in PPLInfernecer is
|
||||
not needed.
|
||||
|
||||
Returns:
|
||||
List[float]: A list of perplexity scores.
|
||||
"""
|
||||
assert self.tokenizer.pad_token
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
messages = _convert_base_messages(inputs)
|
||||
|
||||
tokenize_kwargs = dict(
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
max_length=self.max_seq_len,
|
||||
)
|
||||
|
||||
self.tokenizer.padding_side = 'right'
|
||||
self.tokenizer.truncation_side = 'right'
|
||||
|
||||
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
|
||||
|
||||
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
|
||||
outputs = self.model(**tokens)[0]
|
||||
|
||||
batch_size, seq_len, vocab_size = outputs.shape
|
||||
shift_logits = outputs[:, :-1, :].contiguous().float()
|
||||
shift_labels = tokens['input_ids'][:, 1:].contiguous()
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.view(-1, vocab_size),
|
||||
shift_labels.view(-1),
|
||||
ignore_index=pad_token_id,
|
||||
reduction='none').view(batch_size, seq_len - 1)
|
||||
lens = (tokens['input_ids'] != pad_token_id).sum(-1).cpu().numpy()
|
||||
|
||||
if mask_length is not None:
|
||||
import numpy as np
|
||||
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
|
||||
for i in range(len(mask)):
|
||||
for j in range(mask_length[i] - 1, len(mask[i])):
|
||||
mask[i][j] = 1
|
||||
loss = loss * mask
|
||||
lens -= np.array(mask_length)
|
||||
|
||||
loss = loss.cpu().numpy()
|
||||
|
||||
decode_messages = [[self.tokenizer.decode([input_id]) for input_id in token] for token in tokens['input_ids']]
|
||||
char_messages = [[ch for ch in message] for message in messages]
|
||||
|
||||
# shifted to align label and loss
|
||||
for i in range(len(decode_messages)):
|
||||
decode_messages[i] = decode_messages[i][1:]
|
||||
|
||||
aggregated_label_list = [[] for _ in range(len(decode_messages))]
|
||||
|
||||
tag_list = [[] for _ in range(len(decode_messages))]
|
||||
|
||||
for tmp_index, label_list in enumerate(label):
|
||||
for single_label in label_list:
|
||||
left = single_label[0]
|
||||
right = single_label[1]
|
||||
for i in range(left, right):
|
||||
aggregated_label_list[tmp_index].append(i)
|
||||
|
||||
|
||||
def align_sequences(seq1, seq2, sep_len):
|
||||
"""
|
||||
seq1: decoded sequence from token, one token may contain multiple characters
|
||||
seq2: original separate character sequence
|
||||
"""
|
||||
i, j = 0, 0
|
||||
matched_pairs = []
|
||||
while i < len(seq1) and j < len(seq2):
|
||||
word = seq1[i]
|
||||
if len(word) == 0:
|
||||
matched_pairs.append((word, []))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if '\ufffd' in word:
|
||||
for _ in range(sep_len):
|
||||
matched_pairs.append((word, [j]))
|
||||
i += 1
|
||||
j += 1
|
||||
continue
|
||||
|
||||
char_sequence = ''
|
||||
while j < len(seq2) and (char_sequence != word):
|
||||
char_sequence += seq2[j]
|
||||
if char_sequence == word:
|
||||
matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)]))
|
||||
j += 1
|
||||
break
|
||||
elif len(char_sequence) > len(word):
|
||||
if word == char_sequence[-len(word):]:
|
||||
matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)]))
|
||||
j += 1
|
||||
break
|
||||
else:
|
||||
j += 1
|
||||
else:
|
||||
j += 1
|
||||
i += 1
|
||||
|
||||
return matched_pairs
|
||||
|
||||
|
||||
|
||||
if 'qwen' in self.path or 'Qwen' in self.path:
|
||||
sep_len = 2
|
||||
elif 'Llama-3' in self.path:
|
||||
sep_len = 2
|
||||
elif 'Yi' in self.path:
|
||||
sep_len = 3
|
||||
elif 'Llama-2' in self.path:
|
||||
sep_len = 3
|
||||
elif 'deepseek' in self.path:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 3
|
||||
|
||||
|
||||
matched_pairs_list = [align_sequences(decode_messages[i], char_messages[i], sep_len) for i in range(len(decode_messages))]
|
||||
for match_index, matched_pairs in enumerate(matched_pairs_list):
|
||||
for i, (word, indices) in enumerate(matched_pairs):
|
||||
for j in indices:
|
||||
if j in aggregated_label_list[match_index]:
|
||||
tag_list[match_index].append(i)
|
||||
break
|
||||
|
||||
inference_loss_list = []
|
||||
token_len_list = []
|
||||
for i in range(len(loss)):
|
||||
inference_loss = 0
|
||||
token_len = 0
|
||||
for j in range(len(loss[i])):
|
||||
if j in tag_list[i]:
|
||||
|
||||
inference_loss += loss[i][j]
|
||||
print(loss[i][j])
|
||||
token_len += 1
|
||||
inference_loss_list.append(inference_loss)
|
||||
token_len_list.append(token_len)
|
||||
|
||||
return inference_loss_list, token_len_list
|
||||
|
||||
def _get_potential_stop_words(self, path: Optional[str]):
|
||||
from transformers import GenerationConfig
|
||||
potential_stop_words = []
|
||||
|
@ -6,6 +6,7 @@ from .icl_circular_evaluator import CircularEvaluator # noqa
|
||||
from .icl_em_evaluator import EMEvaluator # noqa
|
||||
from .icl_hf_evaluator import * # noqa
|
||||
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
|
||||
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa
|
||||
from .icl_misc_evaluator import AverageMinKEvaluator # noqa
|
||||
from .icl_misc_evaluator import AveragePPLEvaluator # noqa
|
||||
from .icl_plugin_evaluator import TEvalEvaluator # noqa
|
||||
|
@ -17,3 +17,11 @@ class AverageMinKEvaluator(BaseEvaluator):
|
||||
def score(self, mink):
|
||||
average_mink = sum(mink) / len(mink)
|
||||
return {'average_mink': average_mink}
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class AverageInferencePPLEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, ppl, token_len):
|
||||
average_ppl = sum(ppl) / sum(token_len)
|
||||
return {'average_ppl': average_ppl}
|
||||
|
@ -4,6 +4,8 @@ from .icl_base_inferencer import BaseInferencer # noqa
|
||||
from .icl_chat_inferencer import ChatInferencer # noqa
|
||||
from .icl_clp_inferencer import CLPInferencer # noqa
|
||||
from .icl_gen_inferencer import GenInferencer # noqa
|
||||
from .icl_inference_ppl_only_inferencer import \
|
||||
InferencePPLOnlyInferencer # noqa
|
||||
from .icl_ll_inferencer import LLInferencer # noqa
|
||||
from .icl_mink_percent_inferencer import MinKPercentInferencer # noqa
|
||||
from .icl_ppl_inferencer import PPLInferencer # noqa
|
||||
|
@ -0,0 +1,239 @@
|
||||
"""PPL Inferencer."""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.registry import ICL_INFERENCERS
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, dump_results_dict
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class InferencePPLOnlyInferencer(BaseInferencer):
|
||||
"""InferencePPLOnlyInferencer class to calculate Inference-PPL only, no
|
||||
choice is made. This Inferencer is usually used along with
|
||||
AverageInferencePPLEvaluator.
|
||||
|
||||
Attributes:
|
||||
model (:obj:`BaseModel`, optional): The module to inference.
|
||||
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
|
||||
the LM.
|
||||
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
|
||||
output_json_filepath (:obj:`str`, optional): File path for output
|
||||
`JSON` file.
|
||||
output_json_filename (:obj:`str`, optional): File name for output
|
||||
`JSON` file.
|
||||
save_every (:obj:`int`, optional): Save intermediate results every
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
max_seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = 1,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = 1,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.save_every = save_every
|
||||
|
||||
def inference(self,
|
||||
retriever: BaseRetriever,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
output_json_filepath: Optional[str] = None,
|
||||
output_json_filename: Optional[str] = None) -> List:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = InferencePPLOnlyInferencerOutputHandler()
|
||||
|
||||
if output_json_filepath is None:
|
||||
output_json_filepath = self.output_json_filepath
|
||||
if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list, label_list = self.get_generation_prompt_list_and_label(
|
||||
ice_idx_list,
|
||||
retriever,
|
||||
max_seq_len=self.max_seq_len,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
prompt_list = [{
|
||||
'prompt': prompt,
|
||||
'label': label
|
||||
} for prompt, label in zip(prompt_list, label_list)]
|
||||
|
||||
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
|
||||
assert ds_reader.output_column is None, (
|
||||
'InferencePPLOnlyInferencer supports `output_column=None` only.')
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
if os.path.exists(tmp_json_filepath):
|
||||
# TODO: move resume to output handler
|
||||
try:
|
||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
output_handler.results_dict = tmp_result_dict
|
||||
index = len(tmp_result_dict)
|
||||
|
||||
# 4. Wrap prompts with Dataloader
|
||||
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting inference process...')
|
||||
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||
entry = [datum_single['prompt'] for datum_single in datum]
|
||||
label = [datum_single['label'] for datum_single in datum]
|
||||
|
||||
# 5-1. Inference with local model
|
||||
with torch.no_grad():
|
||||
(inference_loss_list,
|
||||
token_len_list) = self.model.get_ppl_tokenwise_from_template(
|
||||
entry, label)
|
||||
|
||||
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||
# 5-3. Save current output
|
||||
for prompt, inference_loss, token_len, in zip(
|
||||
parsed_entries, inference_loss_list, token_len_list):
|
||||
output_handler.save_results(prompt, inference_loss, token_len,
|
||||
index)
|
||||
index = index + 1
|
||||
|
||||
# 5-4. Save intermediate results
|
||||
if (self.save_every is not None and index % self.save_every == 0
|
||||
and self.is_main_process):
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
|
||||
# 6. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
if os.path.exists(tmp_json_filepath):
|
||||
os.remove(tmp_json_filepath)
|
||||
|
||||
return [
|
||||
sample['ppl'] for sample in output_handler.results_dict.values()
|
||||
]
|
||||
|
||||
def get_generation_prompt_list_from_retriever_indices(
|
||||
self,
|
||||
ice_idx_list: List[List[int]],
|
||||
retriever: BaseRetriever,
|
||||
max_seq_len: Optional[int] = None,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
prompt_list = []
|
||||
for idx, ice_idx in enumerate(ice_idx_list):
|
||||
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
if max_seq_len is not None:
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||
ice_idx = ice_idx[:-1]
|
||||
ice = retriever.generate_ice(ice_idx,
|
||||
ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
prompt_list.append(prompt)
|
||||
return prompt_list
|
||||
|
||||
def get_generation_prompt_list_and_label(
|
||||
self,
|
||||
ice_idx_list: List[List[int]],
|
||||
retriever: BaseRetriever,
|
||||
max_seq_len: Optional[int] = None,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
prompt_list = []
|
||||
label_list = []
|
||||
for idx, ice_idx in enumerate(ice_idx_list):
|
||||
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||
|
||||
prompt, label = retriever.generate_prompt_and_label_for_generate_task( # noqa
|
||||
idx,
|
||||
ice,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
if max_seq_len is not None:
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||
ice_idx = ice_idx[:-1]
|
||||
ice = retriever.generate_ice(ice_idx,
|
||||
ice_template=ice_template)
|
||||
prompt, label = retriever.generate_prompt_for_generate_task( # noqa
|
||||
idx,
|
||||
ice,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
prompt_list.append(prompt)
|
||||
label_list.append(label)
|
||||
return prompt_list, label_list
|
||||
|
||||
|
||||
class InferencePPLOnlyInferencerOutputHandler:
|
||||
origin_prompt_dict = {}
|
||||
output_dict = {}
|
||||
results_dict = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.results_dict = {}
|
||||
|
||||
def write_to_json(self, save_dir: str, filename: str):
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, os.path.join(save_dir, filename))
|
||||
|
||||
def save_results(self, origin_prompt, ppl, token_len, idx):
|
||||
self.results_dict[str(idx)] = {
|
||||
'origin_prompt': origin_prompt,
|
||||
'ppl': ppl,
|
||||
'token_len': token_len,
|
||||
}
|
@ -207,6 +207,59 @@ class BaseRetriever:
|
||||
raise NotImplementedError(
|
||||
'Leaving prompt as empty is not supported')
|
||||
|
||||
def generate_prompt_and_label_for_generate_task(
|
||||
self,
|
||||
idx,
|
||||
ice,
|
||||
gen_field_replace_token='',
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
"""Generate the prompt and the label info for one test example in
|
||||
generative evaluation with `prompt_template`. If `prompt_template` is
|
||||
not provided, the `ice_template` will be used to generate the prompt.
|
||||
The token represented by `gen_field_replace_token` will not be replaced
|
||||
by the generated text, or it will leaks the answer.
|
||||
|
||||
Args:
|
||||
idx (`int`): The index of the test example.
|
||||
ice (`str`): The in-context example for the test example.
|
||||
gen_field_replace_token (`str`): The token of the answer in the
|
||||
prompt. Defaults to ''.
|
||||
ice_template (`Optional[PromptTemplate]`): The template for
|
||||
in-context example. Defaults to None.
|
||||
prompt_template (`Optional[PromptTemplate]`): The template for
|
||||
prompt. Defaults to None.
|
||||
"""
|
||||
if prompt_template is not None and ice_template is not None:
|
||||
if prompt_template.ice_token is not None:
|
||||
return prompt_template.generate_item(
|
||||
self.test_ds[idx],
|
||||
output_field=self.dataset_reader.output_column,
|
||||
output_field_replace_token=gen_field_replace_token,
|
||||
ice_field_replace_token=ice), self.test_ds[idx]['label']
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'ice_token of prompt_template is not provided')
|
||||
elif ice_template is not None and prompt_template is None:
|
||||
if ice_template.ice_token is not None:
|
||||
return ice_template.generate_item(
|
||||
self.test_ds[idx],
|
||||
output_field=self.dataset_reader.output_column,
|
||||
output_field_replace_token=gen_field_replace_token,
|
||||
ice_field_replace_token=ice), self.test_ds[idx]['label']
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'ice_token of ice_template is not provided')
|
||||
elif ice_template is None and prompt_template is not None:
|
||||
return prompt_template.generate_item(
|
||||
self.test_ds[idx],
|
||||
output_field=self.dataset_reader.output_column,
|
||||
output_field_replace_token=gen_field_replace_token,
|
||||
ice_field_replace_token=ice), self.test_ds[idx]['label']
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Leaving prompt as empty is not supported')
|
||||
|
||||
def generate_prompt_for_adv_generate_task(
|
||||
self,
|
||||
idx,
|
||||
|
Loading…
Reference in New Issue
Block a user