change repeat to n

This commit is contained in:
jnanliu 2025-02-24 08:11:27 +00:00
parent 2349fcff2c
commit b0330ef1c6
4 changed files with 10 additions and 12 deletions

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset, type=LiveMathBenchDataset,
path='', path='',
k=16, k=16,
repeat=3, n=48,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'], dataset_languages=['cn', 'en'],
cot=True, cot=True,

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset, type=LiveMathBenchDataset,
path='', path='',
k=1, k=1,
repeat=1, n=1,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'], dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'], dataset_languages=['cn', 'en'],
cot=True, cot=True,

View File

@ -12,12 +12,13 @@ class BaseDataset:
def __init__(self, def __init__(self,
reader_cfg: Optional[Dict] = {}, reader_cfg: Optional[Dict] = {},
k: Union[int, List[int]] = 1, k: Union[int, List[int]] = 1,
repeat: int = 1, n: int = 1,
**kwargs): **kwargs):
abbr = kwargs.pop('abbr', 'dataset') abbr = kwargs.pop('abbr', 'dataset')
dataset = self.load(**kwargs) dataset = self.load(**kwargs)
# maybe duplicate # maybe duplicate
n = (max(k) if isinstance(k, List) else k) * repeat assert (max(k) if isinstance(k, List) else
k) <= n, 'Maximum value of `k` must less than or equal to `n`'
if isinstance(dataset, Dataset): if isinstance(dataset, Dataset):
examples = [] examples = []
for idx, example in enumerate(dataset): for idx, example in enumerate(dataset):

View File

@ -81,9 +81,8 @@ class BaseEvaluator:
[detail[metric] for detail in details]) [detail[metric] for detail in details])
return g_passk_details return g_passk_details
def evaluate(self, k: Union[int, List[int]], repeat: int, def evaluate(self, k: Union[int, List[int]], n: int,
original_dataset: Dataset, **score_kwargs): original_dataset: Dataset, **score_kwargs):
n = (max(k) if isinstance(k, List) else k) * repeat
real_size = len(original_dataset) // n real_size = len(original_dataset) // n
all_details = [] all_details = []
all_results = [] all_results = []
@ -119,10 +118,8 @@ class BaseEvaluator:
if isinstance(eval_results[key][0], float) or isinstance( if isinstance(eval_results[key][0], float) or isinstance(
eval_results[key][0], int): eval_results[key][0], int):
if n > 1: if n > 1:
m = n // repeat eval_results[key + f' ({n} runs average)'] = np.mean(
eval_results[ eval_results[key])
key + f' ({m}x{repeat}={n} runs average)'] = np.mean(
eval_results[key])
eval_results.pop(key) eval_results.pop(key)
else: else:
eval_results[key] = np.mean(eval_results[key]) eval_results[key] = np.mean(eval_results[key])
@ -147,7 +144,7 @@ class BaseEvaluator:
can_calculate = True can_calculate = True
c += int(example['detail']['is_correct']) c += int(example['detail']['is_correct'])
if can_calculate and n > 1: if can_calculate and n > 1 and k > 1:
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0] thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
for _k in ([k] if isinstance(k, int) else k): for _k in ([k] if isinstance(k, int) else k):
for threshold in thresholds: for threshold in thresholds:
@ -162,7 +159,7 @@ class BaseEvaluator:
eval_details.append(detail) eval_details.append(detail)
if can_calculate and n > 1: if can_calculate and n > 1 and k > 1:
eval_results.update(self.reduce(eval_details)) eval_results.update(self.reduce(eval_details))
eval_results['details'] = eval_details eval_results['details'] = eval_details