[Feature] Support inference ppl datasets (#1315)

* commit inference ppl datasets

* revised format

* revise

* revise

* revise

* revise

* revise

* revise
This commit is contained in:
Que Haoran 2024-07-22 17:59:30 +08:00 committed by GitHub
parent e9384823f2
commit a244453d9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 662 additions and 0 deletions

View File

@ -0,0 +1,26 @@
# Inference-PPL Datasets
- **Description**: Compute the loss only on the labeled positions, especially used for reasoning corpus.
- **Datasets**: cn-reasoning-val.jsonl (example datasets, inference-ppl can be generalized to more corpus).
# PPL Computation
$$ \text{ppl} = - \frac{1}{n} \sum_{i=0}^n \sum_{c=0}^{vocab\_size} y_{i,c} \log p_{i,c} \tag{1} $$
where Eq. (1) is the normal mean ppl computation formula, for inference-ppl, we only compute the average score based on pre-labeled position.
# Quick Start
```shell
cd opencompass
python run.py configs/eval_inference_ppl.py
```
# Some results
| Model | Result |
| ----------- | ----------- |
| Qwen1.5-7b | 0.59 |
| Qwen1.5-14b | 0.54 |
| Llama2-7b | 0.49 |
| Llama2-13b | 0.43 |

View File

@ -0,0 +1,38 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import InferencePPLOnlyInferencer
from opencompass.openicl.icl_evaluator import AverageInferencePPLEvaluator
from opencompass.datasets import InferencePPLDataset
# Build InferencePPLDataset
inference_ppl_datasets = []
llm_cmp_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{text}',
),
# No in-context example, using ZeroRetriever
retriever=dict(type=ZeroRetriever),
# compute inference-ppl
inferencer=dict(type=InferencePPLOnlyInferencer),
)
# Average the inference-ppl scores
llm_cmp_eval_cfg = dict(evaluator=dict(type=AverageInferencePPLEvaluator))
inference_ppl_datasets.append(
dict(
abbr=f'inference-ppl',
type=InferencePPLDataset,
path='./data/inference_ppl',
name='cn-reasoning-val',
samples=None, # Set small samples for testing
reader_cfg=dict(
input_columns=['text'],
output_column=None,
),
infer_cfg=llm_cmp_infer_cfg,
eval_cfg=llm_cmp_eval_cfg,
))

View File

@ -0,0 +1,62 @@
from mmengine.config import read_base
with read_base():
# Inference PPL datasets
from .datasets.inference_ppl.inference_ppl import inference_ppl_datasets
# Model configs
from .models.qwen.hf_qwen1_5_7b import models as qwen1_5_7b
from .models.qwen.hf_qwen1_5_14b import models as qwen1_5_14b
from .models.hf_llama.hf_llama2_7b import models as llama2_7b
from .models.hf_llama.hf_llama2_13b import models as llama2_13b
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
# -------------Inference Stage ----------------------------------------
datasets = [*inference_ppl_datasets]
workdir = 'outputs/inference_ppl'
models = [
*qwen1_5_7b,
*qwen1_5_14b,
*llama2_7b,
*llama2_13b,
]
# Set custom batch_size and num_gpus for faster loss calculation
# Smaller batch_size should give more precise results, at the cost of worse efficiency
model_cfg = dict(
batch_size=8,
run_cfg=dict(num_gpus=4, num_procs=1)
)
for mdl in models:
mdl.update(model_cfg)
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
task=dict(type=OpenICLInferTask),
max_num_workers=256, # Maximum concurrent evaluation task count
),
)
# -------------Evaluation Stage ----------------------------------------
eval = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
task=dict(type=OpenICLEvalTask),
max_num_workers=256,
)
)

View File

@ -53,6 +53,7 @@ from .humaneval import * # noqa: F401, F403
from .humanevalx import * # noqa: F401, F403
from .hungarian_math import * # noqa: F401, F403
from .IFEval.ifeval import IFEvalDataset, IFEvaluator # noqa: F401, F403
from .inference_ppl import InferencePPLDataset # noqa: F401, F403
from .infinitebench import * # noqa: F401, F403
from .iwslt2017 import * # noqa: F401, F403
from .jigsawmultilingual import * # noqa: F401, F403

View File

