[Feature] Support Dataset Repeat and G-Pass Compute for Each Evaluator (#1886)

* support dataset repeat and g-pass compute for each evaluator

* fix pre-commit errors

* delete print

* delete gpassk_evaluator and fix potential errors

* change `repeat` to `n`

* fix `repeat` to `n` in openicl_eval

* update doc for multi-run and g-pass

* update latex equation in doc

* update eng doc for multi-run and g-pass

* update datasets.md

* update datasets.md

* fix multi-line equation

* fix multi-line equation

* fix multi-line equation

* fix multi-line equation

* fix multi-line equation

* fix multi-line equation

* fix multi-line equation in zh_cn user_guides

* mmodify pre-commit-zh-cn

* recover pre-commit and edit math expr in doc

* del [TIP]

* del cite tag in doc

* del extract_model param in livemathbench config
This commit is contained in:
Junnan Liu 2025-02-26 19:43:12 +08:00 committed by GitHub
parent 6042b88e58
commit 73c80953c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 300 additions and 250 deletions

View File

@ -81,3 +81,43 @@ datasets += cmnli_datasets
Users can choose different abilities, different datasets and different evaluation methods configuration files to build the part of the dataset in the evaluation script according to their needs. Users can choose different abilities, different datasets and different evaluation methods configuration files to build the part of the dataset in the evaluation script according to their needs.
For information on how to start an evaluation task and how to evaluate self-built datasets, please refer to the relevant documents. For information on how to start an evaluation task and how to evaluate self-built datasets, please refer to the relevant documents.
### Multiple Evaluations on the Dataset
In the dataset configuration, you can set the parameter `n` to perform multiple evaluations on the same dataset and return the average metrics, for example:
```python
afqmc_datasets = [
dict(
abbr="afqmc-dev",
type=AFQMCDatasetV2,
path="./data/CLUE/AFQMC/dev.json",
n=10, # Perform 10 evaluations
reader_cfg=afqmc_reader_cfg,
infer_cfg=afqmc_infer_cfg,
eval_cfg=afqmc_eval_cfg,
),
]
```
Additionally, for binary evaluation metrics (such as accuracy, pass-rate, etc.), you can also set the parameter `k` in conjunction with `n` for [G-Pass@k](http://arxiv.org/abs/2412.13147) evaluation. The formula for G-Pass@k is:
```{math}
\text{G-Pass@}k_\tau=E_{\text{Data}}\left[ \sum_{j=\lceil \tau \cdot k \rceil}^c \frac{{c \choose j} \cdot {n - c \choose k - j}}{{n \choose k}} \right],
```
where $n$ is the number of evaluations, and $c$ is the number of times that passed or were correct out of $n$ runs. An example configuration is as follows:
```python
aime2024_datasets = [
dict(
abbr='aime2024',
type=Aime2024Dataset,
path='opencompass/aime2024',
k=[2, 4], # Return results for G-Pass@2 and G-Pass@4
n=12, # 12 evaluations
...
)
]
```

View File

@ -81,3 +81,42 @@ datasets += cmnli_datasets
用户可以根据需要,选择不同能力不同数据集以及不同评测方式的配置文件来构建评测脚本中数据集的部分。 用户可以根据需要,选择不同能力不同数据集以及不同评测方式的配置文件来构建评测脚本中数据集的部分。
有关如何启动评测任务,以及如何评测自建数据集可以参考相关文档。 有关如何启动评测任务,以及如何评测自建数据集可以参考相关文档。
### 数据集多次评测
在数据集配置中可以通过设置参数`n`来对同一数据集进行多次评测,最终返回平均指标,例如:
```python
afqmc_datasets = [
dict(
abbr="afqmc-dev",
type=AFQMCDatasetV2,
path="./data/CLUE/AFQMC/dev.json",
n=10, # 进行10次评测
reader_cfg=afqmc_reader_cfg,
infer_cfg=afqmc_infer_cfg,
eval_cfg=afqmc_eval_cfg,
),
]
```
另外对于二值评测指标例如accuracypass-rate等还可以通过设置参数`k`配合`n`进行[G-Pass@k](http://arxiv.org/abs/2412.13147)评测。G-Pass@k计算公式为
```{math}
\text{G-Pass@}k_\tau=E_{\text{Data}}\left[ \sum_{j=\lceil \tau \cdot k \rceil}^c \frac{{c \choose j} \cdot {n - c \choose k - j}}{{n \choose k}} \right],
```
其中 $n$ 为评测次数, $c$ 为 $n$ 次运行中通过或正确的次数。配置例子如下:
```python
aime2024_datasets = [
dict(
abbr='aime2024',
type=Aime2024Dataset,
path='opencompass/aime2024',
k=[2, 4], # 返回 G-Pass@2和G-Pass@4的结果
n=12, # 12次评测
...
)
]
```

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset, type=LiveMathBenchDataset,
path='', path='',
k=16, k=16,
replication=3, n=48,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'], dataset_languages=['cn', 'en'],
cot=True, cot=True,
@ -38,13 +38,7 @@ livemathbench_dataset = dict(
evaluator=dict( evaluator=dict(
type=LiveMathBenchEvaluator, type=LiveMathBenchEvaluator,
model_name='', model_name='',
url=[], url=[]
use_extract_model=False,
extract_url=[],
extract_model_name='',
k=[4, 8, 16],
replication=3,
thresholds=[0.0, 0.25, 0.5, 0.75, 1.0]
) )
) )
) )

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset, type=LiveMathBenchDataset,
path='', path='',
k=1, k=1,
replication=1, n=1,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'], dataset_languages=['cn', 'en'],
cot=True, cot=True,
@ -38,13 +38,7 @@ livemathbench_dataset = dict(
evaluator=dict( evaluator=dict(
type=LiveMathBenchEvaluator, type=LiveMathBenchEvaluator,
model_name='', model_name='',
url=[], url=[]
use_extract_model=False,
extract_url=[],
extract_model_name='',
k=[1],
replication=1,
thresholds=[0.0]
) )
) )
) )

