From 1acb3c30c01f8e7741f5135d0cf7a6064a3168ab Mon Sep 17 00:00:00 2001 From: MaiziXiao Date: Thu, 8 May 2025 07:26:18 +0000 Subject: [PATCH] update --- .../nejmaibench_llmjudge_gen_60c8f5.py | 5 +- opencompass/datasets/nejmaibench.py | 117 +----------------- opencompass/utils/datasets_info.py | 9 +- 3 files changed, 12 insertions(+), 119 deletions(-) diff --git a/opencompass/configs/datasets/nejm_ai_benchmark/nejmaibench_llmjudge_gen_60c8f5.py b/opencompass/configs/datasets/nejm_ai_benchmark/nejmaibench_llmjudge_gen_60c8f5.py index 4ac26f66..31be8049 100644 --- a/opencompass/configs/datasets/nejm_ai_benchmark/nejmaibench_llmjudge_gen_60c8f5.py +++ b/opencompass/configs/datasets/nejm_ai_benchmark/nejmaibench_llmjudge_gen_60c8f5.py @@ -1,7 +1,8 @@ -from opencompass.datasets import NejmaibenchDataset, nejmaibench_llmjudge_postprocess +from opencompass.datasets import NejmaibenchDataset from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.datasets import generic_llmjudge_postprocess from opencompass.evaluator import GenericLLMEvaluator import os @@ -88,7 +89,7 @@ eval_cfg = dict( reader_cfg=reader_cfg, ), judge_cfg=dict(), - dict_postprocessor=dict(type=nejmaibench_llmjudge_postprocess), + dict_postprocessor=dict(type=generic_llmjudge_postprocess), ), ) diff --git a/opencompass/datasets/nejmaibench.py b/opencompass/datasets/nejmaibench.py index 2c32dfbe..7dd7c418 100644 --- a/opencompass/datasets/nejmaibench.py +++ b/opencompass/datasets/nejmaibench.py @@ -5,7 +5,7 @@ from datasets import Dataset from opencompass.openicl import BaseEvaluator from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS -from opencompass.utils import get_data_path, get_logger +from opencompass.utils import get_data_path from .base import BaseDataset @@ -136,117 +136,4 @@ def answer_cleansing( if prediction[0] and prediction[0].endswith('.'): prediction[0] = prediction[0][:-1] - return prediction[0] - - -def _generic_llmjudge_postprocess(judgement: str): - match = re.search(r'(A|B)', judgement) - grade_letter = (match.group(0) if match else 'B' - ) # Default to "INCORRECT" if no match - return grade_letter - - -def nejmaibench_llmjudge_postprocess( - output: dict, - output_path: str, - dataset: Dataset, -) -> dict: - # Get the original dataset - original_dataset = dataset.reader.dataset['test'] - - judged_answers = [] - original_responses = [] - references = [] - details = [] - - # Initialize statistics dictionaries - stats = {'Subject': {}} - - total_correct = 0 - total_count = 0 - - # Process each sample - for k, v in output.items(): - idx = int(k) # Convert key to integer for indexing - original_responses.append(v['prediction']) - - processed_judge = _generic_llmjudge_postprocess(v['prediction']) - - # Get category information from the dataset - sample = original_dataset[idx] - subject = sample.get('Subject', 'unknown') - - # Initialize category stats if not exists - for level, key in [ - ('Subject', subject), - ]: - if key not in stats[level]: - stats[level][key] = {'correct': 0, 'total': 0} - - # Record the judgment - if processed_judge is not None: - judged_answers.append(processed_judge) - try: - gold = v['gold'] - references.append(gold) - except KeyError: - get_logger().warning( - f'No gold answer for {k}, use empty string as reference!') - gold = '' - references.append('') - - # Check if the answer is correct (A means correct) - is_correct = processed_judge == 'A' - total_count += 1 - - if is_correct: - total_correct += 1 - # Update category stats - for level, key in [ - ('Subject', subject), - ]: - stats[level][key]['correct'] += 1 - - # Update category totals - for level, key in [ - ('Subject', subject), - ]: - stats[level][key]['total'] += 1 - # Add to details - details.append({ - 'id': k, - 'question': sample['question'], - 'options': sample['options'], - 'origin_prompt': v['origin_prompt'], - 'llm_judge': processed_judge, - 'gold': gold, - 'is_correct': is_correct, - 'Subject': subject, - }) - - # Calculate overall accuracy with two decimal places - overall_accuracy = (round( - (total_correct / total_count * 100), 2) if total_count > 0 else 0.00) - - # Initialize results dictionary - results = { - 'accuracy': overall_accuracy, - 'total_correct': total_correct, - 'total_count': total_count, - 'details': details, - } - - # Calculate accuracy for each category and flatten into results - for level in stats: - for key, value in stats[level].items(): - if value['total'] > 0: - # Calculate accuracy with two decimal places - accuracy = round((value['correct'] / value['total'] * 100), 2) - - # Create a flattened key for the category - flat_key = f'nejmaibench-{key}' - - # Add to results - results[flat_key] = accuracy - - return results + return prediction[0] \ No newline at end of file diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index c4994fad..10ca4436 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -448,8 +448,8 @@ DATASETS_MAPPING = { }, "opencompass/nejmaibench": { "ms_id": "", - "hf_id": "SeanWu25/NEJM-AI_Benchmarking_Medical_Language_Models", - "local": "./opencompass/configs/datasets/nejm_ai_benchmark/data/NEJM_All_Questions_And_Answers.csv", + "hf_id": "", + "local": "./data/nejmaibench/NEJM_All_Questions_And_Answers.csv", }, } @@ -803,6 +803,11 @@ DATASETS_URL = { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/ChemBench4K.zip", "md5": "fc23fd21b2566a5dbbebfa4601d7779c" + }, + "nejmaibench": { + "url": + "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/nejmaibench.zip", + "md5": "e6082cae3596b3ebea73e23ba445b99e" } }