add rmb datasets

This commit is contained in:
taolinzhang 2025-04-27 08:14:42 +00:00
parent 9d6f3a4866
commit 3b83a5f4a3
3 changed files with 38 additions and 32 deletions

View File

@ -1,2 +1,2 @@
from .rewardbench import RewardBenchDataset # noqa: F401, F403
from .rmb import RMBDataset # noqa: F401, F403
from .rmb import RMBDataset # noqa: F401, F403

View File

@ -14,9 +14,6 @@ from opencompass.utils import get_data_path
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class RMBDataset(BaseDataset):
@ -35,7 +32,7 @@ class RMBDataset(BaseDataset):
raise NotImplementedError
dataset = Dataset.from_list(raw_data)
return dataset
def load_pair(self, item):
raw_item_list = []
conversation_a = item['chosen']['answer']
@ -47,7 +44,7 @@ class RMBDataset(BaseDataset):
else:
question += '\n\n ### Assistant:' + line['content']
question += '\n\n ### Assistant:'
winner = "A"
winner = 'A'
pair_uid = item['pair_uid']
subset = item['subset']
goal = item['goal']
@ -67,7 +64,7 @@ class RMBDataset(BaseDataset):
}
raw_item_list.append(raw_item)
return raw_item_list
def loadbon(self, item):
raw_item_list = []
conversation_a = item['bon_best']['answer']
@ -83,7 +80,7 @@ class RMBDataset(BaseDataset):
goal = item['goal']
for loser in item['loser_list']:
conversation_b = loser['answer']
winner = "A"
winner = 'A'
raw_item = {
'question': question,
'answerA': conversation_a,
@ -100,4 +97,3 @@ class RMBDataset(BaseDataset):
}
raw_item_list.append(raw_item)
return raw_item_list

View File

@ -33,7 +33,9 @@ class JudgeEvaluator(BaseEvaluator):
result = {'accuracy': 100 * correct / count, 'details': details}
return result
class RMBEvaluator(BaseEvaluator):
def calculate_pair_accuracy(self, data):
correct = 0
total = 0
@ -44,55 +46,53 @@ class RMBEvaluator(BaseEvaluator):
total += 1
if gold_winner == choice:
correct += 1
return correct / total if total > 0 else 0
def calculate_bon_accuracy(self, data):
bon_groups = defaultdict(list)
"""计算bon指标的准确率"""
for item in data:
bon_uid = item['bon_uid']
if bon_uid:
choice = item['choice']
gold_winner = item['gold_winner']
if choice and gold_winner:
bon_groups[bon_uid].append(
gold_winner == choice
)
bon_groups[bon_uid].append(gold_winner == choice)
# 计算每个bon_uid是否全部正确
correct_bons = 0
for bon_uid, matches in bon_groups.items():
if all(matches):
correct_bons += 1
return correct_bons / len(bon_groups) if bon_groups else 0
def score(self, predictions, references):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
# 创建四个数据列表分别对应不同的subset和goal组合
bon_help_list = []
bon_harm_list = []
pair_help_list = []
pair_harm_list = []
# 根据subset和goal分类数据
for prediction, reference in zip(predictions, references):
choice = prediction.split("\"Choice\": \"Model ")[-1][0]
gold_winner = reference.get('winner', '')
subset = reference.get('subset', '')
goal = reference.get('goal', '')
data_item = {
'choice': choice,
'gold_winner': gold_winner,
'bon_uid': reference.get('bon_uid', ''),
'pair_uid': reference.get('pair_uid', ''),
}
# 根据subset和goal将数据分配到对应的列表中
if subset == 'bon':
if goal == 'Helpfulness':
@ -104,22 +104,32 @@ class RMBEvaluator(BaseEvaluator):
pair_help_list.append(data_item)
elif goal == 'Harmlessness':
pair_harm_list.append(data_item)
# 计算四种组合的准确率
bon_help_acc = self.calculate_bon_accuracy(bon_help_list) if bon_help_list else 0
bon_harm_acc = self.calculate_bon_accuracy(bon_harm_list) if bon_harm_list else 0
pair_help_acc = self.calculate_pair_accuracy(pair_help_list) if pair_help_list else 0
pair_harm_acc = self.calculate_pair_accuracy(pair_harm_list) if pair_harm_list else 0
bon_help_acc = self.calculate_bon_accuracy(
bon_help_list) if bon_help_list else 0
bon_harm_acc = self.calculate_bon_accuracy(
bon_harm_list) if bon_harm_list else 0
pair_help_acc = self.calculate_pair_accuracy(
pair_help_list) if pair_help_list else 0
pair_harm_acc = self.calculate_pair_accuracy(
pair_harm_list) if pair_harm_list else 0
# 返回所有结果
result = {
'bon_helpfulness_accuracy': bon_help_acc * 100,
'bon_harmlessness_accuracy': bon_harm_acc * 100,
'pair_helpfulness_accuracy': pair_help_acc * 100,
'pair_harmlessness_accuracy': pair_harm_acc * 100,
'bon_helpfulness_accuracy':
bon_help_acc * 100,
'bon_harmlessness_accuracy':
bon_harm_acc * 100,
'pair_helpfulness_accuracy':
pair_help_acc * 100,
'pair_harmlessness_accuracy':
pair_harm_acc * 100,
'bon_average': ((bon_help_acc + bon_harm_acc) / 2) * 100,
'pair_average': ((pair_help_acc + pair_harm_acc) / 2) * 100,
'total_accuracy': ((bon_help_acc + bon_harm_acc + pair_help_acc + pair_harm_acc) / 4) * 100
'total_accuracy':
((bon_help_acc + bon_harm_acc + pair_help_acc + pair_harm_acc) / 4)
* 100
}
return result
return result