View File

@ -1,4 +1,5 @@
from typing import Dict, Optional, Union from copy import deepcopy
from typing import Dict, List, Optional, Union
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
@ -7,8 +8,39 @@ from opencompass.openicl import DatasetReader
class BaseDataset: class BaseDataset:
def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs): def __init__(self,
self.dataset = self.load(**kwargs) reader_cfg: Optional[Dict] = {},
k: Union[int, List[int]] = 1,
n: int = 1,
**kwargs):
abbr = kwargs.pop('abbr', 'dataset')
dataset = self.load(**kwargs)
# maybe duplicate
assert (max(k) if isinstance(k, List) else
k) <= n, 'Maximum value of `k` must less than or equal to `n`'
if isinstance(dataset, Dataset):
examples = []
for idx, example in enumerate(dataset):
if 'subdivision' not in example:
example['subdivision'] = abbr
if 'idx' not in example:
example['idx'] = idx
examples.append(example)
examples = sum([deepcopy(examples) for _ in range(n)], [])
self.dataset = Dataset.from_list(examples)
else:
self.dataset = DatasetDict()
for key in dataset:
examples = []
for idx, example in enumerate(dataset[key]):
if 'subdivision' not in example:
example['subdivision'] = f'{abbr}_{key}'
if 'idx' not in example:
example['idx'] = idx
examples.append(example)
print(abbr, key, len(examples))
examples = sum([deepcopy(examples) for _ in range(n)], [])
self.dataset[key] = Dataset.from_list(examples)
self._init_reader(**reader_cfg) self._init_reader(**reader_cfg)
def _init_reader(self, **kwargs): def _init_reader(self, **kwargs):

View File

