mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support Dataset Repeat and G-Pass Compute for Each Evaluator (#1886)
* support dataset repeat and g-pass compute for each evaluator * fix pre-commit errors * delete print * delete gpassk_evaluator and fix potential errors * change `repeat` to `n` * fix `repeat` to `n` in openicl_eval * update doc for multi-run and g-pass * update latex equation in doc * update eng doc for multi-run and g-pass * update datasets.md * update datasets.md * fix multi-line equation * fix multi-line equation * fix multi-line equation * fix multi-line equation * fix multi-line equation * fix multi-line equation * fix multi-line equation in zh_cn user_guides * mmodify pre-commit-zh-cn * recover pre-commit and edit math expr in doc * del [TIP] * del cite tag in doc * del extract_model param in livemathbench config
This commit is contained in:
parent
6042b88e58
commit
73c80953c6
@ -81,3 +81,43 @@ datasets += cmnli_datasets
|
|||||||
Users can choose different abilities, different datasets and different evaluation methods configuration files to build the part of the dataset in the evaluation script according to their needs.
|
Users can choose different abilities, different datasets and different evaluation methods configuration files to build the part of the dataset in the evaluation script according to their needs.
|
||||||
|
|
||||||
For information on how to start an evaluation task and how to evaluate self-built datasets, please refer to the relevant documents.
|
For information on how to start an evaluation task and how to evaluate self-built datasets, please refer to the relevant documents.
|
||||||
|
|
||||||
|
### Multiple Evaluations on the Dataset
|
||||||
|
|
||||||
|
In the dataset configuration, you can set the parameter `n` to perform multiple evaluations on the same dataset and return the average metrics, for example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
afqmc_datasets = [
|
||||||
|
dict(
|
||||||
|
abbr="afqmc-dev",
|
||||||
|
type=AFQMCDatasetV2,
|
||||||
|
path="./data/CLUE/AFQMC/dev.json",
|
||||||
|
n=10, # Perform 10 evaluations
|
||||||
|
reader_cfg=afqmc_reader_cfg,
|
||||||
|
infer_cfg=afqmc_infer_cfg,
|
||||||
|
eval_cfg=afqmc_eval_cfg,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Additionally, for binary evaluation metrics (such as accuracy, pass-rate, etc.), you can also set the parameter `k` in conjunction with `n` for [G-Pass@k](http://arxiv.org/abs/2412.13147) evaluation. The formula for G-Pass@k is:
|
||||||
|
|
||||||
|
```{math}
|
||||||
|
\text{G-Pass@}k_\tau=E_{\text{Data}}\left[ \sum_{j=\lceil \tau \cdot k \rceil}^c \frac{{c \choose j} \cdot {n - c \choose k - j}}{{n \choose k}} \right],
|
||||||
|
```
|
||||||
|
|
||||||
|
where $n$ is the number of evaluations, and $c$ is the number of times that passed or were correct out of $n$ runs. An example configuration is as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
aime2024_datasets = [
|
||||||
|
dict(
|
||||||
|
abbr='aime2024',
|
||||||
|
type=Aime2024Dataset,
|
||||||
|
path='opencompass/aime2024',
|
||||||
|
k=[2, 4], # Return results for G-Pass@2 and G-Pass@4
|
||||||
|
n=12, # 12 evaluations
|
||||||
|
...
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
@ -81,3 +81,42 @@ datasets += cmnli_datasets
|
|||||||
用户可以根据需要,选择不同能力不同数据集以及不同评测方式的配置文件来构建评测脚本中数据集的部分。
|
用户可以根据需要,选择不同能力不同数据集以及不同评测方式的配置文件来构建评测脚本中数据集的部分。
|
||||||
|
|
||||||
有关如何启动评测任务,以及如何评测自建数据集可以参考相关文档。
|
有关如何启动评测任务,以及如何评测自建数据集可以参考相关文档。
|
||||||
|
|
||||||
|
### 数据集多次评测
|
||||||
|
|
||||||
|
在数据集配置中可以通过设置参数`n`来对同一数据集进行多次评测,最终返回平均指标,例如:
|
||||||
|
|
||||||
|
```python
|
||||||
|
afqmc_datasets = [
|
||||||
|
dict(
|
||||||
|
abbr="afqmc-dev",
|
||||||
|
type=AFQMCDatasetV2,
|
||||||
|
path="./data/CLUE/AFQMC/dev.json",
|
||||||
|
n=10, # 进行10次评测
|
||||||
|
reader_cfg=afqmc_reader_cfg,
|
||||||
|
infer_cfg=afqmc_infer_cfg,
|
||||||
|
eval_cfg=afqmc_eval_cfg,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
另外,对于二值评测指标(例如accuracy,pass-rate等),还可以通过设置参数`k`配合`n`进行[G-Pass@k](http://arxiv.org/abs/2412.13147)评测。G-Pass@k计算公式为:
|
||||||
|
|
||||||
|
```{math}
|
||||||
|
\text{G-Pass@}k_\tau=E_{\text{Data}}\left[ \sum_{j=\lceil \tau \cdot k \rceil}^c \frac{{c \choose j} \cdot {n - c \choose k - j}}{{n \choose k}} \right],
|
||||||
|
```
|
||||||
|
|
||||||
|
其中 $n$ 为评测次数, $c$ 为 $n$ 次运行中通过或正确的次数。配置例子如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
aime2024_datasets = [
|
||||||
|
dict(
|
||||||
|
abbr='aime2024',
|
||||||
|
type=Aime2024Dataset,
|
||||||
|
path='opencompass/aime2024',
|
||||||
|
k=[2, 4], # 返回 G-Pass@2和G-Pass@4的结果
|
||||||
|
n=12, # 12次评测
|
||||||
|
...
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
@ -9,7 +9,7 @@ livemathbench_dataset = dict(
|
|||||||
type=LiveMathBenchDataset,
|
type=LiveMathBenchDataset,
|
||||||
path='',
|
path='',
|
||||||
k=16,
|
k=16,
|
||||||
replication=3,
|
n=48,
|
||||||
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
|
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
|
||||||
dataset_languages=['cn', 'en'],
|
dataset_languages=['cn', 'en'],
|
||||||
cot=True,
|
cot=True,
|
||||||
@ -38,13 +38,7 @@ livemathbench_dataset = dict(
|
|||||||
evaluator=dict(
|
evaluator=dict(
|
||||||
type=LiveMathBenchEvaluator,
|
type=LiveMathBenchEvaluator,
|
||||||
model_name='',
|
model_name='',
|
||||||
url=[],
|
url=[]
|
||||||
use_extract_model=False,
|
|
||||||
extract_url=[],
|
|
||||||
extract_model_name='',
|
|
||||||
k=[4, 8, 16],
|
|
||||||
replication=3,
|
|
||||||
thresholds=[0.0, 0.25, 0.5, 0.75, 1.0]
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -9,7 +9,7 @@ livemathbench_dataset = dict(
|
|||||||
type=LiveMathBenchDataset,
|
type=LiveMathBenchDataset,
|
||||||
path='',
|
path='',
|
||||||
k=1,
|
k=1,
|
||||||
replication=1,
|
n=1,
|
||||||
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
|
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
|
||||||
dataset_languages=['cn', 'en'],
|
dataset_languages=['cn', 'en'],
|
||||||
cot=True,
|
cot=True,
|
||||||
@ -38,13 +38,7 @@ livemathbench_dataset = dict(
|
|||||||
evaluator=dict(
|
evaluator=dict(
|
||||||
type=LiveMathBenchEvaluator,
|
type=LiveMathBenchEvaluator,
|
||||||
model_name='',
|
model_name='',
|
||||||
url=[],
|
url=[]
|
||||||
use_extract_model=False,
|
|
||||||
extract_url=[],
|
|
||||||
extract_model_name='',
|
|
||||||
k=[1],
|
|
||||||
replication=1,
|
|
||||||
thresholds=[0.0]
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Dict, Optional, Union
|
from copy import deepcopy
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict
|
from datasets import Dataset, DatasetDict
|
||||||
|
|
||||||
@ -7,8 +8,39 @@ from opencompass.openicl import DatasetReader
|
|||||||
|
|
||||||
class BaseDataset:
|
class BaseDataset:
|
||||||
|
|
||||||
def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs):
|
def __init__(self,
|
||||||
self.dataset = self.load(**kwargs)
|
reader_cfg: Optional[Dict] = {},
|
||||||
|
k: Union[int, List[int]] = 1,
|
||||||
|
n: int = 1,
|
||||||
|
**kwargs):
|
||||||
|
abbr = kwargs.pop('abbr', 'dataset')
|
||||||
|
dataset = self.load(**kwargs)
|
||||||
|
# maybe duplicate
|
||||||
|
assert (max(k) if isinstance(k, List) else
|
||||||
|
k) <= n, 'Maximum value of `k` must less than or equal to `n`'
|
||||||
|
if isinstance(dataset, Dataset):
|
||||||
|
examples = []
|
||||||
|
for idx, example in enumerate(dataset):
|
||||||
|
if 'subdivision' not in example:
|
||||||
|
example['subdivision'] = abbr
|
||||||
|
if 'idx' not in example:
|
||||||
|
example['idx'] = idx
|
||||||
|
examples.append(example)
|
||||||
|
examples = sum([deepcopy(examples) for _ in range(n)], [])
|
||||||
|
self.dataset = Dataset.from_list(examples)
|
||||||
|
else:
|
||||||
|
self.dataset = DatasetDict()
|
||||||
|
for key in dataset:
|
||||||
|
examples = []
|
||||||
|
for idx, example in enumerate(dataset[key]):
|
||||||
|
if 'subdivision' not in example:
|
||||||
|
example['subdivision'] = f'{abbr}_{key}'
|
||||||
|
if 'idx' not in example:
|
||||||
|
example['idx'] = idx
|
||||||
|
examples.append(example)
|
||||||
|
print(abbr, key, len(examples))
|
||||||
|
examples = sum([deepcopy(examples) for _ in range(n)], [])
|
||||||
|
self.dataset[key] = Dataset.from_list(examples)
|
||||||
self._init_reader(**reader_cfg)
|
self._init_reader(**reader_cfg)
|
||||||
|
|
||||||
def _init_reader(self, **kwargs):
|
def _init_reader(self, **kwargs):
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from copy import deepcopy
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Any, Callable, Dict, List, Union
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
import jsonlines
|
import jsonlines
|
||||||
import mmengine
|
import mmengine
|
||||||
@ -14,7 +12,7 @@ from datasets import Dataset, load_dataset
|
|||||||
|
|
||||||
from opencompass.datasets.math import MATHAgentEvaluator, math_postprocess_v2
|
from opencompass.datasets.math import MATHAgentEvaluator, math_postprocess_v2
|
||||||
from opencompass.models import OpenAISDK
|
from opencompass.models import OpenAISDK
|
||||||
from opencompass.openicl.icl_evaluator import GPassKEvaluator
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
|
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
|
||||||
dump_results_dict
|
dump_results_dict
|
||||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS
|
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS
|
||||||
@ -31,8 +29,6 @@ class LiveMathBenchDataset(BaseDataset):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(path: str,
|
def load(path: str,
|
||||||
k: Union[int, List[int]],
|
|
||||||
replication: int,
|
|
||||||
dataset_splits: List[str] = [
|
dataset_splits: List[str] = [
|
||||||
'CNMO',
|
'CNMO',
|
||||||
'CCEE',
|
'CCEE',
|
||||||
@ -104,17 +100,13 @@ class LiveMathBenchDataset(BaseDataset):
|
|||||||
('' if 'options' not in example else
|
('' if 'options' not in example else
|
||||||
' '.join(example['options']))),
|
' '.join(example['options']))),
|
||||||
})
|
})
|
||||||
max_k = k if isinstance(k, int) else max(k)
|
dataset.append(example)
|
||||||
for idx in range(max_k * replication):
|
|
||||||
duplicated_example = deepcopy(example)
|
|
||||||
duplicated_example.update({'replication_idx': idx})
|
|
||||||
dataset.append(duplicated_example)
|
|
||||||
|
|
||||||
return Dataset.from_list(dataset)
|
return Dataset.from_list(dataset)
|
||||||
|
|
||||||
|
|
||||||
@ICL_EVALUATORS.register_module()
|
@ICL_EVALUATORS.register_module()
|
||||||
class LiveMathBenchEvaluator(GPassKEvaluator):
|
class LiveMathBenchEvaluator(BaseEvaluator):
|
||||||
api_meta_template = dict(round=[
|
api_meta_template = dict(round=[
|
||||||
dict(role='HUMAN', api_role='HUMAN'),
|
dict(role='HUMAN', api_role='HUMAN'),
|
||||||
dict(role='BOT', api_role='BOT', generate=True),
|
dict(role='BOT', api_role='BOT', generate=True),
|
||||||
@ -125,11 +117,8 @@ class LiveMathBenchEvaluator(GPassKEvaluator):
|
|||||||
url,
|
url,
|
||||||
use_extract_model=False,
|
use_extract_model=False,
|
||||||
extract_url=[],
|
extract_url=[],
|
||||||
extract_model_name='',
|
extract_model_name=''):
|
||||||
k: Union[int, List[int]] = 16,
|
super().__init__()
|
||||||
replication: int = 3,
|
|
||||||
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]):
|
|
||||||
super().__init__(k, replication, thresholds)
|
|
||||||
|
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
url = [url]
|
url = [url]
|
||||||
@ -310,55 +299,18 @@ class LiveMathBenchEvaluator(GPassKEvaluator):
|
|||||||
def preprocess(self, predictions, references, test_set):
|
def preprocess(self, predictions, references, test_set):
|
||||||
return self.judge(predictions, references, test_set)
|
return self.judge(predictions, references, test_set)
|
||||||
|
|
||||||
def group(self, predictions, labels, test_set):
|
def score(self, predictions, references, test_set) -> Dict[str, Any]:
|
||||||
example2replications = {}
|
labels = self.preprocess(predictions, references, test_set)
|
||||||
for example, label, prediction in zip(test_set, labels, predictions):
|
results = {'accuracy': 100 * np.mean(labels), 'details': []}
|
||||||
example_abbr = f"{example['subdivision']}_{example['idx']}"
|
|
||||||
if example_abbr not in example2replications:
|
|
||||||
example2replications[example_abbr] = []
|
|
||||||
example.update({'prediction': prediction, 'label': label})
|
|
||||||
example2replications[example_abbr].append(example)
|
|
||||||
for _, replications in example2replications.items():
|
|
||||||
assert len(replications) == self.n, print(len(replications),
|
|
||||||
self.n)
|
|
||||||
return example2replications
|
|
||||||
|
|
||||||
def reduce(self, details) -> Dict[str, Any]:
|
for pred, ref, label in zip(predictions, references, labels):
|
||||||
"""Aggregate the overall metrics.
|
results['details'].append({
|
||||||
|
'pred': pred,
|
||||||
|
'ref': ref,
|
||||||
|
'correct': label
|
||||||
|
})
|
||||||
|
|
||||||
Return:
|
return results
|
||||||
A dict contains overall metrics, like:
|
|
||||||
{'details': details for each example, 'G-Pass@16': xxx}
|
|
||||||
"""
|
|
||||||
g_passk_details = OrderedDict()
|
|
||||||
g_passk_details['details'] = details
|
|
||||||
|
|
||||||
all_dataset = set([detail['subdivision'] for detail in details])
|
|
||||||
|
|
||||||
for k in self.k:
|
|
||||||
for subdivision in sorted(list(all_dataset)):
|
|
||||||
for threshold in self.thresholds:
|
|
||||||
g_passk_details[
|
|
||||||
f'{subdivision}/G-Pass@{k}_{threshold}'] = \
|
|
||||||
100. * np.mean(
|
|
||||||
[
|
|
||||||
detail[f'G-Pass@{k}_{threshold}']
|
|
||||||
for detail in details
|
|
||||||
if detail['subdivision'] == subdivision
|
|
||||||
])
|
|
||||||
g_passk_details[f'{subdivision}/mG-Pass@{k}'] = 100. * np.mean(
|
|
||||||
[
|
|
||||||
detail[f'mG-Pass@{k}'] for detail in details
|
|
||||||
if detail['subdivision'] == subdivision
|
|
||||||
])
|
|
||||||
|
|
||||||
for threshold in self.thresholds:
|
|
||||||
g_passk_details[f'G-Pass@{k}_{threshold}'] = 100. * np.mean(
|
|
||||||
[detail[f'G-Pass@{k}_{threshold}'] for detail in details])
|
|
||||||
g_passk_details[f'mG-Pass@{k}'] = 100. * np.mean(
|
|
||||||
[detail[f'mG-Pass@{k}'] for detail in details])
|
|
||||||
|
|
||||||
return g_passk_details
|
|
||||||
|
|
||||||
|
|
||||||
class LiveMathBenchOutputHandler:
|
class LiveMathBenchOutputHandler:
|
||||||
|
@ -4,7 +4,6 @@ from .icl_base_evaluator import BaseEvaluator # noqa
|
|||||||
from .icl_bpc_evaluator import BPCEvaluator # noqa
|
from .icl_bpc_evaluator import BPCEvaluator # noqa
|
||||||
from .icl_circular_evaluator import CircularEvaluator # noqa
|
from .icl_circular_evaluator import CircularEvaluator # noqa
|
||||||
from .icl_em_evaluator import EMEvaluator # noqa
|
from .icl_em_evaluator import EMEvaluator # noqa
|
||||||
from .icl_gpassk_evaluator import GPassKEvaluator # noqa
|
|
||||||
from .icl_hf_evaluator import * # noqa
|
from .icl_hf_evaluator import * # noqa
|
||||||
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
|
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
|
||||||
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa
|
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa
|
||||||
|
@ -1,4 +1,39 @@
|
|||||||
"""Base Evaluator."""
|
"""Base Evaluator."""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Dict, Iterable, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset
|
||||||
|
from scipy.stats import hypergeom
|
||||||
|
|
||||||
|
|
||||||
|
def compute_pass_at_k(n, c, k):
|
||||||
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
|
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_g_pass_at_k(n, c, k, m):
|
||||||
|
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
|
||||||
|
return 0.0
|
||||||
|
return hypergeom.sf(m - 1, n, c, k)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_g_pass_at_k(n, c, k, t):
|
||||||
|
m = max(int(np.ceil(k * t)), 1)
|
||||||
|
return _compute_g_pass_at_k(n, c, k, m)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mg_pass_at_k(n, c, k):
|
||||||
|
l, r = int(np.ceil(k * 0.5)), k
|
||||||
|
|
||||||
|
mg_pass_at_k = 0.0
|
||||||
|
for i in range(l + 1, r + 1):
|
||||||
|
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
|
||||||
|
mg_pass_at_k = 2 * mg_pass_at_k / k
|
||||||
|
|
||||||
|
return mg_pass_at_k
|
||||||
|
|
||||||
|
|
||||||
class BaseEvaluator:
|
class BaseEvaluator:
|
||||||
@ -6,6 +41,130 @@ class BaseEvaluator:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dir(self):
|
||||||
|
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
|
||||||
|
return self._out_dir
|
||||||
|
|
||||||
|
def group(self, n: int, details: List[Dict[str, Any]],
|
||||||
|
test_set: Dataset) -> Dict[str, Any]:
|
||||||
|
example2replications = {}
|
||||||
|
for detail, example in zip(details, test_set):
|
||||||
|
example_abbr = f"{example['subdivision']}_{example['idx']}"
|
||||||
|
if example_abbr not in example2replications:
|
||||||
|
example2replications[example_abbr] = []
|
||||||
|
example.update({'detail': detail})
|
||||||
|
example2replications[example_abbr].append(example)
|
||||||
|
for _, replications in example2replications.items():
|
||||||
|
assert len(replications) == n, print(len(replications), n)
|
||||||
|
return example2replications
|
||||||
|
|
||||||
|
def reduce(self, details: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
g_passk_details = OrderedDict()
|
||||||
|
all_subdivisions = set(
|
||||||
|
[detail['example_abbr'].split('_')[0] for detail in details])
|
||||||
|
all_metrics = list(details[0].keys())
|
||||||
|
|
||||||
|
for subdivision in sorted(list(all_subdivisions)):
|
||||||
|
for metric in all_metrics:
|
||||||
|
if metric in ['predictions', 'example_abbr']:
|
||||||
|
continue
|
||||||
|
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean([
|
||||||
|
detail[metric] for detail in details
|
||||||
|
if detail['example_abbr'].split('_')[0] == subdivision
|
||||||
|
])
|
||||||
|
|
||||||
|
for metric in all_metrics:
|
||||||
|
if metric in ['predictions', 'example_abbr']:
|
||||||
|
continue
|
||||||
|
g_passk_details[metric] = 100. * np.mean(
|
||||||
|
[detail[metric] for detail in details])
|
||||||
|
return g_passk_details
|
||||||
|
|
||||||
|
def evaluate(self, k: Union[int, List[int]], n: int,
|
||||||
|
original_dataset: Dataset, **score_kwargs):
|
||||||
|
real_size = len(original_dataset) // n
|
||||||
|
all_details = []
|
||||||
|
all_results = []
|
||||||
|
for i in range(n):
|
||||||
|
|
||||||
|
def select_fn(i, real_size, x):
|
||||||
|
if isinstance(x, Dataset):
|
||||||
|
return x.select(range(i * real_size, (i + 1) * real_size))
|
||||||
|
elif isinstance(x, Iterable):
|
||||||
|
return x[i * real_size:(i + 1) * real_size]
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
results = self.score(
|
||||||
|
**{
|
||||||
|
key: select_fn(i, real_size, value)
|
||||||
|
for key, value in score_kwargs.items()
|
||||||
|
})
|
||||||
|
details = results.pop('details', None)
|
||||||
|
if details is not None:
|
||||||
|
if isinstance(details, Dict):
|
||||||
|
details = list(details.values())
|
||||||
|
all_details.extend(details)
|
||||||
|
all_results.append(results)
|
||||||
|
|
||||||
|
eval_results = {}
|
||||||
|
for single_results in all_results:
|
||||||
|
for key in single_results:
|
||||||
|
if key not in eval_results:
|
||||||
|
eval_results[key] = []
|
||||||
|
eval_results[key].append(single_results[key])
|
||||||
|
for key in deepcopy(eval_results):
|
||||||
|
if isinstance(eval_results[key][0], float) or isinstance(
|
||||||
|
eval_results[key][0], int):
|
||||||
|
if n > 1:
|
||||||
|
eval_results[key + f' ({n} runs average)'] = np.mean(
|
||||||
|
eval_results[key])
|
||||||
|
eval_results.pop(key)
|
||||||
|
else:
|
||||||
|
eval_results[key] = np.mean(eval_results[key])
|
||||||
|
else:
|
||||||
|
eval_results[key] = eval_results[key][0]
|
||||||
|
|
||||||
|
grouped_examples = self.group(n, all_details, original_dataset)
|
||||||
|
can_calculate = False
|
||||||
|
if len(all_details) != 0:
|
||||||
|
eval_details = []
|
||||||
|
for example_abbr, examples in grouped_examples.items():
|
||||||
|
detail = {'predictions': [], 'example_abbr': example_abbr}
|
||||||
|
|
||||||
|
c = 0
|
||||||
|
for example in examples:
|
||||||
|
detail['predictions'].append(example['detail'])
|
||||||
|
# only compute G-Pass@k when details have correct labels
|
||||||
|
if example['detail'].get('correct', None) is not None:
|
||||||
|
can_calculate = True
|
||||||
|
c += int(example['detail']['correct'])
|
||||||
|
elif example['detail'].get('is_correct', None) is not None:
|
||||||
|
can_calculate = True
|
||||||
|
c += int(example['detail']['is_correct'])
|
||||||
|
|
||||||
|
if can_calculate and n > 1 and k > 1:
|
||||||
|
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||||
|
for _k in ([k] if isinstance(k, int) else k):
|
||||||
|
for threshold in thresholds:
|
||||||
|
g_pass = compute_g_pass_at_k(n=n,
|
||||||
|
c=c,
|
||||||
|
k=_k,
|
||||||
|
t=threshold)
|
||||||
|
detail[f'G-Pass@{_k}_{threshold}'] = g_pass
|
||||||
|
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(n=n,
|
||||||
|
c=c,
|
||||||
|
k=_k)
|
||||||
|
|
||||||
|
eval_details.append(detail)
|
||||||
|
|
||||||
|
if can_calculate and n > 1 and k > 1:
|
||||||
|
eval_results.update(self.reduce(eval_details))
|
||||||
|
eval_results['details'] = eval_details
|
||||||
|
|
||||||
|
return eval_results
|
||||||
|
|
||||||
def score(self):
|
def score(self):
|
||||||
raise NotImplementedError("Method hasn't been implemented yet")
|
raise NotImplementedError("Method hasn't been implemented yet")
|
||||||
|
|
||||||
|
@ -1,163 +0,0 @@
|
|||||||
from abc import abstractmethod
|
|
||||||
from typing import Any, Dict, List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from scipy.stats import hypergeom
|
|
||||||
|
|
||||||
from opencompass.registry import ICL_EVALUATORS
|
|
||||||
|
|
||||||
from .icl_base_evaluator import BaseEvaluator
|
|
||||||
|
|
||||||
|
|
||||||
def compute_pass_at_k(n, c, k):
|
|
||||||
if n - c < k:
|
|
||||||
return 1.0
|
|
||||||
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_g_pass_at_k(n, c, k, m):
|
|
||||||
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
|
|
||||||
return 0.0
|
|
||||||
return hypergeom.sf(m - 1, n, c, k)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_g_pass_at_k(n, c, k, t):
|
|
||||||
m = max(int(np.ceil(k * t)), 1)
|
|
||||||
return _compute_g_pass_at_k(n, c, k, m)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_mg_pass_at_k(n, c, k):
|
|
||||||
l, r = int(np.ceil(k * 0.5)), k
|
|
||||||
|
|
||||||
mg_pass_at_k = 0.0
|
|
||||||
for i in range(l + 1, r + 1):
|
|
||||||
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
|
|
||||||
mg_pass_at_k = 2 * mg_pass_at_k / k
|
|
||||||
|
|
||||||
return mg_pass_at_k
|
|
||||||
|
|
||||||
|
|
||||||
@ICL_EVALUATORS.register_module()
|
|
||||||
class GPassKEvaluator(BaseEvaluator):
|
|
||||||
"""Evaluator for computing the G-Pass@k Metric.
|
|
||||||
|
|
||||||
This evaluator performs the following steps:
|
|
||||||
1. Invokes task-specific `preprocess` on predictions to
|
|
||||||
assign a consistency label to each prediction and its
|
|
||||||
corresponding reference.
|
|
||||||
2. Calculates metrics for each input example based on
|
|
||||||
these labels.
|
|
||||||
3. Aggregates the overall metrics through a task-specific
|
|
||||||
`postprocess`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
k (int or list of int): Number of predictions to be
|
|
||||||
considered in G-Pass@k. It can be a single integer
|
|
||||||
(e.g., `k=16` computes G-Pass@16) or a list of
|
|
||||||
integers (e.g., `[4, 8, 16]` computes G-Pass@4,
|
|
||||||
G-Pass@8, and G-Pass@16).
|
|
||||||
|
|
||||||
replication (int): Controls the number of generations
|
|
||||||
used to estimate G-Pass@k. The total number of
|
|
||||||
generations is determined by multiplying the
|
|
||||||
maximum of `k` with `replication`. This parameter
|
|
||||||
should be a single integer.
|
|
||||||
|
|
||||||
thresholds (list of float): A list of floating-point
|
|
||||||
numbers that define the thresholds for the G-Pass@k
|
|
||||||
metric.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
k: Union[int, List[int]] = 16,
|
|
||||||
replication: int = 3,
|
|
||||||
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if isinstance(k, int):
|
|
||||||
k = [k]
|
|
||||||
|
|
||||||
self.k = k
|
|
||||||
self.replication = replication
|
|
||||||
self.n = max(k) * replication
|
|
||||||
self.thresholds = thresholds
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_dir(self):
|
|
||||||
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
|
|
||||||
return self._out_dir
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def preprocess(self, predictions, references, test_set) -> None:
|
|
||||||
"""Perform operations on predictions before computing metrics, for
|
|
||||||
example, do answer_extraction and model_judge in mathematical reasoning
|
|
||||||
task.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
labels: A list contains the label which indicates whether
|
|
||||||
prediction is consistency with reference at each position.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def group(self, predictions, labels, test_set) -> Dict[str, Any]:
|
|
||||||
"""Group the predictions and references.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A dict contains the grouped predictions and references.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reduce(self, details) -> Dict[str, Any]:
|
|
||||||
"""Aggregate the overall metrics.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A dict contains overall metrics, like:
|
|
||||||
{'details': details for each example, 'G-Pass@16': xxx}
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def score(self, predictions, references, test_set) -> Dict[str, Any]:
|
|
||||||
"""Compute G-Pass@k metrics.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A dict contains metrics for each dataset sample and
|
|
||||||
overall metrics reduced by `self.reduce`, like:
|
|
||||||
{'details': details for each example, 'G-Pass@16': xxx}
|
|
||||||
"""
|
|
||||||
labels = self.preprocess(predictions, references, test_set)
|
|
||||||
grouped_examples = self.group(predictions, labels, test_set)
|
|
||||||
|
|
||||||
details = []
|
|
||||||
total_pass_num, count = 0, 0
|
|
||||||
for example_abbr, examples in grouped_examples.items():
|
|
||||||
detail = {
|
|
||||||
k: v
|
|
||||||
for k, v in examples[0].items()
|
|
||||||
if k not in ['prediction', 'label']
|
|
||||||
}
|
|
||||||
detail.update({
|
|
||||||
'predictions': [{
|
|
||||||
'prediction': example['prediction'],
|
|
||||||
'label': example['label']
|
|
||||||
} for example in examples],
|
|
||||||
})
|
|
||||||
|
|
||||||
current_example_labels = [e['label'] for e in examples]
|
|
||||||
c = int(np.sum(current_example_labels))
|
|
||||||
|
|
||||||
for k in self.k:
|
|
||||||
for threshold in self.thresholds:
|
|
||||||
detail[f'G-Pass@{k}_{threshold}'] = compute_g_pass_at_k(
|
|
||||||
n=self.n, c=c, k=k, t=threshold)
|
|
||||||
detail[f'mG-Pass@{k}'] = compute_mg_pass_at_k(n=self.n,
|
|
||||||
c=c,
|
|
||||||
k=k)
|
|
||||||
count += self.n
|
|
||||||
total_pass_num += c
|
|
||||||
|
|
||||||
details.append(detail)
|
|
||||||
|
|
||||||
return self.reduce(details)
|
|
@ -240,7 +240,10 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
k: preds[k]
|
k: preds[k]
|
||||||
for k in signature(icl_evaluator.score).parameters
|
for k in signature(icl_evaluator.score).parameters
|
||||||
}
|
}
|
||||||
result = icl_evaluator.score(**preds)
|
k = self.dataset_cfg.get('k', 1)
|
||||||
|
n = self.dataset_cfg.get('n', 1)
|
||||||
|
result = icl_evaluator.evaluate(k, n, copy.deepcopy(test_set),
|
||||||
|
**preds)
|
||||||
|
|
||||||
# Get model postprocess result
|
# Get model postprocess result
|
||||||
model_details = None
|
model_details = None
|
||||||
@ -248,7 +251,9 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
if 'model_postprocessor' in self.eval_cfg:
|
if 'model_postprocessor' in self.eval_cfg:
|
||||||
model_preds = copy.deepcopy(preds)
|
model_preds = copy.deepcopy(preds)
|
||||||
model_preds['predictions'] = model_pred_strs
|
model_preds['predictions'] = model_pred_strs
|
||||||
model_result = icl_evaluator.score(**model_preds)
|
model_result = icl_evaluator.evaluate(k, n,
|
||||||
|
copy.deepcopy(test_set),
|
||||||
|
**model_preds)
|
||||||
for key in model_result:
|
for key in model_result:
|
||||||
if key == 'details':
|
if key == 'details':
|
||||||
model_details = model_result[key]
|
model_details = model_result[key]
|
||||||
|
@ -9,7 +9,6 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict):
|
|||||||
dataset_cfg = copy.deepcopy(dataset_cfg)
|
dataset_cfg = copy.deepcopy(dataset_cfg)
|
||||||
dataset_cfg.pop('infer_cfg', None)
|
dataset_cfg.pop('infer_cfg', None)
|
||||||
dataset_cfg.pop('eval_cfg', None)
|
dataset_cfg.pop('eval_cfg', None)
|
||||||
dataset_cfg.pop('abbr', None)
|
|
||||||
return LOAD_DATASET.build(dataset_cfg)
|
return LOAD_DATASET.build(dataset_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user