This commit is contained in:
huihui 2025-05-09 08:00:44 +00:00
parent 47fd267d4d
commit 075f9c53d4

View File

@ -9,11 +9,6 @@ from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET from opencompass.registry import LOAD_DATASET
from .base import BaseDataset from .base import BaseDataset
"""
the original evaluation codes are from
https://github.com/ncbi-nlp/MedCalc-Bench/blob/main/evaluation/evaluate.py
https://github.com/ncbi-nlp/MedCalc-Bench/blob/main/evaluation/run.py
"""
def check_correctness(answer: str, ground_truth, calid, upper_limit, def check_correctness(answer: str, ground_truth, calid, upper_limit,
@ -34,12 +29,12 @@ def check_correctness(answer: str, ground_truth, calid, upper_limit,
elif calid in [69]: elif calid in [69]:
# Output Type: integer (A, B) # Output Type: integer (A, B)
match = re.search( match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?"
ground_truth) r"\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", ground_truth)
ground_truth = f'({match.group(1)}, {match.group(3)})' ground_truth = f'({match.group(1)}, {match.group(3)})'
match = re.search( match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?"
answer) r"\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", answer)
if match: if match:
weeks = match.group(1) weeks = match.group(1)
days = match.group(3) days = match.group(3)
@ -80,7 +75,8 @@ def extract_answer(answer, calid):
calid = int(calid) calid = int(calid)
extracted_answer = re.findall(r'[Aa]nswer":\s*(.*?)\}', answer) extracted_answer = re.findall(r'[Aa]nswer":\s*(.*?)\}', answer)
matches = re.findall( matches = re.findall(
r'"step_by_step_thinking":\s*"([^"]+)"\s*,\s*"[Aa]nswer"', answer) r'"step_by_step_thinking":\s*"'
r'([^"]+)"\s*,\s*"[Aa]nswer"', answer)
if matches: if matches:
# Select the last match # Select the last match
@ -105,8 +101,8 @@ def extract_answer(answer, calid):
if calid in [13, 68]: if calid in [13, 68]:
# Output Type: date # Output Type: date
match = re.search( match = re.search(
r'^(0?[1-9]|1[0-2])\/(0?[1-9]|[12][0-9]|3[01])\/(\d{4})', r'^(0?[1-9]|1[0-2])\/(0?[1-9]'
extracted_answer) r'|[12][0-9]|3[01])\/(\d{4})', extracted_answer)
if match: if match:
month = int(match.group(1)) month = int(match.group(1))
day = int(match.group(2)) day = int(match.group(2))
@ -118,13 +114,13 @@ def extract_answer(answer, calid):
elif calid in [69]: elif calid in [69]:
# Output Type: integer (A, B) # Output Type: integer (A, B)
match = re.search( match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,\?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,"
extracted_answer) r"\?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", extracted_answer)
extracted_answer = extracted_answer.replace('[', '(').replace( extracted_answer = extracted_answer.replace('[', '(').replace(
']', ')').replace("'", '').replace('"', '') ']', ')').replace("'", '').replace('"', '')
match = re.search( match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,"
extracted_answer) r"?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", extracted_answer)
if match: if match:
weeks = match.group(1) weeks = match.group(1)
days = match.group(3) days = match.group(3)