@ -1,11 +1,9 @@
import os import os
import warnings import warnings
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from functools import partial from functools import partial
from itertools import product from itertools import product
from typing import Any, Callable, Dict, List, Union from typing import Any, Callable, Dict, List
import jsonlines import jsonlines
import mmengine import mmengine
@ -14,7 +12,7 @@ from datasets import Dataset, load_dataset
from opencompass.datasets.math import MATHAgentEvaluator, math_postprocess_v2 from opencompass.datasets.math import MATHAgentEvaluator, math_postprocess_v2
from opencompass.models import OpenAISDK from opencompass.models import OpenAISDK
from opencompass.openicl.icl_evaluator import GPassKEvaluator from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.openicl.icl_inferencer.icl_base_inferencer import \ from opencompass.openicl.icl_inferencer.icl_base_inferencer import \
dump_results_dict dump_results_dict
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS
@ -31,8 +29,6 @@ class LiveMathBenchDataset(BaseDataset):
@staticmethod @staticmethod
def load(path: str, def load(path: str,
k: Union[int, List[int]],
replication: int,
dataset_splits: List[str] = [ dataset_splits: List[str] = [
'CNMO', 'CNMO',
'CCEE', 'CCEE',
@ -104,17 +100,13 @@ class LiveMathBenchDataset(BaseDataset):
('' if 'options' not in example else ('' if 'options' not in example else
' '.join(example['options']))), ' '.join(example['options']))),
}) })
max_k = k if isinstance(k, int) else max(k) dataset.append(example)
for idx in range(max_k * replication):
duplicated_example = deepcopy(example)
duplicated_example.update({'replication_idx': idx})
dataset.append(duplicated_example)
return Dataset.from_list(dataset) return Dataset.from_list(dataset)
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class LiveMathBenchEvaluator(GPassKEvaluator): class LiveMathBenchEvaluator(BaseEvaluator):
api_meta_template = dict(round=[ api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'), dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True), dict(role='BOT', api_role='BOT', generate=True),
@ -125,11 +117,8 @@ class LiveMathBenchEvaluator(GPassKEvaluator):
url, url,
use_extract_model=False, use_extract_model=False,
extract_url=[], extract_url=[],
extract_model_name='', extract_model_name=''):
k: Union[int, List[int]] = 16, super().__init__()
replication: int = 3,
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]):
super().__init__(k, replication, thresholds)
if isinstance(url, str): if isinstance(url, str):
url = [url] url = [url]
@ -310,55 +299,18 @@ class LiveMathBenchEvaluator(GPassKEvaluator):
def preprocess(self, predictions, references, test_set): def preprocess(self, predictions, references, test_set):
return self.judge(predictions, references, test_set) return self.judge(predictions, references, test_set)
def group(self, predictions, labels, test_set): def score(self, predictions, references, test_set) -> Dict[str, Any]:
example2replications = {} labels = self.preprocess(predictions, references, test_set)
for example, label, prediction in zip(test_set, labels, predictions): results = {'accuracy': 100 * np.mean(labels), 'details': []}
example_abbr = f"{example['subdivision']}_{example['idx']}"
if example_abbr not in example2replications:
example2replications[example_abbr] = []
example.update({'prediction': prediction, 'label': label})
example2replications[example_abbr].append(example)
for _, replications in example2replications.items():
assert len(replications) == self.n, print(len(replications),
self.n)
return example2replications
def reduce(self, details) -> Dict[str, Any]: for pred, ref, label in zip(predictions, references, labels):
"""Aggregate the overall metrics. results['details'].append({
'pred': pred,
'ref': ref,
'correct': label
})
Return: return results
A dict contains overall metrics, like:
{'details': details for each example, 'G-Pass@16': xxx}
"""
g_passk_details = OrderedDict()
g_passk_details['details'] = details
all_dataset = set([detail['subdivision'] for detail in details])
for k in self.k:
for subdivision in sorted(list(all_dataset)):
for threshold in self.thresholds:
g_passk_details[
f'{subdivision}/G-Pass@{k}_{threshold}'] = \
100. * np.mean(
[
detail[f'G-Pass@{k}_{threshold}']
for detail in details
if detail['subdivision'] == subdivision
])
g_passk_details[f'{subdivision}/mG-Pass@{k}'] = 100. * np.mean(
[
detail[f'mG-Pass@{k}'] for detail in details
if detail['subdivision'] == subdivision
])
for threshold in self.thresholds:
g_passk_details[f'G-Pass@{k}_{threshold}'] = 100. * np.mean(
[detail[f'G-Pass@{k}_{threshold}'] for detail in details])
g_passk_details[f'mG-Pass@{k}'] = 100. * np.mean(
[detail[f'mG-Pass@{k}'] for detail in details])
return g_passk_details
class LiveMathBenchOutputHandler: class LiveMathBenchOutputHandler:

