support dataset repeat and g-pass compute for each evaluator

This commit is contained in:
jnanliu 2025-02-23 03:05:42 +00:00
parent 046b6f75c6
commit 8def69369a
9 changed files with 189 additions and 26 deletions

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset, type=LiveMathBenchDataset,
path='', path='',
k=16, k=16,
replication=3, repeat=3,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'], dataset_languages=['cn', 'en'],
cot=True, cot=True,
@ -43,7 +43,7 @@ livemathbench_dataset = dict(
extract_url=[], extract_url=[],
extract_model_name='', extract_model_name='',
k=[4, 8, 16], k=[4, 8, 16],
replication=3, repeat=3,
thresholds=[0.0, 0.25, 0.5, 0.75, 1.0] thresholds=[0.0, 0.25, 0.5, 0.75, 1.0]
) )
) )

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset, type=LiveMathBenchDataset,
path='', path='',
k=1, k=1,
replication=1, repeat=1,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'], dataset_languages=['cn', 'en'],
cot=True, cot=True,
@ -43,7 +43,7 @@ livemathbench_dataset = dict(
extract_url=[], extract_url=[],
extract_model_name='', extract_model_name='',
k=[1], k=[1],
replication=1, repeat=1,
thresholds=[0.0] thresholds=[0.0]
) )
) )

View File

@ -1,5 +1,6 @@
from abc import abstractstaticmethod from abc import abstractstaticmethod
from typing import Dict, Optional, Union from typing import Dict, Optional, Union, List
from copy import deepcopy
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
@ -8,8 +9,38 @@ from opencompass.openicl import DatasetReader
class BaseDataset: class BaseDataset:
def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs): def __init__(self,
self.dataset = self.load(**kwargs) reader_cfg: Optional[Dict] = {},
k: Union[int, List[int]] = 1,
repeat: int = 1,
**kwargs):
abbr = kwargs.pop('abbr', 'dataset')
dataset = self.load(**kwargs)
# maybe duplicate
n = (max(k) if isinstance(k, List) else k) * repeat
if isinstance(dataset, Dataset):
examples = []
for idx, example in enumerate(dataset):
if 'subdivision' not in example:
example['subdivision'] = abbr
if 'idx' not in example:
example['idx'] = idx
examples.append(example)
examples = sum([deepcopy(examples) for _ in range(n)], [])
self.dataset = Dataset.from_list(examples)
else:
self.dataset = DatasetDict()
for key in dataset:
examples = []
for idx, example in enumerate(dataset[key]):
if 'subdivision' not in example:
example['subdivision'] = f'{abbr}_{key}'
if 'idx' not in example:
example['idx'] = idx
examples.append(example)
print(abbr, key, len(examples))
examples = sum([deepcopy(examples) for _ in range(n)], [])
self.dataset[key] = Dataset.from_list(examples)
self._init_reader(**reader_cfg) self._init_reader(**reader_cfg)
def _init_reader(self, **kwargs): def _init_reader(self, **kwargs):

View File

@ -2,7 +2,6 @@ import os
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from functools import partial from functools import partial
from itertools import product from itertools import product
from typing import Any, Callable, Dict, List, Union from typing import Any, Callable, Dict, List, Union
@ -31,8 +30,6 @@ class LiveMathBenchDataset(BaseDataset):
@staticmethod @staticmethod
def load(path: str, def load(path: str,
k: Union[int, List[int]],
replication: int,
dataset_splits: List[str] = [ dataset_splits: List[str] = [
'CNMO', 'CNMO',
'CCEE', 'CCEE',
@ -104,11 +101,7 @@ class LiveMathBenchDataset(BaseDataset):
('' if 'options' not in example else ('' if 'options' not in example else
' '.join(example['options']))), ' '.join(example['options']))),
}) })
max_k = k if isinstance(k, int) else max(k) dataset.append(example)
for idx in range(max_k * replication):
duplicated_example = deepcopy(example)
duplicated_example.update({'replication_idx': idx})
dataset.append(duplicated_example)
return Dataset.from_list(dataset) return Dataset.from_list(dataset)
@ -127,9 +120,9 @@ class LiveMathBenchEvaluator(GPassKEvaluator):
extract_url=[], extract_url=[],
extract_model_name='', extract_model_name='',
k: Union[int, List[int]] = 16, k: Union[int, List[int]] = 16,
replication: int = 3, repeat: int = 3,
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]): thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]):
super().__init__(k, replication, thresholds) super().__init__(k, repeat, thresholds)
if isinstance(url, str): if isinstance(url, str):
url = [url] url = [url]