@ -0,0 +1,37 @@
import os.path as osp
from typing import List
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class InferencePPLDataset(BaseDataset):
@staticmethod
def load(path: str, name: List[str] = None, samples: int = None):
# Check if file exists in the given path
supported_extensions = ['jsonl']
for ext in supported_extensions:
filename = osp.join(
path, f'{name}.{ext}') # name refers to data subset name
if osp.exists(filename):
break
else:
raise FileNotFoundError(f'{filename} not found.')
samples = 'test' if samples is None else f'test[:{samples}]'
data_files = {'test': filename}
dataset = load_dataset('json', data_files=data_files, split=samples)
# Filter out empty samples
dataset = dataset.filter(lambda example: len(example['text']) > 0)
return dataset

View File

@ -85,6 +85,28 @@ class BaseModel:
' ppl-based evaluation yet, try gen-based '
'instead.')
@abstractmethod
def get_ppl_tokenwise(
self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Get tokenwise perplexity scores given a list of inputs.
Args:
inputs (List[str]): A list of strings.
mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out. It's okay to skip
its implementation if advanced features in PPLInfernecer is
not needed.
Returns:
List[float]: A list of perplexity scores.
"""
raise NotImplementedError(f'{self.__class__.__name__} does not support'
' ppl-based evaluation yet, try gen-based '
'instead.')
@abstractmethod
def encode(self, prompt: str) -> torch.Tensor:
"""Encode prompt to tokens. Not necessary for most cases.
@ -151,6 +173,20 @@ class BaseModel:
inputs = self.parse_template(templates, mode='ppl')
return self.get_ppl(inputs, mask_length)
def get_ppl_tokenwise_from_template(self,
templates: List[PromptType],
label: List[List[int]],
mask_length=None):
"""Get token-wise perplexity given a list of templates.
Args:
templates (List[PromptType]): A list of templates.
mask_length (List[int]): A list of mask lengths. If provided, the
perplexity will be calculated only on the unmasked tokens.
"""
inputs = self.parse_template(templates, mode='ppl')
return self.get_ppl_tokenwise(inputs, label, mask_length)
def generate_from_template(self, templates: List[PromptType],
max_out_len: int, **kwargs):
"""Generate completion from a list of templates.

View File

@ -226,6 +226,165 @@ class HuggingFacewithChatTemplate(BaseModel):
self.model.eval()
self.model.generation_config.do_sample = False
def get_ppl_tokenwise(self, inputs: List[str], label: List[List[int]], mask_length: Optional[List[int]] = None) -> List[float]:
"""Get inference-ppl per token given a list of inputs and label.
Args:
inputs (List[str]): A list of strings.
label (List[List[int]]): A list of list of label, each label is a tuple of (start, end, 1)
mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out. It's okay to skip
its implementation if advanced features in PPLInfernecer is
not needed.
Returns:
List[float]: A list of perplexity scores.
"""
assert self.tokenizer.pad_token
import torch
import torch.nn.functional as F
pad_token_id = self.tokenizer.pad_token_id
messages = _convert_base_messages(inputs)
tokenize_kwargs = dict(
return_tensors='pt',
padding=True,
truncation=True,
add_special_tokens=True,
max_length=self.max_seq_len,
)
self.tokenizer.padding_side = 'right'
self.tokenizer.truncation_side = 'right'
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}
outputs = self.model(**tokens)[0]
batch_size, seq_len, vocab_size = outputs.shape
shift_logits = outputs[:, :-1, :].contiguous().float()
shift_labels = tokens['input_ids'][:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, vocab_size),
shift_labels.view(-1),
ignore_index=pad_token_id,
reduction='none').view(batch_size, seq_len - 1)
lens = (tokens['input_ids'] != pad_token_id).sum(-1).cpu().numpy()
if mask_length is not None:
import numpy as np
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
for i in range(len(mask)):
for j in range(mask_length[i] - 1, len(mask[i])):
mask[i][j] = 1
loss = loss * mask
lens -= np.array(mask_length)
loss = loss.cpu().numpy()
decode_messages = [[self.tokenizer.decode([input_id]) for input_id in token] for token in tokens['input_ids']]
char_messages = [[ch for ch in message] for message in messages]
# shifted to align label and loss
for i in range(len(decode_messages)):
decode_messages[i] = decode_messages[i][1:]
aggregated_label_list = [[] for _ in range(len(decode_messages))]
tag_list = [[] for _ in range(len(decode_messages))]
for tmp_index, label_list in enumerate(label):
for single_label in label_list:
left = single_label[0]
right = single_label[1]
for i in range(left, right):
aggregated_label_list[tmp_index].append(i)
def align_sequences(seq1, seq2, sep_len):
"""
seq1: decoded sequence from token, one token may contain multiple characters
seq2: original separate character sequence
"""
i, j = 0, 0
matched_pairs = []
while i < len(seq1) and j < len(seq2):
word = seq1[i]
if len(word) == 0:
matched_pairs.append((word, []))
i += 1
continue
if '\ufffd' in word:
for _ in range(sep_len):
matched_pairs.append((word, [j]))
i += 1
j += 1
continue
char_sequence = ''
while j < len(seq2) and (char_sequence != word):
char_sequence += seq2[j]
if char_sequence == word:
matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)]))
j += 1
break
elif len(char_sequence) > len(word):
if word == char_sequence[-len(word):]:
matched_pairs.append((word, [k for k in range(j - len(word) + 1, j+1)]))
j += 1
break
else:
j += 1
else:
j += 1
i += 1
return matched_pairs
if 'qwen' in self.path or 'Qwen' in self.path:
sep_len = 2
elif 'Llama-3' in self.path:
sep_len = 2
elif 'Yi' in self.path:
sep_len = 3
elif 'Llama-2' in self.path:
sep_len = 3
elif 'deepseek' in self.path:
sep_len = 2
else:
sep_len = 3
matched_pairs_list = [align_sequences(decode_messages[i], char_messages[i], sep_len) for i in range(len(decode_messages))]
for match_index, matched_pairs in enumerate(matched_pairs_list):
for i, (word, indices) in enumerate(matched_pairs):
for j in indices:
if j in aggregated_label_list[match_index]:
tag_list[match_index].append(i)
break
inference_loss_list = []
token_len_list = []
for i in range(len(loss)):
inference_loss = 0
token_len = 0
for j in range(len(loss[i])):
if j in tag_list[i]:
inference_loss += loss[i][j]
print(loss[i][j])
token_len += 1
inference_loss_list.append(inference_loss)
token_len_list.append(token_len)
return inference_loss_list, token_len_list
def _get_potential_stop_words(self, path: Optional[str]):
from transformers import GenerationConfig
potential_stop_words = []

View File

@ -6,6 +6,7 @@ from .icl_circular_evaluator import CircularEvaluator # noqa
from .icl_em_evaluator import EMEvaluator # noqa
from .icl_hf_evaluator import * # noqa
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa
from .icl_misc_evaluator import AverageMinKEvaluator # noqa
from .icl_misc_evaluator import AveragePPLEvaluator # noqa
from .icl_plugin_evaluator import TEvalEvaluator # noqa

View File

@ -17,3 +17,11 @@ class AverageMinKEvaluator(BaseEvaluator):
def score(self, mink):
average_mink = sum(mink) / len(mink)
return {'average_mink': average_mink}
@ICL_EVALUATORS.register_module()
class AverageInferencePPLEvaluator(BaseEvaluator):
def score(self, ppl, token_len):
average_ppl = sum(ppl) / sum(token_len)
return {'average_ppl': average_ppl}

View File

@ -4,6 +4,8 @@ from .icl_base_inferencer import BaseInferencer # noqa
from .icl_chat_inferencer import ChatInferencer # noqa
from .icl_clp_inferencer import CLPInferencer # noqa
from .icl_gen_inferencer import GenInferencer # noqa
from .icl_inference_ppl_only_inferencer import \
InferencePPLOnlyInferencer # noqa
from .icl_ll_inferencer import LLInferencer # noqa
from .icl_mink_percent_inferencer import MinKPercentInferencer # noqa
from .icl_ppl_inferencer import PPLInferencer # noqa

View File

@ -0,0 +1,239 @@
"""PPL Inferencer."""
import os
from typing import List, Optional
import mmengine
import torch
from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class InferencePPLOnlyInferencer(BaseInferencer):
"""InferencePPLOnlyInferencer class to calculate Inference-PPL only, no
choice is made. This Inferencer is usually used along with
AverageInferencePPLEvaluator.
Attributes:
model (:obj:`BaseModel`, optional): The module to inference.
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
the LM.
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
save_every (:obj:`int`, optional): Save intermediate results every
"""
def __init__(
self,
model: BaseModel,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.save_every = save_every
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = InferencePPLOnlyInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list, label_list = self.get_generation_prompt_list_and_label(
ice_idx_list,
retriever,
max_seq_len=self.max_seq_len,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_list = [{
'prompt': prompt,
'label': label
} for prompt, label in zip(prompt_list, label_list)]
# 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
assert ds_reader.output_column is None, (
'InferencePPLOnlyInferencer supports `output_column=None` only.')
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if os.path.exists(tmp_json_filepath):
# TODO: move resume to output handler
try:
tmp_result_dict = mmengine.load(tmp_json_filepath)
except Exception:
pass
else:
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for datum in tqdm(dataloader, disable=not self.is_main_process):
entry = [datum_single['prompt'] for datum_single in datum]
label = [datum_single['label'] for datum_single in datum]
# 5-1. Inference with local model
with torch.no_grad():
(inference_loss_list,
token_len_list) = self.model.get_ppl_tokenwise_from_template(
entry, label)
parsed_entries = self.model.parse_template(entry, mode='gen')
# 5-3. Save current output
for prompt, inference_loss, token_len, in zip(
parsed_entries, inference_loss_list, token_len_list):
output_handler.save_results(prompt, inference_loss, token_len,
index)
index = index + 1
# 5-4. Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if os.path.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return [
sample['ppl'] for sample in output_handler.results_dict.values()
]
def get_generation_prompt_list_from_retriever_indices(
self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
max_seq_len: Optional[int] = None,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
prompt_list.append(prompt)
return prompt_list
def get_generation_prompt_list_and_label(
self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
max_seq_len: Optional[int] = None,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
label_list = []
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt, label = retriever.generate_prompt_and_label_for_generate_task( # noqa
idx,
ice,
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt, label = retriever.generate_prompt_for_generate_task( # noqa
idx,
ice,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
prompt_list.append(prompt)
label_list.append(label)
return prompt_list, label_list
class InferencePPLOnlyInferencerOutputHandler:
origin_prompt_dict = {}
output_dict = {}
results_dict = {}
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, os.path.join(save_dir, filename))
def save_results(self, origin_prompt, ppl, token_len, idx):
self.results_dict[str(idx)] = {
'origin_prompt': origin_prompt,
'ppl': ppl,
'token_len': token_len,
}

View File

@ -207,6 +207,59 @@ class BaseRetriever:
raise NotImplementedError(
'Leaving prompt as empty is not supported')
def generate_prompt_and_label_for_generate_task(
self,
idx,
ice,
gen_field_replace_token='',
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
"""Generate the prompt and the label info for one test example in
generative evaluation with `prompt_template`. If `prompt_template` is
not provided, the `ice_template` will be used to generate the prompt.
The token represented by `gen_field_replace_token` will not be replaced
by the generated text, or it will leaks the answer.
Args:
idx (`int`): The index of the test example.
ice (`str`): The in-context example for the test example.
gen_field_replace_token (`str`): The token of the answer in the
prompt. Defaults to ''.
ice_template (`Optional[PromptTemplate]`): The template for
in-context example. Defaults to None.
prompt_template (`Optional[PromptTemplate]`): The template for
prompt. Defaults to None.
"""
if prompt_template is not None and ice_template is not None:
if prompt_template.ice_token is not None:
return prompt_template.generate_item(
self.test_ds[idx],
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice), self.test_ds[idx]['label']
else:
raise NotImplementedError(
'ice_token of prompt_template is not provided')
elif ice_template is not None and prompt_template is None:
if ice_template.ice_token is not None:
return ice_template.generate_item(
self.test_ds[idx],
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice), self.test_ds[idx]['label']
else:
raise NotImplementedError(
'ice_token of ice_template is not provided')
elif ice_template is None and prompt_template is not None:
return prompt_template.generate_item(
self.test_ds[idx],
output_field=self.dataset_reader.output_column,
output_field_replace_token=gen_field_replace_token,
ice_field_replace_token=ice), self.test_ds[idx]['label']
else:
raise NotImplementedError(
'Leaving prompt as empty is not supported')
def generate_prompt_for_adv_generate_task(
self,
idx,