diff --git a/opencompass/configs/datasets/MedCalc_Bench/MedCalc_Bench_official_gen.py b/opencompass/configs/datasets/MedCalc_Bench/MedCalc_Bench_official_gen.py new file mode 100644 index 00000000..d863bc48 --- /dev/null +++ b/opencompass/configs/datasets/MedCalc_Bench/MedCalc_Bench_official_gen.py @@ -0,0 +1,58 @@ +from opencompass.datasets import MedCalc_BenchDataset, MedCalcOfficial_Evaluator +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever + +ZERO_SHOT_PROMPT = 'You are a helpful assistant for calculating a score for a given patient note. Please think step-by-step to solve the question and then generate the required score. Your output should only contain a JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}. \n Here is the patient note:\n{patient_note}\n\nHere is the task:\n{question}\n\nPlease directly output the JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}:' + +# Reader configuration +reader_cfg = dict( + input_columns=[ + 'row_number', + 'calculator_id', + 'calculator_name', + 'category', + 'note_id', + 'output_type', + 'note_type', + 'patient_note', + 'question', + 'relevant_entities', + 'ground_truth_answer', + 'lower_limit', + 'upper_limit', + 'ground_truth_explanation' + ], + output_column='ground_truth_answer', +) + + +# Inference configuration +infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN',prompt=ZERO_SHOT_PROMPT), + ]) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +# Evaluation configuration +eval_cfg = dict( + evaluator=dict(type=MedCalcOfficial_Evaluator), + pred_role='BOT', +) +medcal_bench_dataset = dict( + type=MedCalc_BenchDataset, + abbr='medcal_bench_official_zero_shot_eval', + path='ncbi/MedCalc-Bench-v1.0', + prompt_mode='zero-shot', + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, +) + +medcal_bench_datasets = [medcal_bench_dataset] diff --git a/opencompass/datasets/MedCalc_Bench.py b/opencompass/datasets/MedCalc_Bench.py new file mode 100644 index 00000000..34a361b8 --- /dev/null +++ b/opencompass/datasets/MedCalc_Bench.py @@ -0,0 +1,396 @@ +import argparse +import os +import re +from datetime import datetime + +import numpy as np +import pandas as pd +from datasets import Dataset, load_dataset + +from opencompass.openicl import BaseEvaluator +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS +from opencompass.utils import get_logger + +from .base import BaseDataset + + +# https://github.com/ncbi-nlp/MedCalc-Bench/blob/main/evaluation/evaluate.py +# https://github.com/ncbi-nlp/MedCalc-Bench/blob/main/evaluation/run.py +def check_correctness(answer: str, ground_truth, calid, upper_limit, + lower_limit): + """""" + calid = int(calid) + + if calid in [13, 68]: + # Output Type: date + + if datetime.strptime( + answer, + '%m/%d/%Y').strftime('%-m/%-d/%Y') == datetime.strptime( + ground_truth, '%m/%d/%Y').strftime('%-m/%-d/%Y'): + correctness = 1 + else: + correctness = 0 + elif calid in [69]: + # Output Type: integer (A, B) + match = re.search( + r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", + ground_truth) + ground_truth = f'({match.group(1)}, {match.group(3)})' + match = re.search( + r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", + answer) + if match: + weeks = match.group(1) + days = match.group(3) + answer = f'({weeks}, {days})' + if eval(answer) == eval(ground_truth): + correctness = 1 + else: + correctness = 0 + else: + correctness = 0 + elif calid in [ + 4, 15, 16, 17, 18, 20, 21, 25, 27, 28, 29, 32, 33, 36, 43, 45, 48, + 51, 69 + ]: + # Output Type: integer A + answer = round(eval(answer)) + if answer == eval(ground_truth): + correctness = 1 + else: + correctness = 0 + elif calid in [ + 2, 3, 5, 6, 7, 8, 9, 10, 11, 19, 22, 23, 24, 26, 30, 31, 38, 39, + 40, 44, 46, 49, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67 + ]: + # Output Type: decimal + answer = eval(answer) + if answer >= eval(lower_limit) and answer <= eval(upper_limit): + correctness = 1 + else: + correctness = 0 + else: + raise ValueError(f'Unknown calculator ID: {calid}') + return correctness + + +def extract_answer(answer, calid): + + calid = int(calid) + extracted_answer = re.findall(r'[Aa]nswer":\s*(.*?)\}', answer) + matches = re.findall( + r'"step_by_step_thinking":\s*"([^"]+)"\s*,\s*"[Aa]nswer"', answer) + + if matches: + # Select the last match + last_match = matches[-1] + explanation = last_match + else: + explanation = 'No Explanation' + + if len(extracted_answer) == 0: + extracted_answer = 'Not Found' + else: + extracted_answer = extracted_answer[-1].strip().strip('"') + if extracted_answer == 'str(short_and_direct_answer_of_the_question)' or extracted_answer == 'str(value which is the answer to the question)' or extracted_answer == 'X.XX': + extracted_answer = 'Not Found' + + if calid in [13, 68]: + # Output Type: date + match = re.search( + r'^(0?[1-9]|1[0-2])\/(0?[1-9]|[12][0-9]|3[01])\/(\d{4})', + extracted_answer) + if match: + month = int(match.group(1)) + day = int(match.group(2)) + year = match.group(3) + answer = f'{month:02}/{day:02}/{year}' + else: + answer = 'N/A' + + elif calid in [69]: + # Output Type: integer (A, B) + match = re.search( + r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", + extracted_answer) + ground_truth = f'({match.group(1)}, {match.group(3)})' + extracted_answer = extracted_answer.replace('[', '(').replace( + ']', ')').replace("'", '').replace('"', '') + match = re.search( + r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", + extracted_answer) + if match: + weeks = match.group(1) + days = match.group(3) + answer = f'({weeks}, {days})' + else: + answer = 'N/A' + elif calid in [ + 4, 15, 16, 17, 18, 20, 21, 25, 27, 28, 29, 32, 33, 36, 43, 45, 48, + 51, 69 + ]: + # Output Type: integer A + match = re.search(r'(\d+) out of', extracted_answer) + if match: # cases like "3 out of 5" + answer = match.group(1) + else: + match = re.search(r'-?\d+(, ?-?\d+)+', extracted_answer) + if match: # cases like "3, 4, 5" + answer = str(len(match.group(0).split(','))) + else: + # match = re.findall(r"(? 0: # find the last integer + answer = match[-1][0] + # answer = match[-1].lstrip("0") + else: + answer = 'N/A' + elif calid in [ + 2, 3, 5, 6, 7, 8, 9, 10, 11, 19, 22, 23, 24, 26, 30, 31, 38, 39, + 40, 44, 46, 49, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67 + ]: + # Output Type: decimal + match = re.search(r'str\((.*)\)', extracted_answer) + if match: # cases like "str(round((140 * (3.15 - 136) / 1400) * 72.36)" + expression = match.group(1).replace('^', '**').replace( + 'is odd', '% 2 == 1').replace('is even', '% 2 == 0').replace( + 'sqrt', 'math.sqrt').replace('.math', '').replace( + 'weight', + '').replace('height', '').replace('mg/dl', '').replace( + 'g/dl', '').replace('mmol/L', '').replace( + 'kg', '').replace('g', + '').replace('mEq/L', '') + expression = expression.split( + '#' + )[0] # cases like round(45.5 * 166 - 45.3 + 0.4 * (75 - (45.5 * 166 - 45.3))))) # Calculation: ... + if expression.count('(') > expression.count(')'): # add missing ') + expression += ')' * (expression.count('(') - + expression.count(')')) + elif expression.count(')') > expression.count( + '('): # add missing ( + expression = '(' * (expression.count(')') - + expression.count('(')) + expression + try: + answer = eval(expression, {'__builtins__': None}, { + 'min': min, + 'pow': pow, + 'round': round, + 'abs': abs, + 'int': int, + 'float': float, + 'math': math, + 'np': np, + 'numpy': np + }) + except: + print(f'Error in evaluating expression: {expression}') + answer = 'N/A' + else: + match = re.search(r'(-?\d+(\.\d+)?)\s*mL/min/1.73', + extracted_answer) + if match: # cases like "8.1 mL/min/1.73 m\u00b2" + answer = eval(match.group(1)) + else: + match = re.findall(r'(-?\d+(\.\d+)?)\%', extracted_answer) + if len(match) > 0: # cases like "53.1%" + answer = eval(match[-1][0]) / 100 + else: + match = re.findall(r'(-?\d+(\.\d+)?)', extracted_answer) + if len( + match + ) > 0: # cases like "8.1 mL/min/1.73 m\u00b2" or "11.1" + answer = eval(match[-1][0]) + else: + answer = 'N/A' + if answer != 'N/A': + answer = str(answer) + + return answer, explanation + + +def _parse(item, prompt_mode): + item['row_number'] = item['Row Number'] + item['calculator_id'] = item['Calculator ID'] + item['calculator_name'] = item['Calculator Name'] + item['category'] = item['Category'] + item['output_type'] = item['Output Type'] + item['note_id'] = item['Note ID'] + item['note_type'] = item['Note Type'] + item['patient_note'] = item['Patient Note'] + item['question'] = item['Question'] + item['relevant_entities'] = item['Relevant Entities'] + item['ground_truth_answer'] = item['Ground Truth Answer'] + item['lower_limit'] = item['Lower Limit'] + item['upper_limit'] = item['Upper Limit'] + item['ground_truth_explanation'] = item['Ground Truth Explanation'] + return item + + +@LOAD_DATASET.register_module() +class MedCalc_BenchDataset(BaseDataset): + + @staticmethod + def load(path: str, prompt_mode: str, **kwargs): + data_files = { + 'test': 'data/test-00000-of-00001.parquet', + 'train': 'data/train-00000-of-00001.parquet' + } + dataset = load_dataset(path, data_files=data_files, split='test') + # dataset = dataset.select(range(2)) + if prompt_mode == 'zero-shot': + dataset = dataset.map(lambda item: _parse(item, prompt_mode), + load_from_cache_file=False) + elif prompt_mode == 'few-shot': + pass # TODO: Implement few-shot prompt + return dataset + + +class MedCalcOfficial_Evaluator(BaseEvaluator): + + def score(self, predictions, references, test_set): + + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + + correct = 0 + count = 0 + details = [] + for idx, (i, j) in enumerate(zip(predictions, references)): + calculator_id = test_set['calculator_id'][idx] + lower_limit = test_set['lower_limit'][idx] + upper_limit = test_set['upper_limit'][idx] + row_number = test_set['row_number'][idx] + note_id = test_set['note_id'][idx] + category = test_set['category'][idx] + question = test_set['question'][idx] + calculator_name = test_set['calculator_name'][idx] + patient_note = test_set['patient_note'][idx] + ground_truth_explanation = test_set['ground_truth_explanation'][ + idx] + ground_truth_answer = test_set['ground_truth_answer'][idx] + try: + answer_value, explanation = extract_answer( + i, int(calculator_id)) + + print(answer_value) + print(explanation) + + correctness = check_correctness(answer_value, + ground_truth_answer, + calculator_id, upper_limit, + lower_limit) + + status = 'Correct' if correctness else 'Incorrect' + + outputs = { + 'Row Number': int(row_number), + 'Calculator Name': calculator_name, + 'Calculator ID': calculator_id, + 'Category': category, + 'Note ID': note_id, + 'Patient Note': patient_note, + 'Question': question, + 'LLM Answer': answer_value, + 'LLM Explanation': explanation, + 'Ground Truth Answer': ground_truth_answer, + 'Ground Truth Explanation': ground_truth_explanation, + 'Result': status + } + + except Exception as e: + outputs = { + 'Row Number': int(row_number), + 'Calculator Name': calculator_name, + 'Calculator ID': calculator_id, + 'Category': category, + 'Note ID': note_id, + 'Patient Note': patient_note, + 'Question': question, + 'LLM Answer': str(e), + 'LLM Explanation': str(e), + 'Ground Truth Answer': ground_truth_answer, + 'Ground Truth Explanation': ground_truth_explanation, + 'Result': 'Incorrect' + } + status = 'Incorrect' + count += 1 + if status == 'Correct': + correct += 1 + details.append(outputs) + + result = {'accuracy': 100 * correct / count, 'details': details} + return result + + +# def _generic_llmjudge_postprocess(judgement: str): +# match = re.search(r'(A|B)', judgement) +# grade_letter = (match.group(0) if match else 'B' +# ) # Default to "INCORRECT" if no match +# return grade_letter + +# def MedCalc_Bench_llmjudge_postprocess( +# output: dict, +# output_path: str, +# dataset: Dataset, +# ) -> dict: +# # Get the original dataset +# original_dataset = dataset.reader.dataset['test'] + +# judged_answers = [] +# original_responses = [] +# references = [] +# details = [] + +# total_correct = 0 +# total_count = 0 + +# for k, v in output.items(): +# idx = int(k) # Convert key to integer for indexing +# original_responses.append(v['prediction']) +# processed_judge = _generic_llmjudge_postprocess(v['prediction']) + +# sample = original_dataset[idx] +# # Record the judgment +# if processed_judge is not None: +# judged_answers.append(processed_judge) +# try: +# gold = v['gold'] +# references.append(gold) +# except KeyError: +# get_logger().warning( +# f'No gold answer for {k}, use empty string as reference!') +# gold = '' +# references.append('') + +# # Check if the answer is correct (A means correct) +# is_correct = processed_judge == 'A' +# total_count += 1 + +# if is_correct: +# total_correct += 1 + +# # Add to details +# details.append({ +# 'id': k, +# 'question': sample['question'], +# 'prediction': sample['prediction'], +# 'origin_prompt': v['origin_prompt'], +# 'llm_judge': processed_judge, +# 'gold': gold, +# 'is_correct': is_correct, +# }) + +# # Calculate overall accuracy with two decimal places +# overall_accuracy = (round( +# (total_correct / total_count * 100), 2) if total_count > 0 else 0.00) + +# # Initialize results dictionary +# results = { +# 'accuracy': overall_accuracy, +# 'total_correct': total_correct, +# 'total_count': total_count, +# 'details': details, +# } +# return results diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b00162d1..ccf0ddd4 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -95,6 +95,7 @@ from .math_intern import * # noqa: F401, F403 from .mathbench import * # noqa: F401, F403 from .mbpp import * # noqa: F401, F403 from .medbench import * # noqa: F401, F403 +from .MedCalc_Bench import * from .MedXpertQA import * # noqa: F401, F403 from .mgsm import * # noqa: F401, F403 from .mmlu import * # noqa: F401, F403