MedCalc_Bench

This commit is contained in:
huihui 2025-05-02 16:33:25 +00:00
parent 8c74e6a39e
commit c6e1955cae
3 changed files with 455 additions and 0 deletions

View File

@ -0,0 +1,58 @@
from opencompass.datasets import MedCalc_BenchDataset, MedCalcOfficial_Evaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
ZERO_SHOT_PROMPT = 'You are a helpful assistant for calculating a score for a given patient note. Please think step-by-step to solve the question and then generate the required score. Your output should only contain a JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}. \n Here is the patient note:\n{patient_note}\n\nHere is the task:\n{question}\n\nPlease directly output the JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}:'
# Reader configuration
reader_cfg = dict(
input_columns=[
'row_number',
'calculator_id',
'calculator_name',
'category',
'note_id',
'output_type',
'note_type',
'patient_note',
'question',
'relevant_entities',
'ground_truth_answer',
'lower_limit',
'upper_limit',
'ground_truth_explanation'
],
output_column='ground_truth_answer',
)
# Inference configuration
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN',prompt=ZERO_SHOT_PROMPT),
])
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# Evaluation configuration
eval_cfg = dict(
evaluator=dict(type=MedCalcOfficial_Evaluator),
pred_role='BOT',
)
medcal_bench_dataset = dict(
type=MedCalc_BenchDataset,
abbr='medcal_bench_official_zero_shot_eval',
path='ncbi/MedCalc-Bench-v1.0',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
medcal_bench_datasets = [medcal_bench_dataset]

View File

@ -0,0 +1,396 @@
import argparse
import os
import re
from datetime import datetime
import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_logger
from .base import BaseDataset
# https://github.com/ncbi-nlp/MedCalc-Bench/blob/main/evaluation/evaluate.py
# https://github.com/ncbi-nlp/MedCalc-Bench/blob/main/evaluation/run.py
def check_correctness(answer: str, ground_truth, calid, upper_limit,
lower_limit):
""""""
calid = int(calid)
if calid in [13, 68]:
# Output Type: date
if datetime.strptime(
answer,
'%m/%d/%Y').strftime('%-m/%-d/%Y') == datetime.strptime(
ground_truth, '%m/%d/%Y').strftime('%-m/%-d/%Y'):
correctness = 1
else:
correctness = 0
elif calid in [69]:
# Output Type: integer (A, B)
match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?",
ground_truth)
ground_truth = f'({match.group(1)}, {match.group(3)})'
match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?",
answer)
if match:
weeks = match.group(1)
days = match.group(3)
answer = f'({weeks}, {days})'
if eval(answer) == eval(ground_truth):
correctness = 1
else:
correctness = 0
else:
correctness = 0
elif calid in [
4, 15, 16, 17, 18, 20, 21, 25, 27, 28, 29, 32, 33, 36, 43, 45, 48,
51, 69
]:
# Output Type: integer A
answer = round(eval(answer))
if answer == eval(ground_truth):
correctness = 1
else:
correctness = 0
elif calid in [
2, 3, 5, 6, 7, 8, 9, 10, 11, 19, 22, 23, 24, 26, 30, 31, 38, 39,
40, 44, 46, 49, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67
]:
# Output Type: decimal
answer = eval(answer)
if answer >= eval(lower_limit) and answer <= eval(upper_limit):
correctness = 1
else:
correctness = 0
else:
raise ValueError(f'Unknown calculator ID: {calid}')
return correctness
def extract_answer(answer, calid):
calid = int(calid)
extracted_answer = re.findall(r'[Aa]nswer":\s*(.*?)\}', answer)
matches = re.findall(
r'"step_by_step_thinking":\s*"([^"]+)"\s*,\s*"[Aa]nswer"', answer)
if matches:
# Select the last match
last_match = matches[-1]
explanation = last_match
else:
explanation = 'No Explanation'
if len(extracted_answer) == 0:
extracted_answer = 'Not Found'
else:
extracted_answer = extracted_answer[-1].strip().strip('"')
if extracted_answer == 'str(short_and_direct_answer_of_the_question)' or extracted_answer == 'str(value which is the answer to the question)' or extracted_answer == 'X.XX':
extracted_answer = 'Not Found'
if calid in [13, 68]:
# Output Type: date
match = re.search(
r'^(0?[1-9]|1[0-2])\/(0?[1-9]|[12][0-9]|3[01])\/(\d{4})',
extracted_answer)
if match:
month = int(match.group(1))
day = int(match.group(2))
year = match.group(3)
answer = f'{month:02}/{day:02}/{year}'
else:
answer = 'N/A'
elif calid in [69]:
# Output Type: integer (A, B)
match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?",
extracted_answer)
ground_truth = f'({match.group(1)}, {match.group(3)})'
extracted_answer = extracted_answer.replace('[', '(').replace(
']', ')').replace("'", '').replace('"', '')
match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?",
extracted_answer)
if match:
weeks = match.group(1)
days = match.group(3)
answer = f'({weeks}, {days})'
else:
answer = 'N/A'
elif calid in [
4, 15, 16, 17, 18, 20, 21, 25, 27, 28, 29, 32, 33, 36, 43, 45, 48,
51, 69
]:
# Output Type: integer A
match = re.search(r'(\d+) out of', extracted_answer)
if match: # cases like "3 out of 5"
answer = match.group(1)
else:
match = re.search(r'-?\d+(, ?-?\d+)+', extracted_answer)
if match: # cases like "3, 4, 5"
answer = str(len(match.group(0).split(',')))
else:
# match = re.findall(r"(?<!-)\d+", extracted_answer)
match = re.findall(r'(-?\d+(\.\d+)?)', extracted_answer)
# match = re.findall(r"-?\d+", extracted_answer)
if len(match) > 0: # find the last integer
answer = match[-1][0]
# answer = match[-1].lstrip("0")
else:
answer = 'N/A'
elif calid in [
2, 3, 5, 6, 7, 8, 9, 10, 11, 19, 22, 23, 24, 26, 30, 31, 38, 39,
40, 44, 46, 49, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67
]:
# Output Type: decimal
match = re.search(r'str\((.*)\)', extracted_answer)
if match: # cases like "str(round((140 * (3.15 - 136) / 1400) * 72.36)"
expression = match.group(1).replace('^', '**').replace(
'is odd', '% 2 == 1').replace('is even', '% 2 == 0').replace(
'sqrt', 'math.sqrt').replace('.math', '').replace(
'weight',
'').replace('height', '').replace('mg/dl', '').replace(
'g/dl', '').replace('mmol/L', '').replace(
'kg', '').replace('g',
'').replace('mEq/L', '')
expression = expression.split(
'#'
)[0] # cases like round(45.5 * 166 - 45.3 + 0.4 * (75 - (45.5 * 166 - 45.3))))) # Calculation: ...
if expression.count('(') > expression.count(')'): # add missing ')
expression += ')' * (expression.count('(') -
expression.count(')'))
elif expression.count(')') > expression.count(
'('): # add missing (
expression = '(' * (expression.count(')') -
expression.count('(')) + expression
try:
answer = eval(expression, {'__builtins__': None}, {
'min': min,
'pow': pow,
'round': round,
'abs': abs,
'int': int,
'float': float,
'math': math,
'np': np,
'numpy': np
})
except:
print(f'Error in evaluating expression: {expression}')
answer = 'N/A'
else:
match = re.search(r'(-?\d+(\.\d+)?)\s*mL/min/1.73',
extracted_answer)
if match: # cases like "8.1 mL/min/1.73 m\u00b2"
answer = eval(match.group(1))
else:
match = re.findall(r'(-?\d+(\.\d+)?)\%', extracted_answer)
if len(match) > 0: # cases like "53.1%"
answer = eval(match[-1][0]) / 100
else:
match = re.findall(r'(-?\d+(\.\d+)?)', extracted_answer)
if len(
match
) > 0: # cases like "8.1 mL/min/1.73 m\u00b2" or "11.1"
answer = eval(match[-1][0])
else:
answer = 'N/A'
if answer != 'N/A':
answer = str(answer)
return answer, explanation
def _parse(item, prompt_mode):
item['row_number'] = item['Row Number']
item['calculator_id'] = item['Calculator ID']
item['calculator_name'] = item['Calculator Name']
item['category'] = item['Category']
item['output_type'] = item['Output Type']
item['note_id'] = item['Note ID']
item['note_type'] = item['Note Type']
item['patient_note'] = item['Patient Note']
item['question'] = item['Question']
item['relevant_entities'] = item['Relevant Entities']
item['ground_truth_answer'] = item['Ground Truth Answer']
item['lower_limit'] = item['Lower Limit']
item['upper_limit'] = item['Upper Limit']
item['ground_truth_explanation'] = item['Ground Truth Explanation']
return item
@LOAD_DATASET.register_module()
class MedCalc_BenchDataset(BaseDataset):
@staticmethod
def load(path: str, prompt_mode: str, **kwargs):
data_files = {
'test': 'data/test-00000-of-00001.parquet',
'train': 'data/train-00000-of-00001.parquet'
}
dataset = load_dataset(path, data_files=data_files, split='test')
# dataset = dataset.select(range(2))
if prompt_mode == 'zero-shot':
dataset = dataset.map(lambda item: _parse(item, prompt_mode),
load_from_cache_file=False)
elif prompt_mode == 'few-shot':
pass # TODO: Implement few-shot prompt
return dataset
class MedCalcOfficial_Evaluator(BaseEvaluator):
def score(self, predictions, references, test_set):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
correct = 0
count = 0
details = []
for idx, (i, j) in enumerate(zip(predictions, references)):
calculator_id = test_set['calculator_id'][idx]
lower_limit = test_set['lower_limit'][idx]
upper_limit = test_set['upper_limit'][idx]
row_number = test_set['row_number'][idx]
note_id = test_set['note_id'][idx]
category = test_set['category'][idx]
question = test_set['question'][idx]
calculator_name = test_set['calculator_name'][idx]
patient_note = test_set['patient_note'][idx]
ground_truth_explanation = test_set['ground_truth_explanation'][
idx]
ground_truth_answer = test_set['ground_truth_answer'][idx]
try:
answer_value, explanation = extract_answer(
i, int(calculator_id))
print(answer_value)
print(explanation)
correctness = check_correctness(answer_value,
ground_truth_answer,
calculator_id, upper_limit,
lower_limit)
status = 'Correct' if correctness else 'Incorrect'
outputs = {
'Row Number': int(row_number),
'Calculator Name': calculator_name,
'Calculator ID': calculator_id,
'Category': category,
'Note ID': note_id,
'Patient Note': patient_note,
'Question': question,
'LLM Answer': answer_value,
'LLM Explanation': explanation,
'Ground Truth Answer': ground_truth_answer,
'Ground Truth Explanation': ground_truth_explanation,
'Result': status
}
except Exception as e:
outputs = {
'Row Number': int(row_number),
'Calculator Name': calculator_name,
'Calculator ID': calculator_id,
'Category': category,
'Note ID': note_id,
'Patient Note': patient_note,
'Question': question,
'LLM Answer': str(e),
'LLM Explanation': str(e),
'Ground Truth Answer': ground_truth_answer,
'Ground Truth Explanation': ground_truth_explanation,
'Result': 'Incorrect'
}
status = 'Incorrect'
count += 1
if status == 'Correct':
correct += 1
details.append(outputs)
result = {'accuracy': 100 * correct / count, 'details': details}
return result
# def _generic_llmjudge_postprocess(judgement: str):
# match = re.search(r'(A|B)', judgement)
# grade_letter = (match.group(0) if match else 'B'
# ) # Default to "INCORRECT" if no match
# return grade_letter
# def MedCalc_Bench_llmjudge_postprocess(
# output: dict,
# output_path: str,
# dataset: Dataset,
# ) -> dict:
# # Get the original dataset
# original_dataset = dataset.reader.dataset['test']
# judged_answers = []
# original_responses = []
# references = []
# details = []
# total_correct = 0
# total_count = 0
# for k, v in output.items():
# idx = int(k) # Convert key to integer for indexing
# original_responses.append(v['prediction'])
# processed_judge = _generic_llmjudge_postprocess(v['prediction'])
# sample = original_dataset[idx]
# # Record the judgment
# if processed_judge is not None:
# judged_answers.append(processed_judge)
# try:
# gold = v['gold']
# references.append(gold)
# except KeyError:
# get_logger().warning(
# f'No gold answer for {k}, use empty string as reference!')
# gold = ''
# references.append('')
# # Check if the answer is correct (A means correct)
# is_correct = processed_judge == 'A'
# total_count += 1
# if is_correct:
# total_correct += 1
# # Add to details
# details.append({
# 'id': k,
# 'question': sample['question'],
# 'prediction': sample['prediction'],
# 'origin_prompt': v['origin_prompt'],
# 'llm_judge': processed_judge,
# 'gold': gold,
# 'is_correct': is_correct,
# })
# # Calculate overall accuracy with two decimal places
# overall_accuracy = (round(
# (total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
# # Initialize results dictionary
# results = {
# 'accuracy': overall_accuracy,
# 'total_correct': total_correct,
# 'total_count': total_count,
# 'details': details,
# }
# return results

View File

@ -95,6 +95,7 @@ from .math_intern import * # noqa: F401, F403
from .mathbench import * # noqa: F401, F403
from .mbpp import * # noqa: F401, F403
from .medbench import * # noqa: F401, F403
from .MedCalc_Bench import *
from .MedXpertQA import * # noqa: F401, F403
from .mgsm import * # noqa: F401, F403
from .mmlu import * # noqa: F401, F403