[Feature] Adding support for LLM Compression Evaluation (#1108)

* fixed formatting based on pre-commit tests

* fixed typo in comments; reduced the number of models in the eval config

* fixed a bug in LLMCompressionDataset, where setting samples=None would result in passing test[:None] to load_dataset

* removed unnecessary variable in _format_table_pivot; changed lark_reporter message to English
This commit is contained in:
Alexander Lam 2024-04-30 10:51:01 +08:00 committed by GitHub
parent 9c79224b39
commit 35c94d0cde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 847 additions and 0 deletions

3
.gitignore vendored
View File

@ -1,6 +1,7 @@
output_*/ output_*/
outputs/ outputs/
scripts/
icl_inference_output/ icl_inference_output/
.vscode/ .vscode/
tmp/ tmp/
@ -10,6 +11,8 @@ configs/secrets.py
configs/datasets/log.json configs/datasets/log.json
configs/eval_debug*.py configs/eval_debug*.py
configs/viz_*.py configs/viz_*.py
configs/**/*_bkup.py
opencompass/**/*_bkup.py
data data
work_dirs work_dirs
outputs outputs

View File

@ -0,0 +1,105 @@
# LLM Compression
## Introduction
The following introduction comes from the abstract of [Compression Represents Intelligence Linearly](https://arxiv.org/abs/2404.09937):
>There is a belief that learning to compress well will lead to intelligence. Recently, language modeling has been shown to be equivalent to compression, which offers a compelling rationale for the success of large language models (LLMs): the development of more advanced language models is essentially enhancing compression which facilitates intelligence. ...our findings suggest that compression efficiency, as an unsupervised metric derived from raw text corpora, serves as a reliable evaluation measure that is linearly associated with the model capabilities. We open-source our compression datasets as well as our data collection pipelines to facilitate future researchers to assess compression properly.
## Official Links
- Paper: [Compression Represents Intelligence Linearly](https://arxiv.org/abs/2404.09937)
- GitHub Repository: [llm-compression-intelligence](https://github.com/hkust-nlp/llm-compression-intelligence)
## Overview and Usage
### Dataset
The dataset, which consists of three external corpora, can be downloaded using the following python script:
```python
from os import os.path as osp
from datasets import load_dataset
data_path = "data/llm-compression"
subset_mapping = {
'arxiv_math': ['arxiv_math'],
'commoncraw': ['cc'],
'python': ['python'],
}
for key, value in subset_mapping.items():
llmc_dataset = load_dataset(r"hkust-nlp/llm-compression", name=value)
llmc_dataset["test"].to_json(osp.join(data_path, f"{key}.jsonl"))
```
Note: Refer to the original [repository](https://github.com/hkust-nlp/llm-compression-intelligence) for more details on data collection and design.
### Inference
The inference stage (`SWCELossInferencer`) consists of the following key steps:
1. For each candidate model, we obtain the encodings of each sample of the dataset using its tokenizer.
2. Concatenate the encodings of all samples into a single array and construct a PyTorch Dataset. Each item of `__getitem__` is a chunk of the array based on a sliding window. To reproduce results from the original paper, set `block_size=1900` and `stride=512`.
3. For each batch, calculate the cross entropy loss based on model logits and targets. The losses within each batch is reduced to a single loss by summation.
4. Output the losses and `total_chr_num` to `BPCEvaluator` for evaluation.
### Evaluation
`BPCEvaluator`: Using the total loss for each batch and the total number of characters in the original dataset from the inference stage, calculate the Bits per Character (BPC) metric for each model:
$$ BPC = \frac{TotalCrossEntropyLoss}{TotalCharacterNumber*log(2)} $$
### Summarization
### Config Files
1. Dataset config: `configs/datasets/llm-compression.py`
2. Evaluation config: `configs/eval_llm_compression.py`
## Evaluation Results
```
metric version model commoncraw python arxiv_math average
0 bpc af04af qwen1.5-32b-hf 0.5910 0.2584 0.4080 0.4191
1 bpc af04af qwen1.5-14b-hf 0.6459 0.2766 0.4310 0.4512
2 bpc af04af qwen-14b-hf 0.6197 0.2849 0.4498 0.4515
3 bpc af04af llama-30b-hf 0.5773 0.3212 0.4562 0.4516
4 bpc af04af llama-2-13b-hf 0.5807 0.3336 0.4752 0.4632
5 bpc af04af qwen1.5-7b-hf 0.6658 0.2935 0.4500 0.4698
6 bpc af04af qwen-7b-hf 0.6453 0.3088 0.4830 0.4790
7 bpc af04af llama-13b-hf 0.6083 0.3555 0.4865 0.4834
8 bpc af04af llama-2-7b-hf 0.6117 0.3536 0.4995 0.4883
9 bpc af04af llama-7b-hf 0.6285 0.3794 0.5096 0.5058
10 bpc af04af qwen1.5-1.8b-hf 0.7448 0.4029 0.5625 0.5701
11 bpc af04af qwen-1.8b-hf 0.7542 0.4175 0.5842 0.5853
12 bpc af04af qwen1.5-0.5b-hf 0.8102 0.4520 0.6181 0.6268
```
## FAQ
### I am getting this warning during inference, should I truncate long samples to `max_seq_len` to avoid further errors?
```
Token indices sequence length is longer than the specified maximum sequence length for this model. Running this sequence through the model will result in indexing errors
```
>A: This warning comes from the tokenizer indicating that the input sequence length exceeds the model's input length, but it does not affect the operation of the tokenizer. For loss calculation, as long as we set a `block_size` of the sliding window less than `max_seq_len`, we can safely ignore this warning.
## Reference
```
@misc{huang2024compression,
title={Compression Represents Intelligence Linearly},
author={Yuzhen Huang and Jinghan Zhang and Zifei Shan and Junxian He},
year={2024},
eprint={2404.09937},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@ -0,0 +1,50 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import SWCELossInferencer
from opencompass.openicl.icl_evaluator import BPCEvaluator
from opencompass.datasets import LLMCompressionDataset
# The three corpora for llm_compression used in the original paper
# See configs/datasets/llm_compression/README.md for more details
subset_mapping = {
'arxiv_math': ['arxiv_math'],
'commoncraw': ['cc'],
'python': ['python'],
}
# Build LLM Compression datasets
llm_compression_datasets = []
for _name in subset_mapping.keys():
llm_cmp_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="{content}",
),
# No in-context example, using ZeroRetriever
retriever=dict(type=ZeroRetriever),
# Calculates cross entropy loss for each batch based on a sliding context window
# Setting block_size=1900 and stride=512 according to the original paper
inferencer=dict(type=SWCELossInferencer, block_size=1900, stride=512),
)
# Calculates Bits per Character (BPC) based on the CE loss from the inference stage
llm_cmp_eval_cfg = dict(evaluator=dict(type=BPCEvaluator))
llm_compression_datasets.append(
dict(
abbr=f"llm_compression-{_name}",
type=LLMCompressionDataset,
path="./data/llm-compression",
name=_name,
samples=None, # Set small samples for testing
reader_cfg=dict(
input_columns=["content"],
output_column=None,
),
infer_cfg=llm_cmp_infer_cfg,
eval_cfg=llm_cmp_eval_cfg,
))
del _name

View File

@ -0,0 +1,65 @@
from mmengine.config import read_base
with read_base():
# LLM compression datasets
from .datasets.llm_compression.llm_compression import llm_compression_datasets
# Model configs
from .models.qwen.hf_qwen1_5_7b import models as qwen1_5_7b
from .models.qwen.hf_qwen1_5_14b import models as qwen1_5_14b
from .models.hf_llama.hf_llama2_7b import models as llama2_7b
from .models.hf_llama.hf_llama2_13b import models as llama2_13b
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
from opencompass.summarizers import LLMCompressionSummarizer
# -------------Inference Stage ----------------------------------------
datasets = [*llm_compression_datasets]
workdir = 'outputs/llm_compression'
models = [
*qwen1_5_7b,
*qwen1_5_14b,
*llama2_7b,
*llama2_13b,
]
# Set custom batch_size and num_gpus for faster loss calculation
# Smaller batch_size should give more precise results, at the cost of worse performance
model_cfg = dict(
batch_size=8,
run_cfg=dict(num_gpus=4, num_procs=1)
)
for mdl in models:
mdl.update(model_cfg)
infer = dict(
# The OpenCompass implementation of BPC currently only supports NaivePartitioner, as the sliding window approach requires the dataset to be loaded sequentially. Using other partitioner types may produce incorrect results.
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
task=dict(type=OpenICLInferTask),
max_num_workers=256, # Maximum concurrent evaluation task count
),
)
# -------------Evaluation Stage ----------------------------------------
eval = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
task=dict(type=OpenICLEvalTask),
max_num_workers=256,
)
)
# -------------Summarization Stage ----------------------------------------
summarizer = dict(type=LLMCompressionSummarizer)

View File

@ -60,6 +60,7 @@ from .lambada import * # noqa: F401, F403
from .lawbench import * # noqa: F401, F403 from .lawbench import * # noqa: F401, F403
from .lcsts import * # noqa: F401, F403 from .lcsts import * # noqa: F401, F403
from .leval import * # noqa: F401, F403 from .leval import * # noqa: F401, F403
from .llm_compression import LLMCompressionDataset # noqa: F401, F403
from .longbench import * # noqa: F401, F403 from .longbench import * # noqa: F401, F403
from .lveval import * # noqa: F401, F403 from .lveval import * # noqa: F401, F403
from .mastermath2024v1 import * # noqa: F401, F403 from .mastermath2024v1 import * # noqa: F401, F403

View File

@ -0,0 +1,36 @@
import os.path as osp
from typing import List
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class LLMCompressionDataset(BaseDataset):
@staticmethod
def load(path: str, name: List[str] = None, samples: int = None):
# Check if file exists in the given path
supported_extensions = ['json', 'jsonl']
for ext in supported_extensions:
filename = osp.join(
path, f'{name}.{ext}') # name refers to data subset name
if osp.exists(filename):
break
else:
raise FileNotFoundError(f'{filename} not found.')
samples = 'test' if samples is None else f'test[:{samples}]'
data_files = {'test': filename}
dataset = load_dataset('json', data_files=data_files, split=samples)
# Filter out empty samples
dataset = dataset.filter(lambda example: len(example['content']) > 0)
return dataset

View File

@ -1,6 +1,7 @@
from .icl_agent_evaluator import * # noqa from .icl_agent_evaluator import * # noqa
from .icl_aucroc_evaluator import AUCROCEvaluator # noqa from .icl_aucroc_evaluator import AUCROCEvaluator # noqa
from .icl_base_evaluator import BaseEvaluator # noqa from .icl_base_evaluator import BaseEvaluator # noqa
from .icl_bpc_evaluator import BPCEvaluator # noqa
from .icl_circular_evaluator import CircularEvaluator # noqa from .icl_circular_evaluator import CircularEvaluator # noqa
from .icl_em_evaluator import EMEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa
from .icl_hf_evaluator import * # noqa from .icl_hf_evaluator import * # noqa

View File

@ -0,0 +1,32 @@
from typing import List
import numpy as np
from opencompass.registry import ICL_EVALUATORS
from .icl_base_evaluator import BaseEvaluator
@ICL_EVALUATORS.register_module()
class BPCEvaluator(BaseEvaluator):
def score(self, loss: List[float], total_chr_num: List[float]):
"""Calculate bits per character based on inference results.
Args:
loss (List[float]): CrossEntropyLoss per batch x sliding
context window
total_chr_num (List[float]): Total number of characters
in the original dataset.
Returns:
Dict[str, float]: Bits per Character
"""
total_loss = sum(loss)
# Multiplying by log(2) to correct for the constant shift
# due to natural log used in the PyTorch implementation
# of CrossEntropyLoss
bpc = total_loss / (total_chr_num[0] * np.log(2))
return {'bpc': bpc}

View File

@ -9,4 +9,5 @@ from .icl_mink_percent_inferencer import MinKPercentInferencer # noqa
from .icl_ppl_inferencer import PPLInferencer # noqa from .icl_ppl_inferencer import PPLInferencer # noqa
from .icl_ppl_only_inferencer import PPLOnlyInferencer # noqa from .icl_ppl_only_inferencer import PPLOnlyInferencer # noqa
from .icl_sc_inferencer import SCInferencer # noqa from .icl_sc_inferencer import SCInferencer # noqa
from .icl_sw_ce_loss_inferencer import SWCELossInferencer # noqa
from .icl_tot_inferencer import ToTInferencer # noqa from .icl_tot_inferencer import ToTInferencer # noqa

View File

@ -0,0 +1,352 @@
"""Sliding Window Cross Entropy Loss Inferencer."""
import math
import os
from typing import List, Optional, Tuple, Union
import mmengine
import numpy as np
import torch
from datasets import Dataset as HFDataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class SWCELossInferencer(BaseInferencer):
"""SWCELossInferencer class to calculate cross entropy loss per batch based
on a sliding context window approach. This Inferencer is usually used along
with BPCEvaluator to calculate a models Bits per Character metric on a
given dataset.
Attributes:
model (:obj:`BaseModel`, optional): The module to inference.
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
the LM.
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
save_every (:obj:`int`, optional): Save intermediate results every
block_size (:obj:`int`, optional): Block size (window size) of
the sliding window on tokens
stride (:obj:`int`, optional): Stride (step size) of the
sliding window on tokens
"""
def __init__(
self,
model: BaseModel,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
block_size: Optional[int] = 1900,
stride: Optional[int] = 512,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.block_size = block_size
self.stride = stride
self.save_every = save_every
self.character_num = 0
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = SWCELossInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
items_dataset = self.get_encoding_from_retriever_indices(
ice_idx_list,
retriever,
max_seq_len=self.max_seq_len,
prompt_template=prompt_template)
# 3-1. Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
assert ds_reader.output_column is None, (
'SWCELossInferencer supports `output_column=None` only.')
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if os.path.exists(tmp_json_filepath):
# TODO: move resume to output handler
try:
tmp_result_dict = mmengine.load(tmp_json_filepath)
except Exception:
pass
else:
output_handler.results_dict = tmp_result_dict
index = len(
tmp_result_dict) # rewrite tmp_dataset on every run
# 4. Initialize torch dataset from items hf dataset
logger.info('Starting dataset building process...')
eval_dataset = SlidingWindowEvalDataset(
items_dataset,
block_size=self.block_size + 1,
stride=self.stride,
)
# 4-1. Construct Dataloader
dataloader = DataLoader(eval_dataset, self.batch_size, shuffle=False)
# 5. Calculate total loss in each batch
logger.info('Starting inference process...')
device = self.model.model.device
for ind, datum in enumerate(
tqdm(dataloader, disable=not self.is_main_process)):
if ind < index:
continue
encodings = datum['input_ids'] # encodings
attention_mask = datum['attention_mask']
# 5-1. Loss calculation by local model
with torch.no_grad():
if self.batch_size == 1:
input_ids = encodings[0:self.block_size].contiguous().to(
device)
targets = encodings[1:self.block_size +
1].contiguous().long().to(device)
attention_mask = attention_mask[1:self.block_size +
1].contiguous().to(device)
else:
input_ids = encodings[:,
0:self.block_size].contiguous().to(
device)
targets = encodings[:, 1:self.block_size +
1].contiguous().long().to(device)
attention_mask = attention_mask[:, 1:self.block_size +
1].contiguous().to(device)
logits = self.model.model(input_ids).logits
loss = self._get_cross_entropy(logits,
targets,
attention_mask=attention_mask)
loss = loss.cpu().item()
logger.info(f'loss: {loss:.8f}')
# 5-2. Save intermediate results
output_handler.save_results(loss, datum['total_chr_num'][0].item(),
index)
index = index + 1
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if os.path.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return [sample for sample in output_handler.results_dict.values()]
def get_encoding_from_retriever_indices(
self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
max_seq_len: Optional[int] = None,
prompt_template: Optional[PromptTemplate] = None,
dtype: str = 'auto') -> Tuple[List, List]:
vocab_size = self.model.tokenizer.vocab_size
if dtype == 'auto':
if vocab_size is None:
raise ValueError("vocab_size cannot be None when dtype='auto'")
if vocab_size is not None and vocab_size < 65500:
_dtype = np.uint16
else:
_dtype = np.int32
else:
_dtype = dtype
item_list = []
for idx, ice_idx in enumerate(ice_idx_list):
cur_item_dict = {}
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice='',
ice_template=None,
prompt_template=prompt_template)
cur_item_dict['prompt'] = prompt
# Get encodings from model tokenizer
# As long as block_size > max_seq_len, we can safely ignore the
# warning about token sequence length
cur_item_dict['encoding'] = np.array(
self.model.tokenizer.encode(prompt), dtype=_dtype)
item_list.append(cur_item_dict)
items = HFDataset.from_list(item_list)
return items
def _get_cross_entropy(self,
logits: torch.Tensor,
targets: torch.Tensor,
attention_mask: torch.Tensor = None):
"""Calculate cross entropy based on given logits, targets and
attention_mask for BPC loss calculation.
Args:
logits (np.ndarray): Model logits
targets (np.ndarray): Targets
attention_mask (torch.Tensor, optional): Attention mask.
Defaults to None.
Returns:
torch.Tensor: Total cross entropy on the given batch of logits and
targets reduced by summation
"""
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
if attention_mask is not None:
attention_mask = attention_mask.reshape(-1)
targets = targets.masked_fill(~attention_mask, -1)
return torch.nn.functional.cross_entropy(logits,
targets,
ignore_index=-1,
reduction='sum')
class SlidingWindowEvalDataset(Dataset):
def __init__(self,
data: HFDataset,
block_size: int = 1900,
stride: int = 512) -> None:
"""SlidingWindowEvalDataset.
Args:
data (HFDataset): HuggingFace dataset containing input samples
block_size (int, optional): Sliding context window size.
Defaults to 1900.
stride (int, optional): Sliding context window step size.
Defaults to 512.
"""
self.block_size = block_size
self.data = data
self.stride = stride
self._prepare()
self.prev_end_loc = 0
self.seq_len = len(self.data)
self.begin_loc = 0
def _prepare(self):
"""Prepare evaluation dataset by calculating total number of characters
and from original text and concatenating encodings into a single
array."""
self._curr_idx = 0
self._arr = []
self._total_chr_num = 0
for i in range(len(self.data)):
self._total_chr_num += len(self.data[i]['prompt'])
logger.info(f'data Dataset before concat: {self.data}')
self.data = np.concatenate([a['encoding'] for a in self.data], axis=0)
logger.info(f'data after concat: {self.data}')
logger.info(f'data after concat: {self.data.shape}')
def __len__(self):
return math.floor((len(self.data) - self.block_size) / self.stride + 1)
def __getitem__(self, item):
end_loc = min(self.begin_loc + self.block_size, self.seq_len)
trg_len = end_loc - self.prev_end_loc
input_ids = self.data[self.begin_loc:end_loc]
attention_mask = np.ones((len(input_ids), ), dtype=bool)
attention_mask[:-trg_len] = False
self.prev_end_loc = end_loc
self.begin_loc = self.begin_loc + self.stride
out_items = dict(
input_ids=torch.tensor(input_ids),
attention_mask=torch.tensor(attention_mask, dtype=bool),
total_chr_num=self._total_chr_num,
)
return out_items
@property
def total_chr_num(self):
return self._total_chr_num
class SWCELossInferencerOutputHandler:
origin_prompt_dict = {}
output_dict = {}
results_dict = {}
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, os.path.join(save_dir, filename))
def save_results(self, loss: float, total_chr_num: int,
idx: Union[str, int]) -> None:
self.results_dict[str(idx)] = {
'loss': loss,
'total_chr_num': total_chr_num
}

View File

@ -1,4 +1,5 @@
# flake8: noqa: F401, E501 # flake8: noqa: F401, E501
from .circular import CircularSummarizer # noqa: F401 from .circular import CircularSummarizer # noqa: F401
from .default import DefaultSummarizer # noqa: F401 from .default import DefaultSummarizer # noqa: F401
from .llm_compression import LLMCompressionSummarizer
from .subjective import * # noqa: F401 from .subjective import * # noqa: F401

View File

@ -0,0 +1,200 @@
import getpass
import os.path as osp
from datetime import datetime
from typing import List, Optional
import mmengine
import pandas as pd
from mmengine import ConfigDict
from opencompass.utils import dataset_abbr_from_cfg
from opencompass.utils.prompt import get_prompt_hash
from .default import DefaultSummarizer
class LLMCompressionSummarizer(DefaultSummarizer):
def __init__(self,
config: ConfigDict,
dataset_abbrs: Optional[List[str]] = None,
summary_groups: List = None,
prompt_db=None) -> None:
summary_groups = [] if summary_groups is None else summary_groups
super().__init__(config, dataset_abbrs, summary_groups, prompt_db)
def _format_table(self, parsed_results, dataset_metrics,
dataset_eval_mode):
dataset_abbrs = [
dataset_abbr_from_cfg(dataset) for dataset in self.dataset_cfgs
]
prompt_version = {
dataset_abbr_from_cfg(d): get_prompt_hash(d)[:6]
for d in self.dataset_cfgs
}
summarizer_dataset_abbrs = []
if self.dataset_abbrs is None:
# display all dataset metrics included in the config
for dataset_abbr in dataset_abbrs:
if dataset_abbr in dataset_metrics:
for metric in dataset_metrics[dataset_abbr]:
summarizer_dataset_abbrs.append((dataset_abbr, metric))
else:
summarizer_dataset_abbrs.append((dataset_abbr, None))
# along with all possible group metrics
for dataset_abbr in dataset_metrics:
for metric in dataset_metrics[dataset_abbr]:
if (dataset_abbr, metric) not in summarizer_dataset_abbrs:
summarizer_dataset_abbrs.append((dataset_abbr, metric))
else:
# follow the required order
for item in self.dataset_abbrs:
if isinstance(item, str):
summarizer_dataset_abbrs.append((item, None))
elif isinstance(item, (list, tuple)):
summarizer_dataset_abbrs.append((item[0], item[1]))
table = []
header = ['dataset', 'version', 'metric', 'mode'] + self.model_abbrs
table.append(header)
for dataset_abbr, metric in summarizer_dataset_abbrs:
if dataset_abbr not in dataset_metrics:
table.append([dataset_abbr, '-', '-', '-'] +
['-'] * len(self.model_abbrs))
continue
if metric is None:
metric = dataset_metrics[dataset_abbr][0]
elif metric in dataset_metrics[dataset_abbr]:
pass
else:
table.append([dataset_abbr, '-', '-', '-'] +
['-'] * len(self.model_abbrs))
continue
row = [
dataset_abbr,
prompt_version.get(dataset_abbr, '-'), metric,
dataset_eval_mode.get(dataset_abbr, '-')
]
for model_abbr in self.model_abbrs:
if dataset_abbr in parsed_results[model_abbr]:
row.append(
f'{parsed_results[model_abbr][dataset_abbr][metric]:.04f}' # noqa
)
else:
row.append('-')
table.append(row)
return table
def _format_table_pivot(self, table: List[List], decimals: int = 4):
"""Format table as a pandas dataframe and pivot so that columns are
datasets and rows are models.
Args:
table (List[List]): List of lists containing summary table rows
(including headers)
Returns:
pd.DataFrame: Summary dataframe sorted by ascending average BPC
"""
headers = table.pop(0)
table_df = pd.DataFrame(table, columns=headers)\
.drop(columns=['mode'])
dataset_names = {
'llm_compression-commoncraw': 'commoncraw',
'llm_compression-python': 'python',
'llm_compression-arxiv_math': 'arxiv_math',
}
# Pivot model columns to rows
table_df_long = table_df.melt(id_vars=['dataset', 'version', 'metric'],
var_name='model')
# Pivot dataset rows to columns
table_df_wide = table_df_long\
.pivot(index=['metric', 'version', 'model'], columns='dataset')\
.droplevel(0, axis=1)\
.reset_index()\
.rename(columns=dataset_names)
table_df_wide.columns.name = None
# Calculate average BPC per model
table_df_wide['average'] = table_df_wide[dataset_names.values()]\
.apply(pd.to_numeric)\
.mean(axis=1)\
.round(decimals)
table_df_wide = table_df_wide[[
'metric', 'version', 'model', *dataset_names.values(), 'average'
]]
return table_df_wide.sort_values(by='average')\
.reset_index(drop=True)
def _output_df_to_file(self, output_path: str, timestamp: str,
table: pd.DataFrame) -> None:
"""Output summary dataframe to file.
Args:
output_path (str): Output path
timestamp (str): Timestamp for file suffix
table (pd.DataFrame): Input dataframe
"""
if output_path is None:
output_path = osp.join(self.work_dir, 'summary')
output_csv_path = osp.join(output_path,
f'summary_pivot_{timestamp}.csv')
output_dir = osp.split(output_path)[0]
mmengine.mkdir_or_exist(output_dir)
table.to_csv(output_csv_path, encoding='utf-8', index=False)
self.logger.info(f'write csv to {osp.abspath(output_csv_path)}')
def summarize(
self,
output_path: str = None,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): # noqa
"""Summarize evaluation results and format output table.
Args:
output_path (str, optional): Output path. Defaults to None.
time_str (str, optional): Timestamp for file suffix. Defaults to
datetime.now().strftime('%Y%m%d_%H%M%S').
"""
# pick up results
raw_results, parsed_results, dataset_metrics, dataset_eval_mode = \
self._pick_up_results()
# calculate group metrics
raw_results, parsed_results, dataset_metrics, dataset_eval_mode = \
self._calculate_group_metrics(
raw_results,
parsed_results,
dataset_metrics,
dataset_eval_mode)
# format table
table = self._format_table(parsed_results, dataset_metrics,
dataset_eval_mode)
# convert to list of lists to pandas dataframe and pivot
table_df = self._format_table_pivot(table)
with pd.option_context('display.max_columns', 10):
print(table_df)
# format raw txt
raw_txts = self._format_raw_txt(raw_results)
# output to .text / .csv files
self._output_to_file(output_path, time_str, table, raw_txts)
self._output_df_to_file(output_path, time_str, table_df)
if self.lark_reporter:
content = f'Detailed evaluation summary for {getpass.getuser()}'
content += f' saved to {osp.abspath(output_path)}'
self.lark_reporter.post(content)