add rmb datasets

This commit is contained in:
taolinzhang 2025-04-27 08:14:42 +00:00
parent 9d6f3a4866
commit 3b83a5f4a3
3 changed files with 38 additions and 32 deletions

View File

@ -1,2 +1,2 @@
from .rewardbench import RewardBenchDataset # noqa: F401, F403
from .rmb import RMBDataset # noqa: F401, F403
from .rmb import RMBDataset # noqa: F401, F403

View File

@ -14,9 +14,6 @@ from opencompass.utils import get_data_path
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class RMBDataset(BaseDataset):
@ -47,7 +44,7 @@ class RMBDataset(BaseDataset):
else:
question += '\n\n ### Assistant:' + line['content']
question += '\n\n ### Assistant:'
winner = "A"
winner = 'A'
pair_uid = item['pair_uid']
subset = item['subset']
goal = item['goal']
@ -83,7 +80,7 @@ class RMBDataset(BaseDataset):
goal = item['goal']
for loser in item['loser_list']:
conversation_b = loser['answer']
winner = "A"
winner = 'A'
raw_item = {
'question': question,
'answerA': conversation_a,
@ -100,4 +97,3 @@ class RMBDataset(BaseDataset):
}
raw_item_list.append(raw_item)
return raw_item_list

View File

@ -33,7 +33,9 @@ class JudgeEvaluator(BaseEvaluator):
result = {'accuracy': 100 * correct / count, 'details': details}
return result
class RMBEvaluator(BaseEvaluator):
def calculate_pair_accuracy(self, data):
correct = 0
total = 0
@ -57,9 +59,7 @@ class RMBEvaluator(BaseEvaluator):
choice = item['choice']
gold_winner = item['gold_winner']
if choice and gold_winner:
bon_groups[bon_uid].append(
gold_winner == choice
)
bon_groups[bon_uid].append(gold_winner == choice)
# 计算每个bon_uid是否全部正确
correct_bons = 0
@ -106,20 +106,30 @@ class RMBEvaluator(BaseEvaluator):
pair_harm_list.append(data_item)
# 计算四种组合的准确率
bon_help_acc = self.calculate_bon_accuracy(bon_help_list) if bon_help_list else 0
bon_harm_acc = self.calculate_bon_accuracy(bon_harm_list) if bon_harm_list else 0
pair_help_acc = self.calculate_pair_accuracy(pair_help_list) if pair_help_list else 0
pair_harm_acc = self.calculate_pair_accuracy(pair_harm_list) if pair_harm_list else 0
bon_help_acc = self.calculate_bon_accuracy(
bon_help_list) if bon_help_list else 0
bon_harm_acc = self.calculate_bon_accuracy(
bon_harm_list) if bon_harm_list else 0
pair_help_acc = self.calculate_pair_accuracy(
pair_help_list) if pair_help_list else 0
pair_harm_acc = self.calculate_pair_accuracy(
pair_harm_list) if pair_harm_list else 0
# 返回所有结果
result = {
'bon_helpfulness_accuracy': bon_help_acc * 100,
'bon_harmlessness_accuracy': bon_harm_acc * 100,
'pair_helpfulness_accuracy': pair_help_acc * 100,
'pair_harmlessness_accuracy': pair_harm_acc * 100,
'bon_helpfulness_accuracy':
bon_help_acc * 100,
'bon_harmlessness_accuracy':
bon_harm_acc * 100,
'pair_helpfulness_accuracy':
pair_help_acc * 100,
'pair_harmlessness_accuracy':
pair_harm_acc * 100,
'bon_average': ((bon_help_acc + bon_harm_acc) / 2) * 100,
'pair_average': ((pair_help_acc + pair_harm_acc) / 2) * 100,
'total_accuracy': ((bon_help_acc + bon_harm_acc + pair_help_acc + pair_harm_acc) / 4) * 100
'total_accuracy':
((bon_help_acc + bon_harm_acc + pair_help_acc + pair_harm_acc) / 4)
* 100
}
return result