change repeat to n

This commit is contained in:
jnanliu 2025-02-24 08:11:27 +00:00
parent 2349fcff2c
commit b0330ef1c6
4 changed files with 10 additions and 12 deletions

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset,
path='',
k=16,
repeat=3,
n=48,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'],
cot=True,

View File

@ -9,7 +9,7 @@ livemathbench_dataset = dict(
type=LiveMathBenchDataset,
path='',
k=1,
repeat=1,
n=1,
dataset_splits=['CNMO', 'CCEE', 'AMC', 'WLPMC'],
dataset_languages=['cn', 'en'],
cot=True,

View File

@ -12,12 +12,13 @@ class BaseDataset:
def __init__(self,
reader_cfg: Optional[Dict] = {},
k: Union[int, List[int]] = 1,
repeat: int = 1,
n: int = 1,
**kwargs):
abbr = kwargs.pop('abbr', 'dataset')
dataset = self.load(**kwargs)
# maybe duplicate
n = (max(k) if isinstance(k, List) else k) * repeat
assert (max(k) if isinstance(k, List) else
k) <= n, 'Maximum value of `k` must less than or equal to `n`'
if isinstance(dataset, Dataset):
examples = []
for idx, example in enumerate(dataset):

View File

@ -81,9 +81,8 @@ class BaseEvaluator:
[detail[metric] for detail in details])
return g_passk_details
def evaluate(self, k: Union[int, List[int]], repeat: int,
def evaluate(self, k: Union[int, List[int]], n: int,
original_dataset: Dataset, **score_kwargs):
n = (max(k) if isinstance(k, List) else k) * repeat
real_size = len(original_dataset) // n
all_details = []
all_results = []
@ -119,10 +118,8 @@ class BaseEvaluator:
if isinstance(eval_results[key][0], float) or isinstance(
eval_results[key][0], int):
if n > 1:
m = n // repeat
eval_results[
key + f' ({m}x{repeat}={n} runs average)'] = np.mean(
eval_results[key])
eval_results[key + f' ({n} runs average)'] = np.mean(
eval_results[key])
eval_results.pop(key)
else:
eval_results[key] = np.mean(eval_results[key])
@ -147,7 +144,7 @@ class BaseEvaluator:
can_calculate = True
c += int(example['detail']['is_correct'])
if can_calculate and n > 1:
if can_calculate and n > 1 and k > 1:
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
for _k in ([k] if isinstance(k, int) else k):
for threshold in thresholds:
@ -162,7 +159,7 @@ class BaseEvaluator:
eval_details.append(detail)
if can_calculate and n > 1:
if can_calculate and n > 1 and k > 1:
eval_results.update(self.reduce(eval_details))
eval_results['details'] = eval_details