View File

@ -4,7 +4,6 @@ from .icl_base_evaluator import BaseEvaluator # noqa
from .icl_bpc_evaluator import BPCEvaluator # noqa from .icl_bpc_evaluator import BPCEvaluator # noqa
from .icl_circular_evaluator import CircularEvaluator # noqa from .icl_circular_evaluator import CircularEvaluator # noqa
from .icl_em_evaluator import EMEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa
from .icl_gpassk_evaluator import GPassKEvaluator # noqa
from .icl_hf_evaluator import * # noqa from .icl_hf_evaluator import * # noqa
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa
from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa from .icl_misc_evaluator import AverageInferencePPLEvaluator # noqa

View File

@ -1,4 +1,39 @@
"""Base Evaluator.""" """Base Evaluator."""
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Union
import numpy as np
from datasets import Dataset
from scipy.stats import hypergeom
def compute_pass_at_k(n, c, k):
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
def _compute_g_pass_at_k(n, c, k, m):
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
return 0.0
return hypergeom.sf(m - 1, n, c, k)
def compute_g_pass_at_k(n, c, k, t):
m = max(int(np.ceil(k * t)), 1)
return _compute_g_pass_at_k(n, c, k, m)
def compute_mg_pass_at_k(n, c, k):
l, r = int(np.ceil(k * 0.5)), k
mg_pass_at_k = 0.0
for i in range(l + 1, r + 1):
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
mg_pass_at_k = 2 * mg_pass_at_k / k
return mg_pass_at_k
class BaseEvaluator: class BaseEvaluator:
@ -6,6 +41,130 @@ class BaseEvaluator:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
@property
def output_dir(self):
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
return self._out_dir
def group(self, n: int, details: List[Dict[str, Any]],
test_set: Dataset) -> Dict[str, Any]:
example2replications = {}
for detail, example in zip(details, test_set):
example_abbr = f"{example['subdivision']}_{example['idx']}"
if example_abbr not in example2replications:
example2replications[example_abbr] = []
example.update({'detail': detail})
example2replications[example_abbr].append(example)
for _, replications in example2replications.items():
assert len(replications) == n, print(len(replications), n)
return example2replications
def reduce(self, details: List[Dict[str, Any]]) -> Dict[str, Any]:
g_passk_details = OrderedDict()
all_subdivisions = set(
[detail['example_abbr'].split('_')[0] for detail in details])
all_metrics = list(details[0].keys())
for subdivision in sorted(list(all_subdivisions)):
for metric in all_metrics:
if metric in ['predictions', 'example_abbr']:
continue
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean([
detail[metric] for detail in details
if detail['example_abbr'].split('_')[0] == subdivision
])
for metric in all_metrics:
if metric in ['predictions', 'example_abbr']:
continue
g_passk_details[metric] = 100. * np.mean(
[detail[metric] for detail in details])
return g_passk_details
def evaluate(self, k: Union[int, List[int]], n: int,
original_dataset: Dataset, **score_kwargs):
real_size = len(original_dataset) // n
all_details = []
all_results = []
for i in range(n):
def select_fn(i, real_size, x):
if isinstance(x, Dataset):
return x.select(range(i * real_size, (i + 1) * real_size))
elif isinstance(x, Iterable):
return x[i * real_size:(i + 1) * real_size]
else:
return x
results = self.score(
**{
key: select_fn(i, real_size, value)
for key, value in score_kwargs.items()
})
details = results.pop('details', None)
if details is not None:
if isinstance(details, Dict):
details = list(details.values())
all_details.extend(details)
all_results.append(results)
eval_results = {}
for single_results in all_results:
for key in single_results:
if key not in eval_results:
eval_results[key] = []
eval_results[key].append(single_results[key])
for key in deepcopy(eval_results):
if isinstance(eval_results[key][0], float) or isinstance(
eval_results[key][0], int):
if n > 1:
eval_results[key + f' ({n} runs average)'] = np.mean(
eval_results[key])
eval_results.pop(key)
else:
eval_results[key] = np.mean(eval_results[key])
else:
eval_results[key] = eval_results[key][0]
grouped_examples = self.group(n, all_details, original_dataset)
can_calculate = False
if len(all_details) != 0:
eval_details = []
for example_abbr, examples in grouped_examples.items():
detail = {'predictions': [], 'example_abbr': example_abbr}
c = 0
for example in examples:
detail['predictions'].append(example['detail'])
# only compute G-Pass@k when details have correct labels
if example['detail'].get('correct', None) is not None:
can_calculate = True
c += int(example['detail']['correct'])
elif example['detail'].get('is_correct', None) is not None:
can_calculate = True
c += int(example['detail']['is_correct'])
if can_calculate and n > 1 and k > 1:
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
for _k in ([k] if isinstance(k, int) else k):
for threshold in thresholds:
g_pass = compute_g_pass_at_k(n=n,
c=c,
k=_k,
t=threshold)
detail[f'G-Pass@{_k}_{threshold}'] = g_pass
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(n=n,
c=c,
k=_k)
eval_details.append(detail)
if can_calculate and n > 1 and k > 1:
eval_results.update(self.reduce(eval_details))
eval_results['details'] = eval_details
return eval_results
def score(self): def score(self):
raise NotImplementedError("Method hasn't been implemented yet") raise NotImplementedError("Method hasn't been implemented yet")

