OpenCompass/opencompass/openicl/icl_inferencer/icl_sw_ce_loss_inferencer.py
Alexander Lam 35c94d0cde
[Feature] Adding support for LLM Compression Evaluation (#1108)
* fixed formatting based on pre-commit tests

* fixed typo in comments; reduced the number of models in the eval config

* fixed a bug in LLMCompressionDataset, where setting samples=None would result in passing test[:None] to load_dataset

* removed unnecessary variable in _format_table_pivot; changed lark_reporter message to English
2024-04-30 10:51:01 +08:00

353 lines
13 KiB
Python

"""Sliding Window Cross Entropy Loss Inferencer."""
import math
import os
from typing import List, Optional, Tuple, Union
import mmengine
import numpy as np
import torch
from datasets import Dataset as HFDataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class SWCELossInferencer(BaseInferencer):
"""SWCELossInferencer class to calculate cross entropy loss per batch based
on a sliding context window approach. This Inferencer is usually used along
with BPCEvaluator to calculate a models Bits per Character metric on a
given dataset.
Attributes:
model (:obj:`BaseModel`, optional): The module to inference.
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
the LM.
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
save_every (:obj:`int`, optional): Save intermediate results every
block_size (:obj:`int`, optional): Block size (window size) of
the sliding window on tokens
stride (:obj:`int`, optional): Stride (step size) of the
sliding window on tokens
"""
def __init__(
self,
model: BaseModel,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
block_size: Optional[int] = 1900,
stride: Optional[int] = 512,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.block_size = block_size
self.stride = stride
self.save_every = save_every
self.character_num = 0
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = SWCELossInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
items_dataset = self.get_encoding_from_retriever_indices(
ice_idx_list,
retriever,
max_seq_len=self.max_seq_len,
prompt_template=prompt_template)
# 3-1. Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
assert ds_reader.output_column is None, (
'SWCELossInferencer supports `output_column=None` only.')
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if os.path.exists(tmp_json_filepath):
# TODO: move resume to output handler
try:
tmp_result_dict = mmengine.load(tmp_json_filepath)
except Exception:
pass
else:
output_handler.results_dict = tmp_result_dict
index = len(
tmp_result_dict) # rewrite tmp_dataset on every run
# 4. Initialize torch dataset from items hf dataset
logger.info('Starting dataset building process...')
eval_dataset = SlidingWindowEvalDataset(
items_dataset,
block_size=self.block_size + 1,
stride=self.stride,
)
# 4-1. Construct Dataloader
dataloader = DataLoader(eval_dataset, self.batch_size, shuffle=False)
# 5. Calculate total loss in each batch
logger.info('Starting inference process...')
device = self.model.model.device
for ind, datum in enumerate(
tqdm(dataloader, disable=not self.is_main_process)):
if ind < index:
continue
encodings = datum['input_ids'] # encodings
attention_mask = datum['attention_mask']
# 5-1. Loss calculation by local model
with torch.no_grad():
if self.batch_size == 1:
input_ids = encodings[0:self.block_size].contiguous().to(
device)
targets = encodings[1:self.block_size +
1].contiguous().long().to(device)
attention_mask = attention_mask[1:self.block_size +
1].contiguous().to(device)
else:
input_ids = encodings[:,
0:self.block_size].contiguous().to(
device)
targets = encodings[:, 1:self.block_size +
1].contiguous().long().to(device)
attention_mask = attention_mask[:, 1:self.block_size +
1].contiguous().to(device)
logits = self.model.model(input_ids).logits
loss = self._get_cross_entropy(logits,
targets,
attention_mask=attention_mask)
loss = loss.cpu().item()
logger.info(f'loss: {loss:.8f}')
# 5-2. Save intermediate results
output_handler.save_results(loss, datum['total_chr_num'][0].item(),
index)
index = index + 1
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if os.path.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return [sample for sample in output_handler.results_dict.values()]
def get_encoding_from_retriever_indices(
self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
max_seq_len: Optional[int] = None,
prompt_template: Optional[PromptTemplate] = None,
dtype: str = 'auto') -> Tuple[List, List]:
vocab_size = self.model.tokenizer.vocab_size
if dtype == 'auto':
if vocab_size is None:
raise ValueError("vocab_size cannot be None when dtype='auto'")
if vocab_size is not None and vocab_size < 65500:
_dtype = np.uint16
else:
_dtype = np.int32
else:
_dtype = dtype
item_list = []
for idx, ice_idx in enumerate(ice_idx_list):
cur_item_dict = {}
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice='',
ice_template=None,
prompt_template=prompt_template)
cur_item_dict['prompt'] = prompt
# Get encodings from model tokenizer
# As long as block_size > max_seq_len, we can safely ignore the
# warning about token sequence length
cur_item_dict['encoding'] = np.array(
self.model.tokenizer.encode(prompt), dtype=_dtype)
item_list.append(cur_item_dict)
items = HFDataset.from_list(item_list)
return items
def _get_cross_entropy(self,
logits: torch.Tensor,
targets: torch.Tensor,
attention_mask: torch.Tensor = None):
"""Calculate cross entropy based on given logits, targets and
attention_mask for BPC loss calculation.
Args:
logits (np.ndarray): Model logits
targets (np.ndarray): Targets
attention_mask (torch.Tensor, optional): Attention mask.
Defaults to None.
Returns:
torch.Tensor: Total cross entropy on the given batch of logits and
targets reduced by summation
"""
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
if attention_mask is not None:
attention_mask = attention_mask.reshape(-1)
targets = targets.masked_fill(~attention_mask, -1)
return torch.nn.functional.cross_entropy(logits,
targets,
ignore_index=-1,
reduction='sum')
class SlidingWindowEvalDataset(Dataset):
def __init__(self,
data: HFDataset,
block_size: int = 1900,
stride: int = 512) -> None:
"""SlidingWindowEvalDataset.
Args:
data (HFDataset): HuggingFace dataset containing input samples
block_size (int, optional): Sliding context window size.
Defaults to 1900.
stride (int, optional): Sliding context window step size.
Defaults to 512.
"""
self.block_size = block_size
self.data = data
self.stride = stride
self._prepare()
self.prev_end_loc = 0
self.seq_len = len(self.data)
self.begin_loc = 0
def _prepare(self):
"""Prepare evaluation dataset by calculating total number of characters
and from original text and concatenating encodings into a single
array."""
self._curr_idx = 0
self._arr = []
self._total_chr_num = 0
for i in range(len(self.data)):
self._total_chr_num += len(self.data[i]['prompt'])
logger.info(f'data Dataset before concat: {self.data}')
self.data = np.concatenate([a['encoding'] for a in self.data], axis=0)
logger.info(f'data after concat: {self.data}')
logger.info(f'data after concat: {self.data.shape}')
def __len__(self):
return math.floor((len(self.data) - self.block_size) / self.stride + 1)
def __getitem__(self, item):
end_loc = min(self.begin_loc + self.block_size, self.seq_len)
trg_len = end_loc - self.prev_end_loc
input_ids = self.data[self.begin_loc:end_loc]
attention_mask = np.ones((len(input_ids), ), dtype=bool)
attention_mask[:-trg_len] = False
self.prev_end_loc = end_loc
self.begin_loc = self.begin_loc + self.stride
out_items = dict(
input_ids=torch.tensor(input_ids),
attention_mask=torch.tensor(attention_mask, dtype=bool),
total_chr_num=self._total_chr_num,
)
return out_items
@property
def total_chr_num(self):
return self._total_chr_num
class SWCELossInferencerOutputHandler:
origin_prompt_dict = {}
output_dict = {}
results_dict = {}
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, os.path.join(save_dir, filename))
def save_results(self, loss: float, total_chr_num: int,
idx: Union[str, int]) -> None:
self.results_dict[str(idx)] = {
'loss': loss,
'total_chr_num': total_chr_num
}