fix repeat to n in openicl_eval

This commit is contained in:
jnanliu 2025-02-24 08:14:08 +00:00
parent b0330ef1c6
commit 4e63ebbf0c

View File

@ -216,8 +216,8 @@ class OpenICLEvalTask(BaseTask):
for k in signature(icl_evaluator.score).parameters
}
k = self.dataset_cfg.get('k', 1)
repeat = self.dataset_cfg.get('repeat', 1)
result = icl_evaluator.evaluate(k, repeat, copy.deepcopy(test_set),
n = self.dataset_cfg.get('n', 1)
result = icl_evaluator.evaluate(k, n, copy.deepcopy(test_set),
**preds)
# Get model postprocess result
@ -226,7 +226,7 @@ class OpenICLEvalTask(BaseTask):
if 'model_postprocessor' in self.eval_cfg:
model_preds = copy.deepcopy(preds)
model_preds['predictions'] = model_pred_strs
model_result = icl_evaluator.evaluate(k, repeat,
model_result = icl_evaluator.evaluate(k, n,
copy.deepcopy(test_set),
**model_preds)
for key in model_result: