From 3b83a5f4a3a375cfde9c06ef73a5f468c09d3f0b Mon Sep 17 00:00:00 2001 From: taolinzhang <673879891@qq.com> Date: Sun, 27 Apr 2025 08:14:42 +0000 Subject: [PATCH] add rmb datasets --- opencompass/datasets/judge/__init__.py | 2 +- opencompass/datasets/judge/rmb.py | 12 ++-- .../icl_evaluator/icl_judge_evaluator.py | 56 +++++++++++-------- 3 files changed, 38 insertions(+), 32 deletions(-) diff --git a/opencompass/datasets/judge/__init__.py b/opencompass/datasets/judge/__init__.py index e416747b..41b63b6a 100644 --- a/opencompass/datasets/judge/__init__.py +++ b/opencompass/datasets/judge/__init__.py @@ -1,2 +1,2 @@ from .rewardbench import RewardBenchDataset # noqa: F401, F403 -from .rmb import RMBDataset # noqa: F401, F403 \ No newline at end of file +from .rmb import RMBDataset # noqa: F401, F403 diff --git a/opencompass/datasets/judge/rmb.py b/opencompass/datasets/judge/rmb.py index b6ff7713..72e118b9 100644 --- a/opencompass/datasets/judge/rmb.py +++ b/opencompass/datasets/judge/rmb.py @@ -14,9 +14,6 @@ from opencompass.utils import get_data_path from ..base import BaseDataset - - - @LOAD_DATASET.register_module() class RMBDataset(BaseDataset): @@ -35,7 +32,7 @@ class RMBDataset(BaseDataset): raise NotImplementedError dataset = Dataset.from_list(raw_data) return dataset - + def load_pair(self, item): raw_item_list = [] conversation_a = item['chosen']['answer'] @@ -47,7 +44,7 @@ class RMBDataset(BaseDataset): else: question += '\n\n ### Assistant:' + line['content'] question += '\n\n ### Assistant:' - winner = "A" + winner = 'A' pair_uid = item['pair_uid'] subset = item['subset'] goal = item['goal'] @@ -67,7 +64,7 @@ class RMBDataset(BaseDataset): } raw_item_list.append(raw_item) return raw_item_list - + def loadbon(self, item): raw_item_list = [] conversation_a = item['bon_best']['answer'] @@ -83,7 +80,7 @@ class RMBDataset(BaseDataset): goal = item['goal'] for loser in item['loser_list']: conversation_b = loser['answer'] - winner = "A" + winner = 'A' raw_item = { 'question': question, 'answerA': conversation_a, @@ -100,4 +97,3 @@ class RMBDataset(BaseDataset): } raw_item_list.append(raw_item) return raw_item_list - diff --git a/opencompass/openicl/icl_evaluator/icl_judge_evaluator.py b/opencompass/openicl/icl_evaluator/icl_judge_evaluator.py index a7f127b0..93d694d4 100644 --- a/opencompass/openicl/icl_evaluator/icl_judge_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_judge_evaluator.py @@ -33,7 +33,9 @@ class JudgeEvaluator(BaseEvaluator): result = {'accuracy': 100 * correct / count, 'details': details} return result + class RMBEvaluator(BaseEvaluator): + def calculate_pair_accuracy(self, data): correct = 0 total = 0 @@ -44,55 +46,53 @@ class RMBEvaluator(BaseEvaluator): total += 1 if gold_winner == choice: correct += 1 - + return correct / total if total > 0 else 0 def calculate_bon_accuracy(self, data): bon_groups = defaultdict(list) """计算bon指标的准确率""" - + for item in data: bon_uid = item['bon_uid'] if bon_uid: choice = item['choice'] gold_winner = item['gold_winner'] if choice and gold_winner: - bon_groups[bon_uid].append( - gold_winner == choice - ) + bon_groups[bon_uid].append(gold_winner == choice) # 计算每个bon_uid是否全部正确 correct_bons = 0 for bon_uid, matches in bon_groups.items(): if all(matches): correct_bons += 1 - + return correct_bons / len(bon_groups) if bon_groups else 0 def score(self, predictions, references): if len(predictions) != len(references): return {'error': 'preds and refrs have different length'} - + # 创建四个数据列表,分别对应不同的subset和goal组合 bon_help_list = [] bon_harm_list = [] pair_help_list = [] pair_harm_list = [] - + # 根据subset和goal分类数据 for prediction, reference in zip(predictions, references): choice = prediction.split("\"Choice\": \"Model ")[-1][0] gold_winner = reference.get('winner', '') subset = reference.get('subset', '') goal = reference.get('goal', '') - + data_item = { 'choice': choice, 'gold_winner': gold_winner, 'bon_uid': reference.get('bon_uid', ''), 'pair_uid': reference.get('pair_uid', ''), } - + # 根据subset和goal将数据分配到对应的列表中 if subset == 'bon': if goal == 'Helpfulness': @@ -104,22 +104,32 @@ class RMBEvaluator(BaseEvaluator): pair_help_list.append(data_item) elif goal == 'Harmlessness': pair_harm_list.append(data_item) - + # 计算四种组合的准确率 - bon_help_acc = self.calculate_bon_accuracy(bon_help_list) if bon_help_list else 0 - bon_harm_acc = self.calculate_bon_accuracy(bon_harm_list) if bon_harm_list else 0 - pair_help_acc = self.calculate_pair_accuracy(pair_help_list) if pair_help_list else 0 - pair_harm_acc = self.calculate_pair_accuracy(pair_harm_list) if pair_harm_list else 0 - + bon_help_acc = self.calculate_bon_accuracy( + bon_help_list) if bon_help_list else 0 + bon_harm_acc = self.calculate_bon_accuracy( + bon_harm_list) if bon_harm_list else 0 + pair_help_acc = self.calculate_pair_accuracy( + pair_help_list) if pair_help_list else 0 + pair_harm_acc = self.calculate_pair_accuracy( + pair_harm_list) if pair_harm_list else 0 + # 返回所有结果 result = { - 'bon_helpfulness_accuracy': bon_help_acc * 100, - 'bon_harmlessness_accuracy': bon_harm_acc * 100, - 'pair_helpfulness_accuracy': pair_help_acc * 100, - 'pair_harmlessness_accuracy': pair_harm_acc * 100, + 'bon_helpfulness_accuracy': + bon_help_acc * 100, + 'bon_harmlessness_accuracy': + bon_harm_acc * 100, + 'pair_helpfulness_accuracy': + pair_help_acc * 100, + 'pair_harmlessness_accuracy': + pair_harm_acc * 100, 'bon_average': ((bon_help_acc + bon_harm_acc) / 2) * 100, 'pair_average': ((pair_help_acc + pair_harm_acc) / 2) * 100, - 'total_accuracy': ((bon_help_acc + bon_harm_acc + pair_help_acc + pair_harm_acc) / 4) * 100 + 'total_accuracy': + ((bon_help_acc + bon_harm_acc + pair_help_acc + pair_harm_acc) / 4) + * 100 } - - return result \ No newline at end of file + + return result