OpenCompass/opencompass/openicl/icl_evaluator/icl_em_evaluator.py

42 lines
1.3 KiB
Python
Raw Normal View History

2023-07-04 21:34:55 +08:00
from opencompass.registry import ICL_EVALUATORS
from opencompass.utils.text_postprocessors import general_postprocess
from .icl_base_evaluator import BaseEvaluator
@ICL_EVALUATORS.register_module()
class EMEvaluator(BaseEvaluator):
"""Exact match evaluator."""
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
predictions = [
general_postprocess(prediction) for prediction in predictions
]
processed_answers = [[general_postprocess(j) for j in i]
for i in references]
cnt = 0
2023-10-27 20:31:22 +08:00
details = []
2023-07-04 21:34:55 +08:00
for pred, ans, origin_ans in zip(predictions, processed_answers,
references):
2023-10-27 20:31:22 +08:00
answers = list(set(ans + origin_ans))
detail = {'pred': pred, 'answer': answers}
2023-07-04 21:34:55 +08:00
if pred in ans or pred in origin_ans:
cnt += 1
2023-10-27 20:31:22 +08:00
detail['correct'] = True
else:
detail['correct'] = False
details.append(detail)
2023-07-04 21:34:55 +08:00
score = cnt / len(predictions) * 100
2023-10-27 20:31:22 +08:00
return {'score': score, 'details': details}