[Feature] Update MedBench (#779)

* update medbench

* medbench update

* format medbench

* format

* Update

* update

* update

* update suffix

---------

Co-authored-by: 施晓明 <PJLAB\shixiaoming@pjnl104220118l.pjlab.org>
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Xiaoming Shi 2024-01-09 11:42:44 +08:00 committed by GitHub
parent a74e4c1a8d
commit ad872a5dc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 139 additions and 142 deletions

View File

@ -6,7 +6,7 @@ exclude: |
opencompass/openicl/icl_evaluator/hf_metrics/|
opencompass/datasets/lawbench/utils|
opencompass/datasets/lawbench/evaluation_functions/|
opencompass/datasets/medbench|
opencompass/datasets/medbench/|
docs/zh_cn/advanced_guides/compassbench_intro.md
)
repos:

View File

@ -2,41 +2,24 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import (
MedBenchDataset,
MedBenchEvaluator,
MedBenchEvaluator_Cloze,
MedBenchEvaluator_IE,
MedBenchEvaluator_mcq,
MedBenchEvaluator_CMeEE,
MedBenchEvaluator_CMeIE,
MedBenchEvaluator_CHIP_CDEE,
MedBenchEvaluator_CHIP_CDN,
MedBenchEvaluator_CHIP_CTC,
MedBenchEvaluator_NLG,
MedBenchEvaluator_TF,
MedBenchEvaluator_EMR,
)
from opencompass.datasets import MedBenchDataset, MedBenchEvaluator, MedBenchEvaluator_Cloze, MedBenchEvaluator_IE, MedBenchEvaluator_mcq, MedBenchEvaluator_CMeEE, MedBenchEvaluator_CMeIE, MedBenchEvaluator_CHIP_CDEE, MedBenchEvaluator_CHIP_CDN, MedBenchEvaluator_CHIP_CTC, MedBenchEvaluator_NLG, MedBenchEvaluator_TF, MedBenchEvaluator_DBMHG, MedBenchEvaluator_SMDoc, MedBenchEvaluator_IMCS_V2_MRG
from opencompass.utils.text_postprocessors import first_capital_postprocess
medbench_reader_cfg = dict(
input_columns=['problem_input'], output_column='label')
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题用acc判断
medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'SafetyBench'] # 选择题用acc判断
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA有标答
medbench_qa_sets = ['MedHC', 'MedMC', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA有标答
medbench_cloze_sets = ['Triage'] # 限定域QA有标答
medbench_cloze_sets = ['MedHG'] # 限定域QA有标答
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
medbench_single_choice_sets = ['DrugCA'] # 正确与否判断,有标答
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致用F1评价
#, 'CMeIE', 'CHIP_CDEE', 'CHIP_CDN', 'CHIP_CTC', 'Doc_parsing', 'MRG'
medbench_ie_sets = ['DBMHG', 'CMeEE', 'CMeIE', 'CHIP-CDEE', 'CHIP-CDN', 'CHIP-CTC', 'SMDoc', 'IMCS-V2-MRG'] # 判断识别的实体是否一致用F1评价
medbench_datasets = []
for name in medbench_single_choice_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
@ -144,7 +127,7 @@ for name in medbench_ie_sets:
inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict(
evaluator=dict(type=eval('MedBenchEvaluator_'+name)), pred_role="BOT")
evaluator=dict(type=eval('MedBenchEvaluator_'+name.replace('-', '_'))), pred_role="BOT")
medbench_datasets.append(
dict(
@ -157,4 +140,4 @@ for name in medbench_ie_sets:
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))
del name, medbench_infer_cfg, medbench_eval_cfg
del name, medbench_infer_cfg, medbench_eval_cfg

View File

@ -11,31 +11,31 @@ from .constructions import ChatGPTSchema, ResultsForHumanSchema
from .utils import extract_answer, read_jsonl, save_jsonl
# define the datasets
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题用acc判断
medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'DDx-advanced', 'SafetyBench'] # 选择题用acc判断
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA有标答
medbench_qa_sets = ['MedHC', 'MedMC', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA有标答
medbench_cloze_sets = ['Triage'] # 限定域QA有标答
medbench_cloze_sets = ['MedHG'] # 限定域QA有标答
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
medbench_single_choice_sets = ['DrugCA'] # 正确与否判断,有标答
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致用F1评价
medbench_ie_sets = ['DBMHG', 'CMeEE', 'CMeIE', 'CHIP-CDEE', 'CHIP-CDN', 'CHIP-CTC', 'SMDoc', 'IMCS-V2-MRG'] # 判断识别的实体是否一致用F1评价
def convert_zero_shot(line, dataset_name):
# passage = line['passage'] if line['passage'] is not None else ''
if dataset_name in medbench_qa_sets:
return line['question']
elif dataset_name in medbench_cloze_sets:
return '问题:' + line['question'] + '\n答案:'
elif dataset_name in medbench_multiple_choices_sets:
return '问题:' + line['question'] + ' ' \
+ '选项:' + ' '.join(line['options']) + '\n从A到G我们应该选择'
else:
return line['question']
# if dataset_name in medbench_qa_sets:
# return line['question']
# elif dataset_name in medbench_cloze_sets:
# return '问题:' + line['question'] + '\n答案'
# elif dataset_name in medbench_multiple_choices_sets:
# return '问题:' + line['question'] + ' ' \
# + '选项:' + ' '.join(line['options']) + '\n从A到G我们应该选择'
# else:
# return line['question']
return line['question']
prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n'
# def convert_zero_shot_CoT_stage1(line, dataset_name):
# try:
# passage = line['passage'] if line['passage'] is not None else ''

File diff suppressed because one or more lines are too long

View File

@ -82,10 +82,8 @@ class MedBenchEvaluator(BaseEvaluator):
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
#输出字典类型 {'score':'', 'details'}
return {'Accuracy': score, 'details': details}
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_mcq(BaseEvaluator):
@ -109,16 +107,18 @@ class MedBenchEvaluator_mcq(BaseEvaluator):
return {'score': score, 'details': details}
def process_generated_results_CMeEE(pred_file):
# 实体每类占一行,每行格式为 "[类型名称]实体实体名称1实体名称2实体名称3\n"
# 多个实体,用 ,符号分割
structured_output = []
answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病']
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('[,|.|。||\n]', pred):
for piece in re.split('[。||\n]', pred):
if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").split("")
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").replace(f'{choice}', '').replace(f'{choice}:', '').split("")
for ment in mentions:
list_entities.append({'entity':ment, 'type':choice})
list_entities.append({'type':choice, 'entity':ment})
structured_output.append(list_entities)
return structured_output
@ -128,12 +128,15 @@ def process_generated_results_EMR(pred_file):
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('[,|.|?|;||。||\n]', pred):
if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}", "").split("")
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
for ment in mentions:
list_entities.append({ment: choice})
for piece in re.split('\n', pred):
# if piece.startswith(f"{choice}"):
if f"{choice}" in piece and len(piece.split(f"{choice}"))>1:
# mentions = piece.replace(f"{choice}", "").split("")
mentions = piece.split(f"{choice}")[1].strip()
# mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
list_entities.append({choice: mentions})
# for ment in mentions:
# list_entities.append({choice: ment})
structured_output.append(list_entities)
return structured_output
@ -141,7 +144,7 @@ def process_generated_results_CMeIE(pred_file):
structured_output = []
for line in pred_file:
gen_output = line
# 答案格式:
# 每个关系类型占一行,格式为
# "具有{lab}关系的头尾实体对如下头实体为str尾实体为str头实体为str尾实体为str"
@ -156,14 +159,17 @@ def process_generated_results_CMeIE(pred_file):
# 首先是解析出label:
predicate = line.split("关系的头尾实体对")[0][2: ].strip()
line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "")
for spo_str in line.split(""):
if len(spo_str.split(",尾实体为")) < 2:
# for spo_str in line.split("。"):
for spo_str in re.split('|。', line):
if len(spo_str.split(",尾实体:")) < 2:
continue
head_mention_str, tail_mention_str = spo_str.split(",尾实体为")[:2]
head_mention_str = head_mention_str.replace("头实体为", "").strip()
tail_mention_str = tail_mention_str.replace("尾实体为", "").strip()
head_mention_str, tail_mention_str = spo_str.split(",尾实体:")[:2]
head_mention_str = head_mention_str.replace("头实体:", "").strip()
tail_mention_str = tail_mention_str.replace("尾实体:", "").strip()
list_spos.append(
{
"predicate": predicate,
@ -176,10 +182,10 @@ def process_generated_results_CMeIE(pred_file):
def process_generated_results_CDN(pred_file):
structured_output = []
answer_choices = json.load(open('./data/MedBench/CHIP_CDN/CHIP-CDN_entity.json', 'r'))
answer_choices = json.load(open('./opencompass/datasets/medbench/entity_list.jsonl', 'r'))
for line in pred_file:
gen_output = line
# 答案格式:
# 多个选中的标准化实体,用 符号分割
@ -211,15 +217,17 @@ def process_generated_results_CDEE(pred_file):
keys = ["主体词", "发生状态", "描述词", "解剖部位"]
list_answer_strs = gen_output.split("\n")
# list_answer_strs: ['主题词:饮食,描述词:差;', '主题词:消瘦']
list_events = []
for ans_str in list_answer_strs:
if '主体词' in ans_str:
event_info = {}
ans_attrs = ans_str.split("")
ans_attrs = ans_str.split("")
for a_attr in ans_attrs:
for key in keys:
if a_attr.startswith(f"{key}"):
a_attr = a_attr.replace(f"{key}", "").strip()
a_attr = a_attr.replace(f"{key}", "").strip().strip('')
if key in ["描述词", "解剖部位"]:
a_attr_split = a_attr.split("")
a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0]
@ -239,7 +247,7 @@ def process_generated_results_CDEE(pred_file):
structured_output.append(list_events)
return structured_output
def process_generated_results_CTC(pred_file, task_dataset):
def process_generated_results_CTC(pred_file):
structured_output = []
for line in pred_file:
@ -252,60 +260,60 @@ def process_generated_results_CTC(pred_file, task_dataset):
def process_generated_results_doc_parsing(pred_file):
output = []
for line in pred_file:
structured_output = {'体温':'', '脉搏':'', '心率':'', '收缩压':'', '舒张压':'', '呼吸':'', '上腹部深压痛':'', '腹部反跳痛':'', '上腹部肿块':''}
sentence_list = line.strip().split('|。|\n')
structured_output = []
sentence_list = line.strip().split('\n')
for sentence in sentence_list:
if '体温' in sentence:
temp_value = re.search('[0-9]+', sentence)
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output['体温'] = temp_value.group(0)
structured_output.append({'type':'体温', 'entity':temp_value.group(0)})
else:
structured_output['体温'] = '未扪及'
structured_output.append({'type':'体温', 'entity':'未扪及'})
elif '脉搏' in sentence:
temp_value = re.search('[0-9]+', sentence)
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output['脉搏'] = temp_value.group(0)
structured_output.append({'type':'脉搏', 'entity':temp_value.group(0)})
else:
structured_output['脉搏'] = '未扪及'
structured_output.append({'type':'脉搏', 'entity':'未扪及'})
elif '心率' in sentence:
temp_value = re.search('[0-9]+', sentence)
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output['心率'] = temp_value.group(0)
structured_output.append({'type':'心率', 'entity':temp_value.group(0)})
else:
structured_output['心率'] = '未扪及'
structured_output.append({'type':'心率', 'entity':'未扪及'})
elif '收缩压' in sentence:
temp_value = re.search('[0-9]+', sentence)
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output['收缩压'] = temp_value.group(0)
structured_output.append({'type':'收缩压', 'entity':temp_value.group(0)})
else:
structured_output['收缩压'] = '未扪及'
structured_output.append({'type':'收缩压', 'entity':'未扪及'})
elif '舒张压' in sentence:
temp_value = re.search('[0-9]+', sentence)
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output['舒张压'] = temp_value.group(0)
structured_output.append({'type':'舒张压', 'entity':temp_value.group(0)})
else:
structured_output['舒张压'] = '未扪及'
structured_output.append({'type':'舒张压', 'entity':'未扪及'})
elif '呼吸' in sentence:
temp_value = re.search('[0-9]+', sentence)
temp_value = re.search('[0-9]+.[0-9]', sentence)
if temp_value:
structured_output['呼吸'] = temp_value.group(0)
structured_output.append({'type':'呼吸', 'entity':temp_value.group(0)})
else:
structured_output['呼吸'] = '未扪及'
structured_output.append({'type':'呼吸', 'entity':'未扪及'})
elif '上腹部深压痛' in sentence:
if re.search('是|存在|有', sentence):
structured_output['是否上腹部深压痛'] = ''
if re.search('未|不|没|无', sentence):
structured_output.append({'type':'上腹部深压痛', 'entity':'否是'})
else:
structured_output['是否上腹部深压痛'] = ''
structured_output.append({'type':'上腹部深压痛', 'entity':''})
elif '腹部反跳痛' in sentence:
if re.search('是|存在|有', sentence):
structured_output['是否腹部反跳痛'] = ''
if re.search('未|不|没|无', sentence):
structured_output.append({'type':'腹部反跳痛', 'entity':''})
else:
structured_output['是否腹部反跳痛'] = ''
structured_output.append({'type':'腹部反跳痛', 'entity':''})
elif '上腹部肿块' in sentence:
if re.search('是|存在|有', sentence):
structured_output['上腹部肿块'] = '扪及'
if re.search('未|不|没|无', sentence):
structured_output.append({'type':'上腹部肿块', 'entity':'未扪及'})
else:
structured_output['上腹部肿块'] = '未扪及'
structured_output.append({'type':'上腹部肿块', 'entity':'扪及'})
output.append(structured_output)
return output
@ -315,18 +323,22 @@ def process_generated_results_mrg(pred_file):
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('[,|.|?|;||。||\n]', pred):
if piece.startswith(f"{choice}实体"):
mentions = piece.replace(f"{choice}实体:", "").split("")
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
for ment in mentions:
list_entities.append({ment: choice})
if '\n\n' in pred['answer']:
for piece in re.split('\n\n', pred['answer']):
if f"{choice}" in piece and len(piece.split(f"{choice}"))>1:
mentions = piece.split(f"{choice}")[1].strip()
list_entities.append({choice:mentions})
else:
for piece in re.split('\n', pred):
if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}", "").split("")
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
for ment in mentions:
list_entities.append({choice:ment})
structured_output.append(list_entities)
return structured_output
def calc_info_extract_task_scores(list_structured_golden,
list_structured_predict):
def calc_info_extract_task_scores(list_structured_predict, list_structured_golden):
assert len(list_structured_golden) == len(list_structured_predict)
@ -334,12 +346,11 @@ def calc_info_extract_task_scores(list_structured_golden,
fp = 0
fn = 0
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
# samp_golden: [[{}]]
answer_golden = samp_golden
answer_predict = samp_predict
assert isinstance(answer_golden, list)
assert isinstance(answer_predict, list), "sample format is wrong!"
# assert isinstance(answer_golden, list)
# assert isinstance(answer_predict, list), "sample format is wrong!"
set_golden = set()
for inst in answer_golden:
@ -356,18 +367,11 @@ def calc_info_extract_task_scores(list_structured_golden,
for inst in answer_predict:
assert isinstance(inst, dict)
keys = sorted(list(inst.keys()))
# inst = tuple([inst[w] for w in keys])
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys])
# inst = list(inst.items())
# inst.sort()
# inst = tuple(inst)
set_predict.add(inst)
# print("set_predict: ", set_predict)
# print("set_golden: ", set_golden)
tp += len(set_golden.intersection(set_predict))
fp += len(set_predict.difference(set_golden))
fn += len(set_golden.difference(set_predict))
@ -402,7 +406,9 @@ def calc_cls_task_scores(list_structured_golden,
pred_label = pred_samp
gt_label = gt_samp
assert gt_label != ""
# assert gt_label != ""
if gt_label == "":
get_label = list_labels[0]
if pred_label == "":
pred_label = list_labels[0]
@ -434,16 +440,10 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
references = []
details = []
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
# print("samp_golden: ", samp_golden)
# print("samp_predict: ", samp_predict)
# assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!"
answer_golden = samp_golden
answer_predict = samp_predict
print('#')
print(answer_golden)
print(answer_predict)
if not (answer_predict and answer_golden):
continue
@ -456,8 +456,6 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
answer_golden = "无 。"
if answer_predict.strip() == "":
answer_predict = "无 。"
# print("answer_predict: ", answer_predict)
# print("answer_golden: ", answer_golden)
predictions.append(answer_predict)
references.append(answer_golden)
@ -487,7 +485,7 @@ def calc_scores_f1(dict_gt, dict_pred):
details = []
for gt, pred in zip(dict_gt, dict_pred):
details.append({'pred':pred, 'answer':gt, 'correct':None})
precision, recall, f1 = calc_info_extract_task_scores(dict_gt, dict_pred)
return {'F1':f1, 'details':details}
@ -498,7 +496,7 @@ def calc_scores_ctc(dict_gt, dict_pred):
gts = dict_gt
preds = dict_pred
precision, recall, f1 = calc_cls_task_scores(
gts,
preds,
@ -520,9 +518,9 @@ def calc_scores_ctc(dict_gt, dict_pred):
return_macro=True,
)
return {'Macro-F1':f1, 'details':details}
def calc_scores_nlg(dict_gt, dict_pred):
# scores = {}
scores = {'score':0, 'details':[]}
success_flag = 1
@ -532,7 +530,7 @@ def calc_scores_nlg(dict_gt, dict_pred):
# if not len(gts) == len(preds):
# success_flag = 0
# try:
return calc_nlg_task_scores(gts, preds)
return calc_nlg_task_scores(gts, preds)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CMeEE(BaseEvaluator):
@ -542,14 +540,14 @@ class MedBenchEvaluator_CMeEE(BaseEvaluator):
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_EMR(BaseEvaluator):
class MedBenchEvaluator_DBMHG(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_EMR(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_MRG(BaseEvaluator):
class MedBenchEvaluator_IMCS_V2_MRG(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_mrg(predictions)
@ -581,10 +579,10 @@ class MedBenchEvaluator_CHIP_CTC(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CTC(predictions)
return calc_scores_ctc(predictions, references)[0]
return calc_scores_ctc(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
class MedBenchEvaluator_SMDoc(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_doc_parsing(predictions)
@ -594,23 +592,36 @@ class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
class MedBenchEvaluator_NLG(BaseEvaluator):
def score(self, predictions, references):
# predictions = process_generated_results_med(predictions)
return calc_scores_nlg(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Cloze(BaseEvaluator):
def score(self, predictions, references):
# predictions: [[]]
# references: [[]]
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "小儿风湿免疫科", "放射科", "小儿内分泌代谢科", "急诊科", "心血管内科", "小儿神经内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "小儿心脏中心", "产科", "血液内科", "小儿普外科", "小儿泌尿外科", "小儿感染科", "临床营养科", "小儿骨科", "发育行为儿童保健科", "小儿呼吸内科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "小儿神经外科", "康复医学科", "骨科", "风湿免疫科", "小儿内科", "眼科", "心胸外科", "小儿肾脏内科", "乳腺外科", "小儿血液肿瘤科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "小儿消化内科", "呼吸内科", "核医学科", "肾脏内科"]
no_erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "放射科", "急诊科", "心血管内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "产科", "血液内科", "临床营养科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "康复医学科", "骨科", "风湿免疫科", "眼科", "心胸外科", "乳腺外科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "呼吸内科", "核医学科", "肾脏内科"]
cross_erke_list = [item for item in erke_list if '小儿' in item and item.replace('小儿', '') in no_erke_list]
cross_list = [item[2:] for item in cross_erke_list]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred':pred, 'answer':ref, 'correct':False}
current_pred = []
for x in cross_list:
if '小儿' + x in predictions:
current_pred.append('小儿' + x)
elif x in predictions:
current_pred.append(x)
if sum([item in pred for item in ref]) == len(ref):
for x in (set(erke_list + no_erke_list) - set(cross_erke_list) - set(cross_list)):
if x in predictions:
current_pred.append(x)
# if set([x for x in erke_list + no_erke_list if x in pred]) == set(ref):
if set(current_pred) == set(ref):
cnt += 1
detail['correct'] = True
details.append(detail)
@ -628,7 +639,7 @@ class MedBenchEvaluator_TF(BaseEvaluator):
cnt = 0
for pred, ref in zip(predictions, references):
if '' in pred or '' in pred:
cur_pred = '不可以'
else:
@ -639,8 +650,8 @@ class MedBenchEvaluator_TF(BaseEvaluator):
if cur_pred == ref:
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'Accuracy': score, 'details': details}
return {'Accuracy': score, 'details': details}

View File

@ -148,8 +148,8 @@ def parse_math_answer(setting_name, raw_string):
last_match = None
if '=' in s:
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
if '\\n' in last_match:
last_match = last_match.split('\\n')[0]
if '\n' in last_match:
last_match = last_match.split('\n')[0]
else:
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
matches = re.findall(pattern, s)
@ -170,6 +170,8 @@ def parse_math_answer(setting_name, raw_string):
def parse_qa_multiple_answer(string):
# if setting_name == 'few-shot-CoT':
# string = extract_last_line(string)
for x in ['CC', 'CA', 'AC', 'POMES', 'AI', 'MIBG', 'CF', 'CTE', 'AD', 'CB', 'BG', 'BD', 'BE', 'BH', 'CTB', 'BI', 'CE', 'Pugh', 'Child', 'CTI', 'CTA', 'TACE', 'PPD', 'Castleman', 'BA', 'CH', 'AB', 'CTC', 'CT', 'CTH', 'CD', 'AH', 'AE', 'AA', 'AF', 'BC', 'CG', 'BB', 'CI', 'BF', 'CTF', 'CTG', 'AG', 'CTD', '分级C', '分级A', 'I131', '分级B', '分级D', '131IMIBG', 'NYHA', 'IPF', 'DIP', 'Lambert-Eaton', 'Graves', 'IIA期', 'CKD', 'FDA', 'A级', 'B级', 'C级', 'D级', '维生素D']:
string = string.replace(x, '')
pattern = '\(*([A-Z])\)*'
match = re.findall(pattern, string)
if match: