This commit is contained in:
MaiziXiao 2025-04-02 03:35:25 +00:00
parent 55aa1717ce
commit e64ca780fa
2 changed files with 14 additions and 23 deletions

View File

@ -91,7 +91,8 @@ class BaseEvaluator:
):
# Check if predictions and references have the
# same length if both are provided
if 'predictions' in score_kwargs and 'references' in score_kwargs:
if ('predictions' in score_kwargs and 'references' in score_kwargs
and score_kwargs['references'] is not None):
if len(score_kwargs['predictions']) != len(
score_kwargs['references']):
raise ValueError(

View File

@ -256,11 +256,6 @@ class OpenICLEvalTask(BaseTask):
)
evaluator._out_dir = osp.splitext(out_path)[0] # strip extension
# Run evaluation - ensure we pass the correct parameters
# to score method
score_sig = signature(evaluator.score)
score_params = {}
# If preds contains keys that match the score method
# parameters, include them
if pred_dicts:
@ -268,27 +263,22 @@ class OpenICLEvalTask(BaseTask):
k: [pred.get(k) for pred in pred_dicts]
for k in pred_dicts[0]
}
score_params = {
k: preds[k]
for k in score_sig.parameters if k in preds
}
# Add predictions and references if they're expected
# by the score method
if 'predictions' in score_sig.parameters:
score_params['predictions'] = pred_strs
if 'references' in score_sig.parameters:
score_params['references'] = references
if 'test_set' in score_sig.parameters:
score_params['test_set'] = test_set
if 'origin_prompt' in score_sig.parameters:
preds['predictions'] = pred_strs
preds['references'] = (test_set[self.output_column]
if self.output_column else None)
preds['test_set'] = test_set
if 'origin_prompt' not in preds:
try:
score_params['origin_prompt'] = [
None for _ in range(len(pred_strs))
]
preds['origin_prompt'] = [None for _ in range(len(pred_strs))]
except TypeError:
score_params['origin_prompt'] = None
# Call score with the appropriate parameters
result = evaluator.score(**score_params)
preds['origin_prompt'] = None
preds = {k: preds[k] for k in signature(evaluator.score).parameters}
# Call evaluate with the appropriate parameters
k = self.dataset_cfg.get('k', 1)
n = self.dataset_cfg.get('n', 1)
result = evaluator.evaluate(k, n, copy.deepcopy(test_set), **preds)
# Format details if needed
if self.dump_details: