This commit is contained in:
huihui 2025-05-08 11:12:17 +00:00
parent b9aa1c17f7
commit 724472ee5d
2 changed files with 13 additions and 15 deletions

View File

@ -4,7 +4,6 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_retriever import ZeroRetriever
ZERO_SHOT_PROMPT = 'You are a helpful assistant for calculating a score for a given patient note. Please think step-by-step to solve the question and then generate the required score. Your output should only contain a JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}. \n Here is the patient note:\n{patient_note}\n\nHere is the task:\n{question}\n\nPlease directly output the JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}:' ZERO_SHOT_PROMPT = 'You are a helpful assistant for calculating a score for a given patient note. Please think step-by-step to solve the question and then generate the required score. Your output should only contain a JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}. \n Here is the patient note:\n{patient_note}\n\nHere is the task:\n{question}\n\nPlease directly output the JSON dict formatted as {"step_by_step_thinking": str(your_step_by_step_thinking_procress_to_solve_the_question), "answer": str(short_and_direct_answer_of_the_question)}:'
# Reader configuration # Reader configuration
reader_cfg = dict( reader_cfg = dict(
input_columns=[ input_columns=[

View File

@ -1,16 +1,12 @@
import argparse
import math import math
import os
import re import re
from datetime import datetime from datetime import datetime
import numpy as np import numpy as np
import pandas as pd from datasets import load_dataset
from datasets import Dataset, load_dataset
from opencompass.openicl import BaseEvaluator from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_logger
from .base import BaseDataset from .base import BaseDataset
""" """
@ -97,7 +93,13 @@ def extract_answer(answer, calid):
extracted_answer = 'Not Found' extracted_answer = 'Not Found'
else: else:
extracted_answer = extracted_answer[-1].strip().strip('"') extracted_answer = extracted_answer[-1].strip().strip('"')
if extracted_answer == 'str(short_and_direct_answer_of_the_question)' or extracted_answer == 'str(value which is the answer to the question)' or extracted_answer == 'X.XX': if extracted_answer == 'str(short_and_direct\
_answer_of_the_question)':
extracted_answer = 'Not Found'
if extracted_answer == 'str(value which is\
the answer to the question)':
extracted_answer = 'Not Found'
if extracted_answer == 'X.XX':
extracted_answer = 'Not Found' extracted_answer = 'Not Found'
if calid in [13, 68]: if calid in [13, 68]:
@ -116,9 +118,8 @@ def extract_answer(answer, calid):
elif calid in [69]: elif calid in [69]:
# Output Type: integer (A, B) # Output Type: integer (A, B)
match = re.search( match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?", r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,\?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?",
extracted_answer) extracted_answer)
ground_truth = f'({match.group(1)}, {match.group(3)})'
extracted_answer = extracted_answer.replace('[', '(').replace( extracted_answer = extracted_answer.replace('[', '(').replace(
']', ')').replace("'", '').replace('"', '') ']', ')').replace("'", '').replace('"', '')
match = re.search( match = re.search(
@ -157,7 +158,7 @@ def extract_answer(answer, calid):
]: ]:
# Output Type: decimal # Output Type: decimal
match = re.search(r'str\((.*)\)', extracted_answer) match = re.search(r'str\((.*)\)', extracted_answer)
if match: # cases like "str(round((140 * (3.15 - 136) / 1400) * 72.36)" if match:
expression = match.group(1).replace('^', '**').replace( expression = match.group(1).replace('^', '**').replace(
'is odd', '% 2 == 1').replace('is even', '% 2 == 0').replace( 'is odd', '% 2 == 1').replace('is even', '% 2 == 0').replace(
'sqrt', 'math.sqrt').replace('.math', '').replace( 'sqrt', 'math.sqrt').replace('.math', '').replace(
@ -166,9 +167,7 @@ def extract_answer(answer, calid):
'g/dl', '').replace('mmol/L', '').replace( 'g/dl', '').replace('mmol/L', '').replace(
'kg', '').replace('g', 'kg', '').replace('g',
'').replace('mEq/L', '') '').replace('mEq/L', '')
expression = expression.split( expression = expression.split('#')[0]
'#'
)[0] # cases like round(45.5 * 166 - 45.3 + 0.4 * (75 - (45.5 * 166 - 45.3))))) # Calculation: ...
if expression.count('(') > expression.count(')'): # add missing ') if expression.count('(') > expression.count(')'): # add missing ')
expression += ')' * (expression.count('(') - expression += ')' * (expression.count('(') -
expression.count(')')) expression.count(')'))
@ -188,7 +187,7 @@ def extract_answer(answer, calid):
'np': np, 'np': np,
'numpy': np 'numpy': np
}) })
except: except Exception:
print(f'Error in evaluating expression: {expression}') print(f'Error in evaluating expression: {expression}')
answer = 'N/A' answer = 'N/A'
else: else: