mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
support dataset repeat and g-pass compute for each evaluator
This commit is contained in:
parent
046b6f75c6
commit
8def69369a
@ -9,7 +9,7 @@ livemathbench_dataset = dict(
|
||||
type=LiveMathBenchDataset,
|
||||
path='',
|
||||
k=16,
|
||||
replication=3,
|
||||
repeat=3,
|
||||
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
|
||||
dataset_languages=['cn', 'en'],
|
||||
cot=True,
|
||||
@ -43,7 +43,7 @@ livemathbench_dataset = dict(
|
||||
extract_url=[],
|
||||
extract_model_name='',
|
||||
k=[4, 8, 16],
|
||||
replication=3,
|
||||
repeat=3,
|
||||
thresholds=[0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
)
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ livemathbench_dataset = dict(
|
||||
type=LiveMathBenchDataset,
|
||||
path='',
|
||||
k=1,
|
||||
replication=1,
|
||||
repeat=1,
|
||||
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
|
||||
dataset_languages=['cn', 'en'],
|
||||
cot=True,
|
||||
@ -43,7 +43,7 @@ livemathbench_dataset = dict(
|
||||
extract_url=[],
|
||||
extract_model_name='',
|
||||
k=[1],
|
||||
replication=1,
|
||||
repeat=1,
|
||||
thresholds=[0.0]
|
||||
)
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from abc import abstractstaticmethod
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Union, List
|
||||
from copy import deepcopy
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
@ -8,8 +9,38 @@ from opencompass.openicl import DatasetReader
|
||||
|
||||
class BaseDataset:
|
||||
|
||||
def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs):
|
||||
self.dataset = self.load(**kwargs)
|
||||
def __init__(self,
|
||||
reader_cfg: Optional[Dict] = {},
|
||||
k: Union[int, List[int]] = 1,
|
||||
repeat: int = 1,
|
||||
**kwargs):
|
||||
abbr = kwargs.pop('abbr', 'dataset')
|
||||
dataset = self.load(**kwargs)
|
||||
# maybe duplicate
|
||||
n = (max(k) if isinstance(k, List) else k) * repeat
|
||||
if isinstance(dataset, Dataset):
|
||||
examples = []
|
||||
for idx, example in enumerate(dataset):
|
||||
if 'subdivision' not in example:
|
||||
example['subdivision'] = abbr
|
||||
if 'idx' not in example:
|
||||
example['idx'] = idx
|
||||
examples.append(example)
|
||||
examples = sum([deepcopy(examples) for _ in range(n)], [])
|
||||
self.dataset = Dataset.from_list(examples)
|
||||
else:
|
||||
self.dataset = DatasetDict()
|
||||
for key in dataset:
|
||||
examples = []
|
||||
for idx, example in enumerate(dataset[key]):
|
||||
if 'subdivision' not in example:
|
||||
example['subdivision'] = f'{abbr}_{key}'
|
||||
if 'idx' not in example:
|
||||
example['idx'] = idx
|
||||
examples.append(example)
|
||||
print(abbr, key, len(examples))
|
||||
examples = sum([deepcopy(examples) for _ in range(n)], [])
|
||||
self.dataset[key] = Dataset.from_list(examples)
|
||||
self._init_reader(**reader_cfg)
|
||||
|
||||
def _init_reader(self, **kwargs):
|
||||
|
@ -2,7 +2,6 @@ import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
@ -31,8 +30,6 @@ class LiveMathBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str,
|
||||
k: Union[int, List[int]],
|
||||
replication: int,
|
||||
dataset_splits: List[str] = [
|
||||
'CNMO',
|
||||
'CCEE',
|
||||
@ -104,11 +101,7 @@ class LiveMathBenchDataset(BaseDataset):
|
||||
('' if 'options' not in example else
|
||||
' '.join(example['options']))),
|
||||
})
|
||||
max_k = k if isinstance(k, int) else max(k)
|
||||
for idx in range(max_k * replication):
|
||||
duplicated_example = deepcopy(example)
|
||||
duplicated_example.update({'replication_idx': idx})
|
||||
dataset.append(duplicated_example)
|
||||
dataset.append(example)
|
||||
|
||||
return Dataset.from_list(dataset)
|
||||
|
||||
@ -127,9 +120,9 @@ class LiveMathBenchEvaluator(GPassKEvaluator):
|
||||
extract_url=[],
|
||||
extract_model_name='',
|
||||
k: Union[int, List[int]] = 16,
|
||||
replication: int = 3,
|
||||
repeat: int = 3,
|
||||
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]):
|
||||
super().__init__(k, replication, thresholds)
|
||||
super().__init__(k, repeat, thresholds)
|
||||
|
||||
if isinstance(url, str):
|
||||
url = [url]
|
||||
|
@ -1,11 +1,146 @@
|
||||
"""Base Evaluator."""
|
||||
from typing import Union, List, Dict, Any, Iterable
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import hypergeom
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
def compute_pass_at_k(n, c, k):
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
||||
|
||||
def _compute_g_pass_at_k(n, c, k, m):
|
||||
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
|
||||
return 0.0
|
||||
return hypergeom.sf(m - 1, n, c, k)
|
||||
|
||||
def compute_g_pass_at_k(n, c, k, t):
|
||||
m = max(int(np.ceil(k * t)), 1)
|
||||
return _compute_g_pass_at_k(n, c, k, m)
|
||||
|
||||
def compute_mg_pass_at_k(n, c, k):
|
||||
l, r = int(np.ceil(k * 0.5)), k
|
||||
|
||||
mg_pass_at_k = 0.0
|
||||
for i in range(l + 1, r + 1):
|
||||
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
|
||||
mg_pass_at_k = 2 * mg_pass_at_k / k
|
||||
|
||||
return mg_pass_at_k
|
||||
|
||||
|
||||
class BaseEvaluator:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def group(self, n: int, details: List[Dict[str, Any]], test_set: Dataset) -> Dict[str, Any]:
|
||||
example2replications = {}
|
||||
for detail, example in zip(details, test_set):
|
||||
example_abbr = f"{example['subdivision']}_{example['idx']}"
|
||||
if example_abbr not in example2replications:
|
||||
example2replications[example_abbr] = []
|
||||
example.update({'detail': detail})
|
||||
example2replications[example_abbr].append(example)
|
||||
for _, replications in example2replications.items():
|
||||
assert len(replications) == n, print(len(replications), n)
|
||||
return example2replications
|
||||
|
||||
def reduce(self, details: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
g_passk_details = OrderedDict()
|
||||
all_subdivisions = set([detail['example_abbr'].split('_')[0] for detail in details])
|
||||
all_metrics = list(details[0].keys())
|
||||
|
||||
for subdivision in sorted(list(all_subdivisions)):
|
||||
for metric in all_metrics:
|
||||
if metric in ['predictions', 'example_abbr']:
|
||||
continue
|
||||
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean([
|
||||
detail[metric]
|
||||
for detail in details
|
||||
if detail['example_abbr'].split('_')[0] == subdivision
|
||||
])
|
||||
|
||||
for metric in all_metrics:
|
||||
if metric in ['predictions', 'example_abbr']:
|
||||
continue
|
||||
g_passk_details[metric] = 100. * np.mean([detail[metric] for detail in details])
|
||||
return g_passk_details
|
||||
|
||||
def evaluate(self, k: Union[int, List[int]],
|
||||
repeat: int, test_set: Dataset, **score_kwargs):
|
||||
n = (max(k) if isinstance(k, List) else k) * repeat
|
||||
print(len(score_kwargs['predictions']))
|
||||
real_size = len(test_set) // n
|
||||
all_details = []
|
||||
all_results = []
|
||||
for i in range(n):
|
||||
results = self.score(**{
|
||||
key: value[i * real_size: (i + 1) * real_size] if isinstance(value, Iterable) else value
|
||||
for key, value in score_kwargs.items()
|
||||
})
|
||||
details = results.pop('details', None)
|
||||
if details is not None:
|
||||
if isinstance(details, Dict):
|
||||
details = list(details.values())
|
||||
all_details.extend(details)
|
||||
all_results.append(results)
|
||||
|
||||
eval_results = {}
|
||||
for single_results in all_results:
|
||||
for key in single_results:
|
||||
if key not in eval_results:
|
||||
eval_results[key] = []
|
||||
eval_results[key].append(single_results[key])
|
||||
for key in deepcopy(eval_results):
|
||||
if isinstance(eval_results[key][0], float) or isinstance(eval_results[key][0], int):
|
||||
if n > 1:
|
||||
eval_results[key + f' ({n // repeat}x{repeat}={n} runs average)'] = np.mean(eval_results[key])
|
||||
eval_results.pop(key)
|
||||
else:
|
||||
eval_results[key] = np.mean(eval_results[key])
|
||||
else:
|
||||
eval_results[key] = eval_results[key][0]
|
||||
|
||||
grouped_examples = self.group(n, all_details, test_set)
|
||||
if len(all_details) != 0:
|
||||
eval_details = []
|
||||
for example_abbr, examples in grouped_examples.items():
|
||||
detail = {
|
||||
'predictions': [],
|
||||
'example_abbr': example_abbr
|
||||
}
|
||||
|
||||
c = 0
|
||||
can_calculate = False
|
||||
for example in examples:
|
||||
detail['predictions'].append(example['detail'])
|
||||
# only compute G-Pass@k when details have correct labels
|
||||
if example['detail'].get('correct', None) is not None:
|
||||
can_calculate = True
|
||||
c += int(example['detail']['correct'])
|
||||
elif example['detail'].get('is_correct', None) is not None:
|
||||
can_calculate = True
|
||||
c += int(example['detail']['is_correct'])
|
||||
|
||||
if can_calculate:
|
||||
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
for _k in ([k] if isinstance(k, int) else k):
|
||||
for threshold in thresholds:
|
||||
detail[f'G-Pass@{_k}_{threshold}'] = compute_g_pass_at_k(
|
||||
n=n, c=c, k=_k, t=threshold)
|
||||
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(n=n, c=c, k=_k)
|
||||
|
||||
eval_details.append(detail)
|
||||
|
||||
eval_results.update(self.reduce(eval_details))
|
||||
eval_results['details'] = eval_details
|
||||
|
||||
return eval_results
|
||||
|
||||
def score(self):
|
||||
raise NotImplementedError("Method hasn't been implemented yet")
|
||||
|
||||
|
@ -57,10 +57,10 @@ class GPassKEvaluator(BaseEvaluator):
|
||||
integers (e.g., `[4, 8, 16]` computes G-Pass@4,
|
||||
G-Pass@8, and G-Pass@16).
|
||||
|
||||
replication (int): Controls the number of generations
|
||||
repeat (int): Controls the number of generations
|
||||
used to estimate G-Pass@k. The total number of
|
||||
generations is determined by multiplying the
|
||||
maximum of `k` with `replication`. This parameter
|
||||
maximum of `k` with `repeat`. This parameter
|
||||
should be a single integer.
|
||||
|
||||
thresholds (list of float): A list of floating-point
|
||||
@ -71,7 +71,7 @@ class GPassKEvaluator(BaseEvaluator):
|
||||
def __init__(
|
||||
self,
|
||||
k: Union[int, List[int]] = 16,
|
||||
replication: int = 3,
|
||||
repeat: int = 3,
|
||||
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -79,8 +79,8 @@ class GPassKEvaluator(BaseEvaluator):
|
||||
k = [k]
|
||||
|
||||
self.k = k
|
||||
self.replication = replication
|
||||
self.n = max(k) * replication
|
||||
self.repeat = repeat
|
||||
self.n = max(k) * repeat
|
||||
self.thresholds = thresholds
|
||||
|
||||
@property
|
||||
|
@ -105,6 +105,9 @@ class GenInferencer(BaseInferencer):
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
|
||||
print(len(prompt_list))
|
||||
|
||||
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
|
@ -215,7 +215,9 @@ class OpenICLEvalTask(BaseTask):
|
||||
k: preds[k]
|
||||
for k in signature(icl_evaluator.score).parameters
|
||||
}
|
||||
result = icl_evaluator.score(**preds)
|
||||
k = self.dataset_cfg.get('k', 1)
|
||||
repeat = self.dataset_cfg.get('repeat', 1)
|
||||
result = icl_evaluator.evaluate(k, repeat, test_set, **preds)
|
||||
|
||||
# Get model postprocess result
|
||||
model_details = None
|
||||
@ -223,7 +225,7 @@ class OpenICLEvalTask(BaseTask):
|
||||
if 'model_postprocessor' in self.eval_cfg:
|
||||
model_preds = copy.deepcopy(preds)
|
||||
model_preds['predictions'] = model_pred_strs
|
||||
model_result = icl_evaluator.score(**model_preds)
|
||||
model_result = icl_evaluator.evaluate(k, repeat, test_set, **model_preds)
|
||||
for key in model_result:
|
||||
if key == 'details':
|
||||
model_details = model_result[key]
|
||||
|
@ -9,7 +9,6 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict):
|
||||
dataset_cfg = copy.deepcopy(dataset_cfg)
|
||||
dataset_cfg.pop('infer_cfg', None)
|
||||
dataset_cfg.pop('eval_cfg', None)
|
||||
dataset_cfg.pop('abbr', None)
|
||||
return LOAD_DATASET.build(dataset_cfg)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user