View File

@ -1,163 +0,0 @@
from abc import abstractmethod
from typing import Any, Dict, List, Union
import numpy as np
from scipy.stats import hypergeom
from opencompass.registry import ICL_EVALUATORS
from .icl_base_evaluator import BaseEvaluator
def compute_pass_at_k(n, c, k):
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
def _compute_g_pass_at_k(n, c, k, m):
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
return 0.0
return hypergeom.sf(m - 1, n, c, k)
def compute_g_pass_at_k(n, c, k, t):
m = max(int(np.ceil(k * t)), 1)
return _compute_g_pass_at_k(n, c, k, m)
def compute_mg_pass_at_k(n, c, k):
l, r = int(np.ceil(k * 0.5)), k
mg_pass_at_k = 0.0
for i in range(l + 1, r + 1):
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
mg_pass_at_k = 2 * mg_pass_at_k / k
return mg_pass_at_k
@ICL_EVALUATORS.register_module()
class GPassKEvaluator(BaseEvaluator):
"""Evaluator for computing the G-Pass@k Metric.
This evaluator performs the following steps:
1. Invokes task-specific `preprocess` on predictions to
assign a consistency label to each prediction and its
corresponding reference.
2. Calculates metrics for each input example based on
these labels.
3. Aggregates the overall metrics through a task-specific
`postprocess`.
Args:
k (int or list of int): Number of predictions to be
considered in G-Pass@k. It can be a single integer
(e.g., `k=16` computes G-Pass@16) or a list of
integers (e.g., `[4, 8, 16]` computes G-Pass@4,
G-Pass@8, and G-Pass@16).
replication (int): Controls the number of generations
used to estimate G-Pass@k. The total number of
generations is determined by multiplying the
maximum of `k` with `replication`. This parameter
should be a single integer.
thresholds (list of float): A list of floating-point
numbers that define the thresholds for the G-Pass@k
metric.
"""
def __init__(
self,
k: Union[int, List[int]] = 16,
replication: int = 3,
thresholds: List[float] = [0.0, 0.25, 0.5, 0.75, 1.0]) -> None:
super().__init__()
if isinstance(k, int):
k = [k]
self.k = k
self.replication = replication
self.n = max(k) * replication
self.thresholds = thresholds
@property
def output_dir(self):
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
return self._out_dir
@abstractmethod
def preprocess(self, predictions, references, test_set) -> None:
"""Perform operations on predictions before computing metrics, for
example, do answer_extraction and model_judge in mathematical reasoning
task.
Return:
labels: A list contains the label which indicates whether
prediction is consistency with reference at each position.
"""
raise NotImplementedError
@abstractmethod
def group(self, predictions, labels, test_set) -> Dict[str, Any]:
"""Group the predictions and references.
Return:
A dict contains the grouped predictions and references.
"""
raise NotImplementedError
@abstractmethod
def reduce(self, details) -> Dict[str, Any]:
"""Aggregate the overall metrics.
Return:
A dict contains overall metrics, like:
{'details': details for each example, 'G-Pass@16': xxx}
"""
raise NotImplementedError
def score(self, predictions, references, test_set) -> Dict[str, Any]:
"""Compute G-Pass@k metrics.
Return:
A dict contains metrics for each dataset sample and
overall metrics reduced by `self.reduce`, like:
{'details': details for each example, 'G-Pass@16': xxx}
"""
labels = self.preprocess(predictions, references, test_set)
grouped_examples = self.group(predictions, labels, test_set)
details = []
total_pass_num, count = 0, 0
for example_abbr, examples in grouped_examples.items():
detail = {
k: v
for k, v in examples[0].items()
if k not in ['prediction', 'label']
}
detail.update({
'predictions': [{
'prediction': example['prediction'],
'label': example['label']
} for example in examples],
})
current_example_labels = [e['label'] for e in examples]
c = int(np.sum(current_example_labels))
for k in self.k:
for threshold in self.thresholds:
detail[f'G-Pass@{k}_{threshold}'] = compute_g_pass_at_k(
n=self.n, c=c, k=k, t=threshold)
detail[f'mG-Pass@{k}'] = compute_mg_pass_at_k(n=self.n,
c=c,
k=k)
count += self.n
total_pass_num += c
details.append(detail)
return self.reduce(details)

