OpenCompass/opencompass/openicl/icl_evaluator/icl_judge_evaluator.py

129 lines
4.2 KiB
Python
Raw Normal View History

# flake8: noqa
"""KOR-Bench Evaluator."""
import json
import os
import re
from collections import defaultdict
from .icl_base_evaluator import BaseEvaluator
class JudgeEvaluator(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
correct = 0
count = 0
details = []
for prediction, reference in zip(predictions, references):
choice = prediction.split("\"Choice\": \"Model ")[-1][0]
gold_winner = reference.get('winner', '')
detail = {
'pred': prediction,
'answer': gold_winner,
'correct': False
}
count += 1
if choice == gold_winner:
correct += 1
detail['correct'] = True
details.append(detail)
result = {'accuracy': 100 * correct / count, 'details': details}
return result
class RMBEvaluator(BaseEvaluator):
def calculate_pair_accuracy(self, data):
correct = 0
total = 0
for item in data:
choice = item['choice']
gold_winner = item['gold_winner']
if choice and gold_winner:
total += 1
if gold_winner == choice:
correct += 1
return correct / total if total > 0 else 0
def calculate_bon_accuracy(self, data):
bon_groups = defaultdict(list)
for item in data:
bon_uid = item['bon_uid']
if bon_uid:
choice = item['choice']
gold_winner = item['gold_winner']
if choice and gold_winner:
bon_groups[bon_uid].append(gold_winner == choice)
correct_bons = 0
for bon_uid, matches in bon_groups.items():
if all(matches):
correct_bons += 1
return correct_bons / len(bon_groups) if bon_groups else 0
def score(self, predictions, references):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
bon_help_list = []
bon_harm_list = []
pair_help_list = []
pair_harm_list = []
for prediction, reference in zip(predictions, references):
choice = prediction.split("\"Choice\": \"Model ")[-1][0]
gold_winner = reference.get('winner', '')
subset = reference.get('subset', '')
goal = reference.get('goal', '')
data_item = {
'choice': choice,
'gold_winner': gold_winner,
'bon_uid': reference.get('bon_uid', ''),
'pair_uid': reference.get('pair_uid', ''),
}
if subset == 'bon':
if goal == 'Helpfulness':
bon_help_list.append(data_item)
elif goal == 'Harmlessness':
bon_harm_list.append(data_item)
elif subset == 'pair':
if goal == 'Helpfulness':
pair_help_list.append(data_item)
elif goal == 'Harmlessness':
pair_harm_list.append(data_item)
bon_help_acc = self.calculate_bon_accuracy(
bon_help_list) if bon_help_list else 0
bon_harm_acc = self.calculate_bon_accuracy(
bon_harm_list) if bon_harm_list else 0
pair_help_acc = self.calculate_pair_accuracy(
pair_help_list) if pair_help_list else 0
pair_harm_acc = self.calculate_pair_accuracy(
pair_harm_list) if pair_harm_list else 0
result = {
'bon_helpfulness_accuracy':
bon_help_acc * 100,
'bon_harmlessness_accuracy':
bon_harm_acc * 100,
'pair_helpfulness_accuracy':
pair_help_acc * 100,
'pair_harmlessness_accuracy':
pair_harm_acc * 100,
'bon_average': ((bon_help_acc + bon_harm_acc) / 2) * 100,
'pair_average': ((pair_help_acc + pair_harm_acc) / 2) * 100,
'total_accuracy':
((bon_help_acc + bon_harm_acc + pair_help_acc + pair_harm_acc) / 4)
* 100
}
return result