View File

@ -1,11 +1,146 @@
"""Base Evaluator.""" """Base Evaluator."""
from typing import Union, List, Dict, Any, Iterable
from collections import OrderedDict
from copy import deepcopy
import numpy as np
from scipy.stats import hypergeom
from datasets import Dataset
def compute_pass_at_k(n, c, k):
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
def _compute_g_pass_at_k(n, c, k, m):
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
return 0.0
return hypergeom.sf(m - 1, n, c, k)
def compute_g_pass_at_k(n, c, k, t):
m = max(int(np.ceil(k * t)), 1)
return _compute_g_pass_at_k(n, c, k, m)
def compute_mg_pass_at_k(n, c, k):
l, r = int(np.ceil(k * 0.5)), k
mg_pass_at_k = 0.0
for i in range(l + 1, r + 1):
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
mg_pass_at_k = 2 * mg_pass_at_k / k
return mg_pass_at_k
class BaseEvaluator: class BaseEvaluator:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def group(self, n: int, details: List[Dict[str, Any]], test_set: Dataset) -> Dict[str, Any]:
example2replications = {}
for detail, example in zip(details, test_set):
example_abbr = f"{example['subdivision']}_{example['idx']}"
if example_abbr not in example2replications:
example2replications[example_abbr] = []
example.update({'detail': detail})
example2replications[example_abbr].append(example)
for _, replications in example2replications.items():
assert len(replications) == n, print(len(replications), n)
return example2replications
def reduce(self, details: List[Dict[str, Any]]) -> Dict[str, Any]:
g_passk_details = OrderedDict()
all_subdivisions = set([detail['example_abbr'].split('_')[0] for detail in details])
all_metrics = list(details[0].keys())
for subdivision in sorted(list(all_subdivisions)):
for metric in all_metrics:
if metric in ['predictions', 'example_abbr']:
continue
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean([
detail[metric]
for detail in details
if detail['example_abbr'].split('_')[0] == subdivision
])
for metric in all_metrics:
if metric in ['predictions', 'example_abbr']:
continue
g_passk_details[metric] = 100. * np.mean([detail[metric] for detail in details])
return g_passk_details
def evaluate(self, k: Union[int, List[int]],
repeat: int, test_set: Dataset, **score_kwargs):
n = (max(k) if isinstance(k, List) else k) * repeat
print(len(score_kwargs['predictions']))
real_size = len(test_set) // n
all_details = []
all_results = []
for i in range(n):
results = self.score(**{
key: value[i * real_size: (i + 1) * real_size] if isinstance(value, Iterable) else value
for key, value in score_kwargs.items()
})
details = results.pop('details', None)
if details is not None:
if isinstance(details, Dict):
details = list(details.values())
all_details.extend(details)
all_results.append(results)
eval_results = {}
for single_results in all_results:
for key in single_results:
if key not in eval_results:
eval_results[key] = []
eval_results[key].append(single_results[key])
for key in deepcopy(eval_results):
if isinstance(eval_results[key][0], float) or isinstance(eval_results[key][0], int):
if n > 1:
eval_results[key + f' ({n // repeat}x{repeat}={n} runs average)'] = np.mean(eval_results[key])
eval_results.pop(key)
else:
eval_results[key] = np.mean(eval_results[key])
else:
eval_results[key] = eval_results[key][0]
grouped_examples = self.group(n, all_details, test_set)
if len(all_details) != 0:
eval_details = []
for example_abbr, examples in grouped_examples.items():
detail = {
'predictions': [],
'example_abbr': example_abbr
}
c = 0
can_calculate = False
for example in examples:
detail['predictions'].append(example['detail'])
# only compute G-Pass@k when details have correct labels
if example['detail'].get('correct', None) is not None:
can_calculate = True
c += int(example['detail']['correct'])
elif example['detail'].get('is_correct', None) is not None:
can_calculate = True
c += int(example['detail']['is_correct'])
if can_calculate:
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
for _k in ([k] if isinstance(k, int) else k):
for threshold in thresholds:
detail[f'G-Pass@{_k}_{threshold}'] = compute_g_pass_at_k(
n=n, c=c, k=_k, t=threshold)
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(n=n, c=c, k=_k)
eval_details.append(detail)
eval_results.update(self.reduce(eval_details))
eval_results['details'] = eval_details
return eval_results
def score(self): def score(self):
raise NotImplementedError("Method hasn't been implemented yet") raise NotImplementedError("Method hasn't been implemented yet")

View File

@ -57,10 +57,10 @@ class GPassKEvaluator(BaseEvaluator):
integers (e.g., `[4, 8, 16]` computes G-Pass@4, integers (e.g., `[4, 8, 16]` computes G-Pass@4,
G-Pass@8, and G-Pass@16). G-Pass@8, and G-Pass@16).
replication (int): Controls the number of generations repeat (int): Controls the number of generations
used to estimate G-Pass@k. The total number of used to estimate G-Pass@k. The total number of
generations is determined by multiplying the generations is determined by multiplying the
maximum of `k` with `replication`. This parameter maximum of `k` with `repeat`. This parameter
should be a single integer. should be a single integer.
thresholds (list of float): A list of floating-point thresholds (list of float): A list of floating-point
@ -71,7 +71,7 @@ class GPassKEvaluator(BaseEvaluator):
def __init__( def __init__(
self, self,
k: Union[int, List[int]] = 16, k: Union[int, List[int]] = 16,
replication: int = 3, repeat: int = 3,
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None: thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None:
super().__init__() super().__init__()
@ -79,8 +79,8 @@ class GPassKEvaluator(BaseEvaluator):
k = [k] k = [k]
self.k = k self.k = k
self.replication = replication self.repeat = repeat
self.n = max(k) * replication self.n = max(k) * repeat
self.thresholds = thresholds self.thresholds = thresholds
@property @property

View File

@ -104,6 +104,9 @@ class GenInferencer(BaseInferencer):
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
ice_template=ice_template, ice_template=ice_template,
prompt_template=prompt_template) prompt_template=prompt_template)
print(len(prompt_list))
# 3.1 Fetch and zip prompt & gold answer if output column exists # 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader ds_reader = retriever.dataset_reader

View File

@ -215,7 +215,9 @@ class OpenICLEvalTask(BaseTask):
k: preds[k] k: preds[k]
for k in signature(icl_evaluator.score).parameters for k in signature(icl_evaluator.score).parameters
} }
result = icl_evaluator.score(**preds) k = self.dataset_cfg.get('k', 1)
repeat = self.dataset_cfg.get('repeat', 1)
result = icl_evaluator.evaluate(k, repeat, test_set, **preds)
# Get model postprocess result # Get model postprocess result
model_details = None model_details = None
@ -223,7 +225,7 @@ class OpenICLEvalTask(BaseTask):
if 'model_postprocessor' in self.eval_cfg: if 'model_postprocessor' in self.eval_cfg:
model_preds = copy.deepcopy(preds) model_preds = copy.deepcopy(preds)
model_preds['predictions'] = model_pred_strs model_preds['predictions'] = model_pred_strs
model_result = icl_evaluator.score(**model_preds) model_result = icl_evaluator.evaluate(k, repeat, test_set, **model_preds)
for key in model_result: for key in model_result:
if key == 'details': if key == 'details':
model_details = model_result[key] model_details = model_result[key]

View File

@ -9,7 +9,6 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict):
dataset_cfg = copy.deepcopy(dataset_cfg) dataset_cfg = copy.deepcopy(dataset_cfg)
dataset_cfg.pop('infer_cfg', None) dataset_cfg.pop('infer_cfg', None)
dataset_cfg.pop('eval_cfg', None) dataset_cfg.pop('eval_cfg', None)
dataset_cfg.pop('abbr', None)
return LOAD_DATASET.build(dataset_cfg) return LOAD_DATASET.build(dataset_cfg)