View File

@ -240,7 +240,10 @@ class OpenICLEvalTask(BaseTask):
k: preds[k] k: preds[k]
for k in signature(icl_evaluator.score).parameters for k in signature(icl_evaluator.score).parameters
} }
result = icl_evaluator.score(**preds) k = self.dataset_cfg.get('k', 1)
n = self.dataset_cfg.get('n', 1)
result = icl_evaluator.evaluate(k, n, copy.deepcopy(test_set),
**preds)
# Get model postprocess result # Get model postprocess result
model_details = None model_details = None
@ -248,7 +251,9 @@ class OpenICLEvalTask(BaseTask):
if 'model_postprocessor' in self.eval_cfg: if 'model_postprocessor' in self.eval_cfg:
model_preds = copy.deepcopy(preds) model_preds = copy.deepcopy(preds)
model_preds['predictions'] = model_pred_strs model_preds['predictions'] = model_pred_strs
model_result = icl_evaluator.score(**model_preds) model_result = icl_evaluator.evaluate(k, n,
copy.deepcopy(test_set),
**model_preds)
for key in model_result: for key in model_result:
if key == 'details': if key == 'details':
model_details = model_result[key] model_details = model_result[key]

View File

@ -9,7 +9,6 @@ def build_dataset_from_cfg(dataset_cfg: ConfigDict):
dataset_cfg = copy.deepcopy(dataset_cfg) dataset_cfg = copy.deepcopy(dataset_cfg)
dataset_cfg.pop('infer_cfg', None) dataset_cfg.pop('infer_cfg', None)
dataset_cfg.pop('eval_cfg', None) dataset_cfg.pop('eval_cfg', None)
dataset_cfg.pop('abbr', None)
return LOAD_DATASET.build(dataset_cfg) return LOAD_DATASET.build(dataset_cfg)