From 8def69369ae9896bf9186965e48f3be44ee5ba36 Mon Sep 17 00:00:00 2001 From: jnanliu Date: Sun, 23 Feb 2025 03:05:42 +0000 Subject: [PATCH] support dataset repeat and g-pass compute for each evaluator --- .../livemathbench/livemathbench_gen_9befbf.py | 4 +- .../livemathbench_greedy_gen_9befbf.py | 4 +- opencompass/datasets/base.py | 37 ++++- .../datasets/livemathbench/livemathbench.py | 13 +- .../icl_evaluator/icl_base_evaluator.py | 137 +++++++++++++++++- .../icl_evaluator/icl_gpassk_evaluator.py | 10 +- .../icl_inferencer/icl_gen_inferencer.py | 3 + opencompass/tasks/openicl_eval.py | 6 +- opencompass/utils/build.py | 1 - 9 files changed, 189 insertions(+), 26 deletions(-) diff --git a/opencompass/configs/datasets/livemathbench/livemathbench_gen_9befbf.py b/opencompass/configs/datasets/livemathbench/livemathbench_gen_9befbf.py index 3748c022..27b4db56 100644 --- a/opencompass/configs/datasets/livemathbench/livemathbench_gen_9befbf.py +++ b/opencompass/configs/datasets/livemathbench/livemathbench_gen_9befbf.py @@ -9,7 +9,7 @@ livemathbench_dataset = dict( type=LiveMathBenchDataset, path='', k=16, - replication=3, + repeat=3, dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_languages=['cn', 'en'], cot=True, @@ -43,7 +43,7 @@ livemathbench_dataset = dict( extract_url=[], extract_model_name='', k=[4, 8, 16], - replication=3, + repeat=3, thresholds=[0.0, 0.25, 0.5, 0.75, 1.0] ) ) diff --git a/opencompass/configs/datasets/livemathbench/livemathbench_greedy_gen_9befbf.py b/opencompass/configs/datasets/livemathbench/livemathbench_greedy_gen_9befbf.py index d8d8b79c..a93c1f47 100644 --- a/opencompass/configs/datasets/livemathbench/livemathbench_greedy_gen_9befbf.py +++ b/opencompass/configs/datasets/livemathbench/livemathbench_greedy_gen_9befbf.py @@ -9,7 +9,7 @@ livemathbench_dataset = dict( type=LiveMathBenchDataset, path='', k=1, - replication=1, + repeat=1, dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_languages=['cn', 'en'], cot=True, @@ -43,7 +43,7 @@ livemathbench_dataset = dict( extract_url=[], extract_model_name='', k=[1], - replication=1, + repeat=1, thresholds=[0.0] ) ) diff --git a/opencompass/datasets/base.py b/opencompass/datasets/base.py index 5412ef4c..de839f6c 100644 --- a/opencompass/datasets/base.py +++ b/opencompass/datasets/base.py @@ -1,5 +1,6 @@ from abc import abstractstaticmethod -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List +from copy import deepcopy from datasets import Dataset, DatasetDict @@ -8,8 +9,38 @@ from opencompass.openicl import DatasetReader class BaseDataset: - def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs): - self.dataset = self.load(**kwargs) + def __init__(self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + repeat: int = 1, + **kwargs): + abbr = kwargs.pop('abbr', 'dataset') + dataset = self.load(**kwargs) + # maybe duplicate + n = (max(k) if isinstance(k, List) else k) * repeat + if isinstance(dataset, Dataset): + examples = [] + for idx, example in enumerate(dataset): + if 'subdivision' not in example: + example['subdivision'] = abbr + if 'idx' not in example: + example['idx'] = idx + examples.append(example) + examples = sum([deepcopy(examples) for _ in range(n)], []) + self.dataset = Dataset.from_list(examples) + else: + self.dataset = DatasetDict() + for key in dataset: + examples = [] + for idx, example in enumerate(dataset[key]): + if 'subdivision' not in example: + example['subdivision'] = f'{abbr}_{key}' + if 'idx' not in example: + example['idx'] = idx + examples.append(example) + print(abbr, key, len(examples)) + examples = sum([deepcopy(examples) for _ in range(n)], []) + self.dataset[key] = Dataset.from_list(examples) self._init_reader(**reader_cfg) def _init_reader(self, **kwargs): diff --git a/opencompass/datasets/livemathbench/livemathbench.py b/opencompass/datasets/livemathbench/livemathbench.py index d2b4b93b..208af7de 100644 --- a/opencompass/datasets/livemathbench/livemathbench.py +++ b/opencompass/datasets/livemathbench/livemathbench.py @@ -2,7 +2,6 @@ import os import warnings from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor, as_completed -from copy import deepcopy from functools import partial from itertools import product from typing import Any, Callable, Dict, List, Union @@ -31,8 +30,6 @@ class LiveMathBenchDataset(BaseDataset): @staticmethod def load(path: str, - k: Union[int, List[int]], - replication: int, dataset_splits: List[str] = [ 'CNMO', 'CCEE', @@ -104,11 +101,7 @@ class LiveMathBenchDataset(BaseDataset): ('' if 'options' not in example else ' '.join(example['options']))), }) - max_k = k if isinstance(k, int) else max(k) - for idx in range(max_k * replication): - duplicated_example = deepcopy(example) - duplicated_example.update({'replication_idx': idx}) - dataset.append(duplicated_example) + dataset.append(example) return Dataset.from_list(dataset) @@ -127,9 +120,9 @@ class LiveMathBenchEvaluator(GPassKEvaluator): extract_url=[], extract_model_name='', k: Union[int, List[int]] = 16, - replication: int = 3, + repeat: int = 3, thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]): - super().__init__(k, replication, thresholds) + super().__init__(k, repeat, thresholds) if isinstance(url, str): url = [url] diff --git a/opencompass/openicl/icl_evaluator/icl_base_evaluator.py b/opencompass/openicl/icl_evaluator/icl_base_evaluator.py index 0b07cfaa..4161825e 100644 --- a/opencompass/openicl/icl_evaluator/icl_base_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_base_evaluator.py @@ -1,11 +1,146 @@ """Base Evaluator.""" +from typing import Union, List, Dict, Any, Iterable +from collections import OrderedDict +from copy import deepcopy + +import numpy as np +from scipy.stats import hypergeom +from datasets import Dataset + + +def compute_pass_at_k(n, c, k): + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + +def _compute_g_pass_at_k(n, c, k, m): + if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0: + return 0.0 + return hypergeom.sf(m - 1, n, c, k) + +def compute_g_pass_at_k(n, c, k, t): + m = max(int(np.ceil(k * t)), 1) + return _compute_g_pass_at_k(n, c, k, m) + +def compute_mg_pass_at_k(n, c, k): + l, r = int(np.ceil(k * 0.5)), k + + mg_pass_at_k = 0.0 + for i in range(l + 1, r + 1): + mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i) + mg_pass_at_k = 2 * mg_pass_at_k / k + + return mg_pass_at_k class BaseEvaluator: - def __init__(self) -> None: pass + def group(self, n: int, details: List[Dict[str, Any]], test_set: Dataset) -> Dict[str, Any]: + example2replications = {} + for detail, example in zip(details, test_set): + example_abbr = f"{example['subdivision']}_{example['idx']}" + if example_abbr not in example2replications: + example2replications[example_abbr] = [] + example.update({'detail': detail}) + example2replications[example_abbr].append(example) + for _, replications in example2replications.items(): + assert len(replications) == n, print(len(replications), n) + return example2replications + + def reduce(self, details: List[Dict[str, Any]]) -> Dict[str, Any]: + g_passk_details = OrderedDict() + all_subdivisions = set([detail['example_abbr'].split('_')[0] for detail in details]) + all_metrics = list(details[0].keys()) + + for subdivision in sorted(list(all_subdivisions)): + for metric in all_metrics: + if metric in ['predictions', 'example_abbr']: + continue + g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean([ + detail[metric] + for detail in details + if detail['example_abbr'].split('_')[0] == subdivision + ]) + + for metric in all_metrics: + if metric in ['predictions', 'example_abbr']: + continue + g_passk_details[metric] = 100. * np.mean([detail[metric] for detail in details]) + return g_passk_details + + def evaluate(self, k: Union[int, List[int]], + repeat: int, test_set: Dataset, **score_kwargs): + n = (max(k) if isinstance(k, List) else k) * repeat + print(len(score_kwargs['predictions'])) + real_size = len(test_set) // n + all_details = [] + all_results = [] + for i in range(n): + results = self.score(**{ + key: value[i * real_size: (i + 1) * real_size] if isinstance(value, Iterable) else value + for key, value in score_kwargs.items() + }) + details = results.pop('details', None) + if details is not None: + if isinstance(details, Dict): + details = list(details.values()) + all_details.extend(details) + all_results.append(results) + + eval_results = {} + for single_results in all_results: + for key in single_results: + if key not in eval_results: + eval_results[key] = [] + eval_results[key].append(single_results[key]) + for key in deepcopy(eval_results): + if isinstance(eval_results[key][0], float) or isinstance(eval_results[key][0], int): + if n > 1: + eval_results[key + f' ({n // repeat}x{repeat}={n} runs average)'] = np.mean(eval_results[key]) + eval_results.pop(key) + else: + eval_results[key] = np.mean(eval_results[key]) + else: + eval_results[key] = eval_results[key][0] + + grouped_examples = self.group(n, all_details, test_set) + if len(all_details) != 0: + eval_details = [] + for example_abbr, examples in grouped_examples.items(): + detail = { + 'predictions': [], + 'example_abbr': example_abbr + } + + c = 0 + can_calculate = False + for example in examples: + detail['predictions'].append(example['detail']) + # only compute G-Pass@k when details have correct labels + if example['detail'].get('correct', None) is not None: + can_calculate = True + c += int(example['detail']['correct']) + elif example['detail'].get('is_correct', None) is not None: + can_calculate = True + c += int(example['detail']['is_correct']) + + if can_calculate: + thresholds = [0.0, 0.25, 0.5, 0.75, 1.0] + for _k in ([k] if isinstance(k, int) else k): + for threshold in thresholds: + detail[f'G-Pass@{_k}_{threshold}'] = compute_g_pass_at_k( + n=n, c=c, k=_k, t=threshold) + detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(n=n, c=c, k=_k) + + eval_details.append(detail) + + eval_results.update(self.reduce(eval_details)) + eval_results['details'] = eval_details + + return eval_results + def score(self): raise NotImplementedError("Method hasn't been implemented yet") diff --git a/opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py b/opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py index 80a59073..8391a435 100644 --- a/opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_gpassk_evaluator.py @@ -57,10 +57,10 @@ class GPassKEvaluator(BaseEvaluator): integers (e.g., `[4, 8, 16]` computes G-Pass@4, G-Pass@8, and G-Pass@16). - replication (int): Controls the number of generations + repeat (int): Controls the number of generations used to estimate G-Pass@k. The total number of generations is determined by multiplying the - maximum of `k` with `replication`. This parameter + maximum of `k` with `repeat`. This parameter should be a single integer. thresholds (list of float): A list of floating-point @@ -71,7 +71,7 @@ class GPassKEvaluator(BaseEvaluator): def __init__( self, k: Union[int, List[int]] = 16, - replication: int = 3, + repeat: int = 3, thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None: super().__init__() @@ -79,8 +79,8 @@ class GPassKEvaluator(BaseEvaluator): k = [k] self.k = k - self.replication = replication - self.n = max(k) * replication + self.repeat = repeat + self.n = max(k) * repeat self.thresholds = thresholds @property diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index 6a33b711..a2dfba85 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -104,6 +104,9 @@ class GenInferencer(BaseInferencer): max_seq_len=self.max_seq_len, ice_template=ice_template, prompt_template=prompt_template) + + + print(len(prompt_list)) # 3.1 Fetch and zip prompt & gold answer if output column exists ds_reader = retriever.dataset_reader diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index a797459f..b1905631 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -215,7 +215,9 @@ class OpenICLEvalTask(BaseTask): k: preds[k] for k in signature(icl_evaluator.score).parameters } - result = icl_evaluator.score(**preds) + k = self.dataset_cfg.get('k', 1) + repeat = self.dataset_cfg.get('repeat', 1) + result = icl_evaluator.evaluate(k, repeat, test_set, **preds) # Get model postprocess result model_details = None @@ -223,7 +225,7 @@ class OpenICLEvalTask(BaseTask): if 'model_postprocessor' in self.eval_cfg: model_preds = copy.deepcopy(preds) model_preds['predictions'] = model_pred_strs - model_result = icl_evaluator.score(**model_preds) + model_result = icl_evaluator.evaluate(k, repeat, test_set, **model_preds) for key in model_result: if key == 'details': model_details = model_result[key] diff --git a/opencompass/utils/build.py b/opencompass/utils/build.py index 14a66683..f0973d7f 100644 --- a/opencompass/utils/build.py +++ b/opencompass/utils/build.py @@ -9,7 +9,6 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict): dataset_cfg = copy.deepcopy(dataset_cfg) dataset_cfg.pop('infer_cfg', None) dataset_cfg.pop('eval_cfg', None) - dataset_cfg.pop('abbr', None) return LOAD_DATASET.build(dataset_cfg)