mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Adding support for LLM Compression Evaluation (#1108)
* fixed formatting based on pre-commit tests * fixed typo in comments; reduced the number of models in the eval config * fixed a bug in LLMCompressionDataset, where setting samples=None would result in passing test[:None] to load_dataset * removed unnecessary variable in _format_table_pivot; changed lark_reporter message to English
This commit is contained in:
parent
9c79224b39
commit
35c94d0cde
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
output_*/
|
output_*/
|
||||||
outputs/
|
outputs/
|
||||||
|
scripts/
|
||||||
icl_inference_output/
|
icl_inference_output/
|
||||||
.vscode/
|
.vscode/
|
||||||
tmp/
|
tmp/
|
||||||
@ -10,6 +11,8 @@ configs/secrets.py
|
|||||||
configs/datasets/log.json
|
configs/datasets/log.json
|
||||||
configs/eval_debug*.py
|
configs/eval_debug*.py
|
||||||
configs/viz_*.py
|
configs/viz_*.py
|
||||||
|
configs/**/*_bkup.py
|
||||||
|
opencompass/**/*_bkup.py
|
||||||
data
|
data
|
||||||
work_dirs
|
work_dirs
|
||||||
outputs
|
outputs
|
||||||
|
105
configs/datasets/llm_compression/README.md
Normal file
105
configs/datasets/llm_compression/README.md
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
# LLM Compression
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
The following introduction comes from the abstract of [Compression Represents Intelligence Linearly](https://arxiv.org/abs/2404.09937):
|
||||||
|
|
||||||
|
>There is a belief that learning to compress well will lead to intelligence. Recently, language modeling has been shown to be equivalent to compression, which offers a compelling rationale for the success of large language models (LLMs): the development of more advanced language models is essentially enhancing compression which facilitates intelligence. ...our findings suggest that compression efficiency, as an unsupervised metric derived from raw text corpora, serves as a reliable evaluation measure that is linearly associated with the model capabilities. We open-source our compression datasets as well as our data collection pipelines to facilitate future researchers to assess compression properly.
|
||||||
|
|
||||||
|
|
||||||
|
## Official Links
|
||||||
|
|
||||||
|
- Paper: [Compression Represents Intelligence Linearly](https://arxiv.org/abs/2404.09937)
|
||||||
|
- GitHub Repository: [llm-compression-intelligence](https://github.com/hkust-nlp/llm-compression-intelligence)
|
||||||
|
|
||||||
|
|
||||||
|
## Overview and Usage
|
||||||
|
|
||||||
|
### Dataset
|
||||||
|
The dataset, which consists of three external corpora, can be downloaded using the following python script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from os import os.path as osp
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
data_path = "data/llm-compression"
|
||||||
|
|
||||||
|
subset_mapping = {
|
||||||
|
'arxiv_math': ['arxiv_math'],
|
||||||
|
'commoncraw': ['cc'],
|
||||||
|
'python': ['python'],
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value in subset_mapping.items():
|
||||||
|
llmc_dataset = load_dataset(r"hkust-nlp/llm-compression", name=value)
|
||||||
|
llmc_dataset["test"].to_json(osp.join(data_path, f"{key}.jsonl"))
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: Refer to the original [repository](https://github.com/hkust-nlp/llm-compression-intelligence) for more details on data collection and design.
|
||||||
|
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
|
||||||
|
The inference stage (`SWCELossInferencer`) consists of the following key steps:
|
||||||
|
|
||||||
|
1. For each candidate model, we obtain the encodings of each sample of the dataset using its tokenizer.
|
||||||
|
2. Concatenate the encodings of all samples into a single array and construct a PyTorch Dataset. Each item of `__getitem__` is a chunk of the array based on a sliding window. To reproduce results from the original paper, set `block_size=1900` and `stride=512`.
|
||||||
|
3. For each batch, calculate the cross entropy loss based on model logits and targets. The losses within each batch is reduced to a single loss by summation.
|
||||||
|
4. Output the losses and `total_chr_num` to `BPCEvaluator` for evaluation.
|
||||||
|
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
`BPCEvaluator`: Using the total loss for each batch and the total number of characters in the original dataset from the inference stage, calculate the Bits per Character (BPC) metric for each model:
|
||||||
|
|
||||||
|
$$ BPC = \frac{TotalCrossEntropyLoss}{TotalCharacterNumber*log(2)} $$
|
||||||
|
|
||||||
|
|
||||||
|
### Summarization
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Config Files
|
||||||
|
|
||||||
|
1. Dataset config: `configs/datasets/llm-compression.py`
|
||||||
|
2. Evaluation config: `configs/eval_llm_compression.py`
|
||||||
|
|
||||||
|
## Evaluation Results
|
||||||
|
```
|
||||||
|
metric version model commoncraw python arxiv_math average
|
||||||
|
0 bpc af04af qwen1.5-32b-hf 0.5910 0.2584 0.4080 0.4191
|
||||||
|
1 bpc af04af qwen1.5-14b-hf 0.6459 0.2766 0.4310 0.4512
|
||||||
|
2 bpc af04af qwen-14b-hf 0.6197 0.2849 0.4498 0.4515
|
||||||
|
3 bpc af04af llama-30b-hf 0.5773 0.3212 0.4562 0.4516
|
||||||
|
4 bpc af04af llama-2-13b-hf 0.5807 0.3336 0.4752 0.4632
|
||||||
|
5 bpc af04af qwen1.5-7b-hf 0.6658 0.2935 0.4500 0.4698
|
||||||
|
6 bpc af04af qwen-7b-hf 0.6453 0.3088 0.4830 0.4790
|
||||||
|
7 bpc af04af llama-13b-hf 0.6083 0.3555 0.4865 0.4834
|
||||||
|
8 bpc af04af llama-2-7b-hf 0.6117 0.3536 0.4995 0.4883
|
||||||
|
9 bpc af04af llama-7b-hf 0.6285 0.3794 0.5096 0.5058
|
||||||
|
10 bpc af04af qwen1.5-1.8b-hf 0.7448 0.4029 0.5625 0.5701
|
||||||
|
11 bpc af04af qwen-1.8b-hf 0.7542 0.4175 0.5842 0.5853
|
||||||
|
12 bpc af04af qwen1.5-0.5b-hf 0.8102 0.4520 0.6181 0.6268
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
### I am getting this warning during inference, should I truncate long samples to `max_seq_len` to avoid further errors?
|
||||||
|
```
|
||||||
|
Token indices sequence length is longer than the specified maximum sequence length for this model. Running this sequence through the model will result in indexing errors
|
||||||
|
```
|
||||||
|
>A: This warning comes from the tokenizer indicating that the input sequence length exceeds the model's input length, but it does not affect the operation of the tokenizer. For loss calculation, as long as we set a `block_size` of the sliding window less than `max_seq_len`, we can safely ignore this warning.
|
||||||
|
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
```
|
||||||
|
@misc{huang2024compression,
|
||||||
|
title={Compression Represents Intelligence Linearly},
|
||||||
|
author={Yuzhen Huang and Jinghan Zhang and Zifei Shan and Junxian He},
|
||||||
|
year={2024},
|
||||||
|
eprint={2404.09937},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CL}
|
||||||
|
}
|
||||||
|
```
|
50
configs/datasets/llm_compression/llm_compression.py
Normal file
50
configs/datasets/llm_compression/llm_compression.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import SWCELossInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import BPCEvaluator
|
||||||
|
from opencompass.datasets import LLMCompressionDataset
|
||||||
|
|
||||||
|
|
||||||
|
# The three corpora for llm_compression used in the original paper
|
||||||
|
# See configs/datasets/llm_compression/README.md for more details
|
||||||
|
subset_mapping = {
|
||||||
|
'arxiv_math': ['arxiv_math'],
|
||||||
|
'commoncraw': ['cc'],
|
||||||
|
'python': ['python'],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Build LLM Compression datasets
|
||||||
|
llm_compression_datasets = []
|
||||||
|
for _name in subset_mapping.keys():
|
||||||
|
llm_cmp_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template="{content}",
|
||||||
|
),
|
||||||
|
# No in-context example, using ZeroRetriever
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
# Calculates cross entropy loss for each batch based on a sliding context window
|
||||||
|
# Setting block_size=1900 and stride=512 according to the original paper
|
||||||
|
inferencer=dict(type=SWCELossInferencer, block_size=1900, stride=512),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculates Bits per Character (BPC) based on the CE loss from the inference stage
|
||||||
|
llm_cmp_eval_cfg = dict(evaluator=dict(type=BPCEvaluator))
|
||||||
|
|
||||||
|
llm_compression_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f"llm_compression-{_name}",
|
||||||
|
type=LLMCompressionDataset,
|
||||||
|
path="./data/llm-compression",
|
||||||
|
name=_name,
|
||||||
|
samples=None, # Set small samples for testing
|
||||||
|
reader_cfg=dict(
|
||||||
|
input_columns=["content"],
|
||||||
|
output_column=None,
|
||||||
|
),
|
||||||
|
infer_cfg=llm_cmp_infer_cfg,
|
||||||
|
eval_cfg=llm_cmp_eval_cfg,
|
||||||
|
))
|
||||||
|
|
||||||
|
del _name
|
65
configs/eval_llm_compression.py
Normal file
65
configs/eval_llm_compression.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
# LLM compression datasets
|
||||||
|
from .datasets.llm_compression.llm_compression import llm_compression_datasets
|
||||||
|
|
||||||
|
# Model configs
|
||||||
|
from .models.qwen.hf_qwen1_5_7b import models as qwen1_5_7b
|
||||||
|
from .models.qwen.hf_qwen1_5_14b import models as qwen1_5_14b
|
||||||
|
from .models.hf_llama.hf_llama2_7b import models as llama2_7b
|
||||||
|
from .models.hf_llama.hf_llama2_13b import models as llama2_13b
|
||||||
|
|
||||||
|
|
||||||
|
from opencompass.partitioners import NaivePartitioner
|
||||||
|
from opencompass.runners import LocalRunner
|
||||||
|
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
|
||||||
|
from opencompass.summarizers import LLMCompressionSummarizer
|
||||||
|
|
||||||
|
|
||||||
|
# -------------Inference Stage ----------------------------------------
|
||||||
|
datasets = [*llm_compression_datasets]
|
||||||
|
workdir = 'outputs/llm_compression'
|
||||||
|
|
||||||
|
models = [
|
||||||
|
*qwen1_5_7b,
|
||||||
|
*qwen1_5_14b,
|
||||||
|
*llama2_7b,
|
||||||
|
*llama2_13b,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Set custom batch_size and num_gpus for faster loss calculation
|
||||||
|
# Smaller batch_size should give more precise results, at the cost of worse performance
|
||||||
|
model_cfg = dict(
|
||||||
|
batch_size=8,
|
||||||
|
run_cfg=dict(num_gpus=4, num_procs=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
for mdl in models:
|
||||||
|
mdl.update(model_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
infer = dict(
|
||||||
|
# The OpenCompass implementation of BPC currently only supports NaivePartitioner, as the sliding window approach requires the dataset to be loaded sequentially. Using other partitioner types may produce incorrect results.
|
||||||
|
partitioner=dict(type=NaivePartitioner),
|
||||||
|
runner=dict(
|
||||||
|
type=LocalRunner,
|
||||||
|
task=dict(type=OpenICLInferTask),
|
||||||
|
max_num_workers=256, # Maximum concurrent evaluation task count
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------Evaluation Stage ----------------------------------------
|
||||||
|
eval = dict(
|
||||||
|
partitioner=dict(type=NaivePartitioner),
|
||||||
|
runner=dict(
|
||||||
|
type=LocalRunner,
|
||||||
|
task=dict(type=OpenICLEvalTask),
|
||||||
|
max_num_workers=256,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------Summarization Stage ----------------------------------------
|
||||||
|
summarizer = dict(type=LLMCompressionSummarizer)
|
@ -60,6 +60,7 @@ from .lambada import * # noqa: F401, F403
|
|||||||
from .lawbench import * # noqa: F401, F403
|
from .lawbench import * # noqa: F401, F403
|
||||||
from .lcsts import * # noqa: F401, F403
|
from .lcsts import * # noqa: F401, F403
|
||||||
from .leval import * # noqa: F401, F403
|
from .leval import * # noqa: F401, F403
|
||||||
|
from .llm_compression import LLMCompressionDataset # noqa: F401, F403
|
||||||
from .longbench import * # noqa: F401, F403
|
from .longbench import * # noqa: F401, F403
|
||||||
from .lveval import * # noqa: F401, F403
|
from .lveval import * # noqa: F401, F403
|
||||||
from .mastermath2024v1 import * # noqa: F401, F403
|
from .mastermath2024v1 import * # noqa: F401, F403
|
||||||
|
36
opencompass/datasets/llm_compression.py
Normal file
36
opencompass/datasets/llm_compression.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import os.path as osp
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from opencompass.registry import LOAD_DATASET
|
||||||
|
|
||||||
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
|
class LLMCompressionDataset(BaseDataset):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: str, name: List[str] = None, samples: int = None):
|
||||||
|
|
||||||
|
# Check if file exists in the given path
|
||||||
|
supported_extensions = ['json', 'jsonl']
|
||||||
|
for ext in supported_extensions:
|
||||||
|
filename = osp.join(
|
||||||
|
path, f'{name}.{ext}') # name refers to data subset name
|
||||||
|
|
||||||
|
if osp.exists(filename):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f'{filename} not found.')
|
||||||
|
|
||||||
|
samples = 'test' if samples is None else f'test[:{samples}]'
|
||||||
|
|
||||||
|
data_files = {'test': filename}
|
||||||
|
dataset = load_dataset('json', data_files=data_files, split=samples)
|
||||||
|
|
||||||
|
# Filter out empty samples
|
||||||
|
dataset = dataset.filter(lambda example: len(example['content']) > 0)
|
||||||
|
|
||||||
|
return dataset
|
@ -1,6 +1,7 @@
|
|||||||
from .icl_agent_evaluator import * # noqa
|
from .icl_agent_evaluator import * # noqa
|
||||||
from .icl_aucroc_evaluator import AUCROCEvaluator # noqa
|
from .icl_aucroc_evaluator import AUCROCEvaluator # noqa
|
||||||
from .icl_base_evaluator import BaseEvaluator # noqa
|
from .icl_base_evaluator import BaseEvaluator # noqa
|
||||||
|
from .icl_bpc_evaluator import BPCEvaluator # noqa
|
||||||
from .icl_circular_evaluator import CircularEvaluator # noqa
|
from .icl_circular_evaluator import CircularEvaluator # noqa
|
||||||
from .icl_em_evaluator import EMEvaluator # noqa
|
from .icl_em_evaluator import EMEvaluator # noqa
|
||||||
from .icl_hf_evaluator import * # noqa
|
from .icl_hf_evaluator import * # noqa
|
||||||
|
32
opencompass/openicl/icl_evaluator/icl_bpc_evaluator.py
Normal file
32
opencompass/openicl/icl_evaluator/icl_bpc_evaluator.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from opencompass.registry import ICL_EVALUATORS
|
||||||
|
|
||||||
|
from .icl_base_evaluator import BaseEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_EVALUATORS.register_module()
|
||||||
|
class BPCEvaluator(BaseEvaluator):
|
||||||
|
|
||||||
|
def score(self, loss: List[float], total_chr_num: List[float]):
|
||||||
|
"""Calculate bits per character based on inference results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (List[float]): CrossEntropyLoss per batch x sliding
|
||||||
|
context window
|
||||||
|
total_chr_num (List[float]): Total number of characters
|
||||||
|
in the original dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, float]: Bits per Character
|
||||||
|
"""
|
||||||
|
total_loss = sum(loss)
|
||||||
|
|
||||||
|
# Multiplying by log(2) to correct for the constant shift
|
||||||
|
# due to natural log used in the PyTorch implementation
|
||||||
|
# of CrossEntropyLoss
|
||||||
|
bpc = total_loss / (total_chr_num[0] * np.log(2))
|
||||||
|
|
||||||
|
return {'bpc': bpc}
|
@ -9,4 +9,5 @@ from .icl_mink_percent_inferencer import MinKPercentInferencer # noqa
|
|||||||
from .icl_ppl_inferencer import PPLInferencer # noqa
|
from .icl_ppl_inferencer import PPLInferencer # noqa
|
||||||
from .icl_ppl_only_inferencer import PPLOnlyInferencer # noqa
|
from .icl_ppl_only_inferencer import PPLOnlyInferencer # noqa
|
||||||
from .icl_sc_inferencer import SCInferencer # noqa
|
from .icl_sc_inferencer import SCInferencer # noqa
|
||||||
|
from .icl_sw_ce_loss_inferencer import SWCELossInferencer # noqa
|
||||||
from .icl_tot_inferencer import ToTInferencer # noqa
|
from .icl_tot_inferencer import ToTInferencer # noqa
|
||||||
|
352
opencompass/openicl/icl_inferencer/icl_sw_ce_loss_inferencer.py
Normal file
352
opencompass/openicl/icl_inferencer/icl_sw_ce_loss_inferencer.py
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
"""Sliding Window Cross Entropy Loss Inferencer."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset as HFDataset
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from opencompass.models.base import BaseModel
|
||||||
|
from opencompass.registry import ICL_INFERENCERS
|
||||||
|
|
||||||
|
from ..icl_prompt_template import PromptTemplate
|
||||||
|
from ..icl_retriever import BaseRetriever
|
||||||
|
from ..utils import get_logger
|
||||||
|
from .icl_base_inferencer import BaseInferencer, dump_results_dict
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_INFERENCERS.register_module()
|
||||||
|
class SWCELossInferencer(BaseInferencer):
|
||||||
|
"""SWCELossInferencer class to calculate cross entropy loss per batch based
|
||||||
|
on a sliding context window approach. This Inferencer is usually used along
|
||||||
|
with BPCEvaluator to calculate a models Bits per Character metric on a
|
||||||
|
given dataset.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (:obj:`BaseModel`, optional): The module to inference.
|
||||||
|
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
|
||||||
|
the LM.
|
||||||
|
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
|
||||||
|
output_json_filepath (:obj:`str`, optional): File path for output
|
||||||
|
`JSON` file.
|
||||||
|
output_json_filename (:obj:`str`, optional): File name for output
|
||||||
|
`JSON` file.
|
||||||
|
save_every (:obj:`int`, optional): Save intermediate results every
|
||||||
|
block_size (:obj:`int`, optional): Block size (window size) of
|
||||||
|
the sliding window on tokens
|
||||||
|
stride (:obj:`int`, optional): Stride (step size) of the
|
||||||
|
sliding window on tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: BaseModel,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = 1,
|
||||||
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||||
|
output_json_filename: Optional[str] = 'predictions',
|
||||||
|
save_every: Optional[int] = 1,
|
||||||
|
block_size: Optional[int] = 1900,
|
||||||
|
stride: Optional[int] = 512,
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
output_json_filename=output_json_filename,
|
||||||
|
output_json_filepath=output_json_filepath,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_size = block_size
|
||||||
|
self.stride = stride
|
||||||
|
self.save_every = save_every
|
||||||
|
self.character_num = 0
|
||||||
|
|
||||||
|
def inference(self,
|
||||||
|
retriever: BaseRetriever,
|
||||||
|
ice_template: Optional[PromptTemplate] = None,
|
||||||
|
prompt_template: Optional[PromptTemplate] = None,
|
||||||
|
output_json_filepath: Optional[str] = None,
|
||||||
|
output_json_filename: Optional[str] = None) -> List:
|
||||||
|
|
||||||
|
# 1. Preparation for output logs
|
||||||
|
output_handler = SWCELossInferencerOutputHandler()
|
||||||
|
|
||||||
|
if output_json_filepath is None:
|
||||||
|
output_json_filepath = self.output_json_filepath
|
||||||
|
if output_json_filename is None:
|
||||||
|
output_json_filename = self.output_json_filename
|
||||||
|
|
||||||
|
# 2. Get results of retrieval process
|
||||||
|
ice_idx_list = retriever.retrieve()
|
||||||
|
|
||||||
|
# 3. Generate prompts for testing input
|
||||||
|
items_dataset = self.get_encoding_from_retriever_indices(
|
||||||
|
ice_idx_list,
|
||||||
|
retriever,
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
|
||||||
|
# 3-1. Fetch and zip prompt & gold answer if output column exists
|
||||||
|
ds_reader = retriever.dataset_reader
|
||||||
|
|
||||||
|
assert ds_reader.output_column is None, (
|
||||||
|
'SWCELossInferencer supports `output_column=None` only.')
|
||||||
|
|
||||||
|
# Create tmp json file for saving intermediate results and future
|
||||||
|
# resuming
|
||||||
|
index = 0
|
||||||
|
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||||
|
'tmp_' + output_json_filename)
|
||||||
|
if os.path.exists(tmp_json_filepath):
|
||||||
|
# TODO: move resume to output handler
|
||||||
|
try:
|
||||||
|
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
output_handler.results_dict = tmp_result_dict
|
||||||
|
index = len(
|
||||||
|
tmp_result_dict) # rewrite tmp_dataset on every run
|
||||||
|
|
||||||
|
# 4. Initialize torch dataset from items hf dataset
|
||||||
|
logger.info('Starting dataset building process...')
|
||||||
|
|
||||||
|
eval_dataset = SlidingWindowEvalDataset(
|
||||||
|
items_dataset,
|
||||||
|
block_size=self.block_size + 1,
|
||||||
|
stride=self.stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4-1. Construct Dataloader
|
||||||
|
dataloader = DataLoader(eval_dataset, self.batch_size, shuffle=False)
|
||||||
|
|
||||||
|
# 5. Calculate total loss in each batch
|
||||||
|
logger.info('Starting inference process...')
|
||||||
|
|
||||||
|
device = self.model.model.device
|
||||||
|
for ind, datum in enumerate(
|
||||||
|
tqdm(dataloader, disable=not self.is_main_process)):
|
||||||
|
|
||||||
|
if ind < index:
|
||||||
|
continue
|
||||||
|
|
||||||
|
encodings = datum['input_ids'] # encodings
|
||||||
|
attention_mask = datum['attention_mask']
|
||||||
|
|
||||||
|
# 5-1. Loss calculation by local model
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.batch_size == 1:
|
||||||
|
input_ids = encodings[0:self.block_size].contiguous().to(
|
||||||
|
device)
|
||||||
|
targets = encodings[1:self.block_size +
|
||||||
|
1].contiguous().long().to(device)
|
||||||
|
attention_mask = attention_mask[1:self.block_size +
|
||||||
|
1].contiguous().to(device)
|
||||||
|
else:
|
||||||
|
input_ids = encodings[:,
|
||||||
|
0:self.block_size].contiguous().to(
|
||||||
|
device)
|
||||||
|
targets = encodings[:, 1:self.block_size +
|
||||||
|
1].contiguous().long().to(device)
|
||||||
|
attention_mask = attention_mask[:, 1:self.block_size +
|
||||||
|
1].contiguous().to(device)
|
||||||
|
|
||||||
|
logits = self.model.model(input_ids).logits
|
||||||
|
loss = self._get_cross_entropy(logits,
|
||||||
|
targets,
|
||||||
|
attention_mask=attention_mask)
|
||||||
|
loss = loss.cpu().item()
|
||||||
|
|
||||||
|
logger.info(f'loss: {loss:.8f}')
|
||||||
|
|
||||||
|
# 5-2. Save intermediate results
|
||||||
|
output_handler.save_results(loss, datum['total_chr_num'][0].item(),
|
||||||
|
index)
|
||||||
|
index = index + 1
|
||||||
|
|
||||||
|
if (self.save_every is not None and index % self.save_every == 0
|
||||||
|
and self.is_main_process):
|
||||||
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
'tmp_' + output_json_filename)
|
||||||
|
|
||||||
|
# 6. Output
|
||||||
|
if self.is_main_process:
|
||||||
|
os.makedirs(output_json_filepath, exist_ok=True)
|
||||||
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
output_json_filename)
|
||||||
|
if os.path.exists(tmp_json_filepath):
|
||||||
|
os.remove(tmp_json_filepath)
|
||||||
|
|
||||||
|
return [sample for sample in output_handler.results_dict.values()]
|
||||||
|
|
||||||
|
def get_encoding_from_retriever_indices(
|
||||||
|
self,
|
||||||
|
ice_idx_list: List[List[int]],
|
||||||
|
retriever: BaseRetriever,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
prompt_template: Optional[PromptTemplate] = None,
|
||||||
|
dtype: str = 'auto') -> Tuple[List, List]:
|
||||||
|
|
||||||
|
vocab_size = self.model.tokenizer.vocab_size
|
||||||
|
|
||||||
|
if dtype == 'auto':
|
||||||
|
if vocab_size is None:
|
||||||
|
raise ValueError("vocab_size cannot be None when dtype='auto'")
|
||||||
|
if vocab_size is not None and vocab_size < 65500:
|
||||||
|
_dtype = np.uint16
|
||||||
|
else:
|
||||||
|
_dtype = np.int32
|
||||||
|
else:
|
||||||
|
_dtype = dtype
|
||||||
|
|
||||||
|
item_list = []
|
||||||
|
for idx, ice_idx in enumerate(ice_idx_list):
|
||||||
|
cur_item_dict = {}
|
||||||
|
|
||||||
|
prompt = retriever.generate_prompt_for_generate_task(
|
||||||
|
idx,
|
||||||
|
ice='',
|
||||||
|
ice_template=None,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
|
||||||
|
cur_item_dict['prompt'] = prompt
|
||||||
|
|
||||||
|
# Get encodings from model tokenizer
|
||||||
|
# As long as block_size > max_seq_len, we can safely ignore the
|
||||||
|
# warning about token sequence length
|
||||||
|
cur_item_dict['encoding'] = np.array(
|
||||||
|
self.model.tokenizer.encode(prompt), dtype=_dtype)
|
||||||
|
|
||||||
|
item_list.append(cur_item_dict)
|
||||||
|
|
||||||
|
items = HFDataset.from_list(item_list)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
def _get_cross_entropy(self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor = None):
|
||||||
|
"""Calculate cross entropy based on given logits, targets and
|
||||||
|
attention_mask for BPC loss calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (np.ndarray): Model logits
|
||||||
|
targets (np.ndarray): Targets
|
||||||
|
attention_mask (torch.Tensor, optional): Attention mask.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Total cross entropy on the given batch of logits and
|
||||||
|
targets reduced by summation
|
||||||
|
"""
|
||||||
|
logits = logits.reshape(-1, logits.size(-1))
|
||||||
|
targets = targets.reshape(-1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.reshape(-1)
|
||||||
|
targets = targets.masked_fill(~attention_mask, -1)
|
||||||
|
|
||||||
|
return torch.nn.functional.cross_entropy(logits,
|
||||||
|
targets,
|
||||||
|
ignore_index=-1,
|
||||||
|
reduction='sum')
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingWindowEvalDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
data: HFDataset,
|
||||||
|
block_size: int = 1900,
|
||||||
|
stride: int = 512) -> None:
|
||||||
|
"""SlidingWindowEvalDataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (HFDataset): HuggingFace dataset containing input samples
|
||||||
|
block_size (int, optional): Sliding context window size.
|
||||||
|
Defaults to 1900.
|
||||||
|
stride (int, optional): Sliding context window step size.
|
||||||
|
Defaults to 512.
|
||||||
|
"""
|
||||||
|
self.block_size = block_size
|
||||||
|
self.data = data
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
self._prepare()
|
||||||
|
self.prev_end_loc = 0
|
||||||
|
self.seq_len = len(self.data)
|
||||||
|
self.begin_loc = 0
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
"""Prepare evaluation dataset by calculating total number of characters
|
||||||
|
and from original text and concatenating encodings into a single
|
||||||
|
array."""
|
||||||
|
self._curr_idx = 0
|
||||||
|
self._arr = []
|
||||||
|
|
||||||
|
self._total_chr_num = 0
|
||||||
|
for i in range(len(self.data)):
|
||||||
|
self._total_chr_num += len(self.data[i]['prompt'])
|
||||||
|
|
||||||
|
logger.info(f'data Dataset before concat: {self.data}')
|
||||||
|
|
||||||
|
self.data = np.concatenate([a['encoding'] for a in self.data], axis=0)
|
||||||
|
|
||||||
|
logger.info(f'data after concat: {self.data}')
|
||||||
|
logger.info(f'data after concat: {self.data.shape}')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return math.floor((len(self.data) - self.block_size) / self.stride + 1)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
end_loc = min(self.begin_loc + self.block_size, self.seq_len)
|
||||||
|
trg_len = end_loc - self.prev_end_loc
|
||||||
|
|
||||||
|
input_ids = self.data[self.begin_loc:end_loc]
|
||||||
|
|
||||||
|
attention_mask = np.ones((len(input_ids), ), dtype=bool)
|
||||||
|
attention_mask[:-trg_len] = False
|
||||||
|
|
||||||
|
self.prev_end_loc = end_loc
|
||||||
|
self.begin_loc = self.begin_loc + self.stride
|
||||||
|
|
||||||
|
out_items = dict(
|
||||||
|
input_ids=torch.tensor(input_ids),
|
||||||
|
attention_mask=torch.tensor(attention_mask, dtype=bool),
|
||||||
|
total_chr_num=self._total_chr_num,
|
||||||
|
)
|
||||||
|
return out_items
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_chr_num(self):
|
||||||
|
return self._total_chr_num
|
||||||
|
|
||||||
|
|
||||||
|
class SWCELossInferencerOutputHandler:
|
||||||
|
origin_prompt_dict = {}
|
||||||
|
output_dict = {}
|
||||||
|
results_dict = {}
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.results_dict = {}
|
||||||
|
|
||||||
|
def write_to_json(self, save_dir: str, filename: str):
|
||||||
|
"""Dump the result to a json file."""
|
||||||
|
dump_results_dict(self.results_dict, os.path.join(save_dir, filename))
|
||||||
|
|
||||||
|
def save_results(self, loss: float, total_chr_num: int,
|
||||||
|
idx: Union[str, int]) -> None:
|
||||||
|
|
||||||
|
self.results_dict[str(idx)] = {
|
||||||
|
'loss': loss,
|
||||||
|
'total_chr_num': total_chr_num
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
# flake8: noqa: F401, E501
|
# flake8: noqa: F401, E501
|
||||||
from .circular import CircularSummarizer # noqa: F401
|
from .circular import CircularSummarizer # noqa: F401
|
||||||
from .default import DefaultSummarizer # noqa: F401
|
from .default import DefaultSummarizer # noqa: F401
|
||||||
|
from .llm_compression import LLMCompressionSummarizer
|
||||||
from .subjective import * # noqa: F401
|
from .subjective import * # noqa: F401
|
||||||
|
200
opencompass/summarizers/llm_compression.py
Normal file
200
opencompass/summarizers/llm_compression.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
import getpass
|
||||||
|
import os.path as osp
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import pandas as pd
|
||||||
|
from mmengine import ConfigDict
|
||||||
|
|
||||||
|
from opencompass.utils import dataset_abbr_from_cfg
|
||||||
|
from opencompass.utils.prompt import get_prompt_hash
|
||||||
|
|
||||||
|
from .default import DefaultSummarizer
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCompressionSummarizer(DefaultSummarizer):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: ConfigDict,
|
||||||
|
dataset_abbrs: Optional[List[str]] = None,
|
||||||
|
summary_groups: List = None,
|
||||||
|
prompt_db=None) -> None:
|
||||||
|
|
||||||
|
summary_groups = [] if summary_groups is None else summary_groups
|
||||||
|
super().__init__(config, dataset_abbrs, summary_groups, prompt_db)
|
||||||
|
|
||||||
|
def _format_table(self, parsed_results, dataset_metrics,
|
||||||
|
dataset_eval_mode):
|
||||||
|
dataset_abbrs = [
|
||||||
|
dataset_abbr_from_cfg(dataset) for dataset in self.dataset_cfgs
|
||||||
|
]
|
||||||
|
prompt_version = {
|
||||||
|
dataset_abbr_from_cfg(d): get_prompt_hash(d)[:6]
|
||||||
|
for d in self.dataset_cfgs
|
||||||
|
}
|
||||||
|
|
||||||
|
summarizer_dataset_abbrs = []
|
||||||
|
if self.dataset_abbrs is None:
|
||||||
|
# display all dataset metrics included in the config
|
||||||
|
for dataset_abbr in dataset_abbrs:
|
||||||
|
if dataset_abbr in dataset_metrics:
|
||||||
|
for metric in dataset_metrics[dataset_abbr]:
|
||||||
|
summarizer_dataset_abbrs.append((dataset_abbr, metric))
|
||||||
|
else:
|
||||||
|
summarizer_dataset_abbrs.append((dataset_abbr, None))
|
||||||
|
# along with all possible group metrics
|
||||||
|
for dataset_abbr in dataset_metrics:
|
||||||
|
for metric in dataset_metrics[dataset_abbr]:
|
||||||
|
if (dataset_abbr, metric) not in summarizer_dataset_abbrs:
|
||||||
|
summarizer_dataset_abbrs.append((dataset_abbr, metric))
|
||||||
|
else:
|
||||||
|
# follow the required order
|
||||||
|
for item in self.dataset_abbrs:
|
||||||
|
if isinstance(item, str):
|
||||||
|
summarizer_dataset_abbrs.append((item, None))
|
||||||
|
elif isinstance(item, (list, tuple)):
|
||||||
|
summarizer_dataset_abbrs.append((item[0], item[1]))
|
||||||
|
|
||||||
|
table = []
|
||||||
|
header = ['dataset', 'version', 'metric', 'mode'] + self.model_abbrs
|
||||||
|
table.append(header)
|
||||||
|
for dataset_abbr, metric in summarizer_dataset_abbrs:
|
||||||
|
if dataset_abbr not in dataset_metrics:
|
||||||
|
table.append([dataset_abbr, '-', '-', '-'] +
|
||||||
|
['-'] * len(self.model_abbrs))
|
||||||
|
continue
|
||||||
|
if metric is None:
|
||||||
|
metric = dataset_metrics[dataset_abbr][0]
|
||||||
|
elif metric in dataset_metrics[dataset_abbr]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
table.append([dataset_abbr, '-', '-', '-'] +
|
||||||
|
['-'] * len(self.model_abbrs))
|
||||||
|
continue
|
||||||
|
|
||||||
|
row = [
|
||||||
|
dataset_abbr,
|
||||||
|
prompt_version.get(dataset_abbr, '-'), metric,
|
||||||
|
dataset_eval_mode.get(dataset_abbr, '-')
|
||||||
|
]
|
||||||
|
for model_abbr in self.model_abbrs:
|
||||||
|
if dataset_abbr in parsed_results[model_abbr]:
|
||||||
|
row.append(
|
||||||
|
f'{parsed_results[model_abbr][dataset_abbr][metric]:.04f}' # noqa
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
row.append('-')
|
||||||
|
table.append(row)
|
||||||
|
return table
|
||||||
|
|
||||||
|
def _format_table_pivot(self, table: List[List], decimals: int = 4):
|
||||||
|
"""Format table as a pandas dataframe and pivot so that columns are
|
||||||
|
datasets and rows are models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table (List[List]): List of lists containing summary table rows
|
||||||
|
(including headers)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Summary dataframe sorted by ascending average BPC
|
||||||
|
"""
|
||||||
|
headers = table.pop(0)
|
||||||
|
table_df = pd.DataFrame(table, columns=headers)\
|
||||||
|
.drop(columns=['mode'])
|
||||||
|
|
||||||
|
dataset_names = {
|
||||||
|
'llm_compression-commoncraw': 'commoncraw',
|
||||||
|
'llm_compression-python': 'python',
|
||||||
|
'llm_compression-arxiv_math': 'arxiv_math',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pivot model columns to rows
|
||||||
|
table_df_long = table_df.melt(id_vars=['dataset', 'version', 'metric'],
|
||||||
|
var_name='model')
|
||||||
|
|
||||||
|
# Pivot dataset rows to columns
|
||||||
|
table_df_wide = table_df_long\
|
||||||
|
.pivot(index=['metric', 'version', 'model'], columns='dataset')\
|
||||||
|
.droplevel(0, axis=1)\
|
||||||
|
.reset_index()\
|
||||||
|
.rename(columns=dataset_names)
|
||||||
|
table_df_wide.columns.name = None
|
||||||
|
|
||||||
|
# Calculate average BPC per model
|
||||||
|
table_df_wide['average'] = table_df_wide[dataset_names.values()]\
|
||||||
|
.apply(pd.to_numeric)\
|
||||||
|
.mean(axis=1)\
|
||||||
|
.round(decimals)
|
||||||
|
|
||||||
|
table_df_wide = table_df_wide[[
|
||||||
|
'metric', 'version', 'model', *dataset_names.values(), 'average'
|
||||||
|
]]
|
||||||
|
|
||||||
|
return table_df_wide.sort_values(by='average')\
|
||||||
|
.reset_index(drop=True)
|
||||||
|
|
||||||
|
def _output_df_to_file(self, output_path: str, timestamp: str,
|
||||||
|
table: pd.DataFrame) -> None:
|
||||||
|
"""Output summary dataframe to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str): Output path
|
||||||
|
timestamp (str): Timestamp for file suffix
|
||||||
|
table (pd.DataFrame): Input dataframe
|
||||||
|
"""
|
||||||
|
if output_path is None:
|
||||||
|
output_path = osp.join(self.work_dir, 'summary')
|
||||||
|
|
||||||
|
output_csv_path = osp.join(output_path,
|
||||||
|
f'summary_pivot_{timestamp}.csv')
|
||||||
|
|
||||||
|
output_dir = osp.split(output_path)[0]
|
||||||
|
mmengine.mkdir_or_exist(output_dir)
|
||||||
|
|
||||||
|
table.to_csv(output_csv_path, encoding='utf-8', index=False)
|
||||||
|
self.logger.info(f'write csv to {osp.abspath(output_csv_path)}')
|
||||||
|
|
||||||
|
def summarize(
|
||||||
|
self,
|
||||||
|
output_path: str = None,
|
||||||
|
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): # noqa
|
||||||
|
"""Summarize evaluation results and format output table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str, optional): Output path. Defaults to None.
|
||||||
|
time_str (str, optional): Timestamp for file suffix. Defaults to
|
||||||
|
datetime.now().strftime('%Y%m%d_%H%M%S').
|
||||||
|
"""
|
||||||
|
# pick up results
|
||||||
|
raw_results, parsed_results, dataset_metrics, dataset_eval_mode = \
|
||||||
|
self._pick_up_results()
|
||||||
|
|
||||||
|
# calculate group metrics
|
||||||
|
raw_results, parsed_results, dataset_metrics, dataset_eval_mode = \
|
||||||
|
self._calculate_group_metrics(
|
||||||
|
raw_results,
|
||||||
|
parsed_results,
|
||||||
|
dataset_metrics,
|
||||||
|
dataset_eval_mode)
|
||||||
|
|
||||||
|
# format table
|
||||||
|
table = self._format_table(parsed_results, dataset_metrics,
|
||||||
|
dataset_eval_mode)
|
||||||
|
|
||||||
|
# convert to list of lists to pandas dataframe and pivot
|
||||||
|
table_df = self._format_table_pivot(table)
|
||||||
|
with pd.option_context('display.max_columns', 10):
|
||||||
|
print(table_df)
|
||||||
|
|
||||||
|
# format raw txt
|
||||||
|
raw_txts = self._format_raw_txt(raw_results)
|
||||||
|
|
||||||
|
# output to .text / .csv files
|
||||||
|
self._output_to_file(output_path, time_str, table, raw_txts)
|
||||||
|
self._output_df_to_file(output_path, time_str, table_df)
|
||||||
|
|
||||||
|
if self.lark_reporter:
|
||||||
|
content = f'Detailed evaluation summary for {getpass.getuser()}'
|
||||||
|
content += f' saved to {osp.abspath(output_path)}'
|
||||||
|
self.lark_reporter.post(content)
|
Loading…
Reference in New Issue
Block a user