mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Update MedBench (#779)
* update medbench * medbench update * format medbench * format * Update * update * update * update suffix --------- Co-authored-by: 施晓明 <PJLAB\shixiaoming@pjnl104220118l.pjlab.org> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
a74e4c1a8d
commit
ad872a5dc2
@ -6,7 +6,7 @@ exclude: |
|
||||
opencompass/openicl/icl_evaluator/hf_metrics/|
|
||||
opencompass/datasets/lawbench/utils|
|
||||
opencompass/datasets/lawbench/evaluation_functions/|
|
||||
opencompass/datasets/medbench|
|
||||
opencompass/datasets/medbench/|
|
||||
docs/zh_cn/advanced_guides/compassbench_intro.md
|
||||
)
|
||||
repos:
|
||||
|
@ -2,41 +2,24 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import (
|
||||
MedBenchDataset,
|
||||
MedBenchEvaluator,
|
||||
MedBenchEvaluator_Cloze,
|
||||
MedBenchEvaluator_IE,
|
||||
MedBenchEvaluator_mcq,
|
||||
MedBenchEvaluator_CMeEE,
|
||||
MedBenchEvaluator_CMeIE,
|
||||
MedBenchEvaluator_CHIP_CDEE,
|
||||
MedBenchEvaluator_CHIP_CDN,
|
||||
MedBenchEvaluator_CHIP_CTC,
|
||||
MedBenchEvaluator_NLG,
|
||||
MedBenchEvaluator_TF,
|
||||
MedBenchEvaluator_EMR,
|
||||
)
|
||||
from opencompass.datasets import MedBenchDataset, MedBenchEvaluator, MedBenchEvaluator_Cloze, MedBenchEvaluator_IE, MedBenchEvaluator_mcq, MedBenchEvaluator_CMeEE, MedBenchEvaluator_CMeIE, MedBenchEvaluator_CHIP_CDEE, MedBenchEvaluator_CHIP_CDN, MedBenchEvaluator_CHIP_CTC, MedBenchEvaluator_NLG, MedBenchEvaluator_TF, MedBenchEvaluator_DBMHG, MedBenchEvaluator_SMDoc, MedBenchEvaluator_IMCS_V2_MRG
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
medbench_reader_cfg = dict(
|
||||
input_columns=['problem_input'], output_column='label')
|
||||
|
||||
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断
|
||||
medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'SafetyBench'] # 选择题,用acc判断
|
||||
|
||||
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
|
||||
medbench_qa_sets = ['MedHC', 'MedMC', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
|
||||
|
||||
medbench_cloze_sets = ['Triage'] # 限定域QA,有标答
|
||||
medbench_cloze_sets = ['MedHG'] # 限定域QA,有标答
|
||||
|
||||
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
|
||||
medbench_single_choice_sets = ['DrugCA'] # 正确与否判断,有标答
|
||||
|
||||
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价
|
||||
|
||||
#, 'CMeIE', 'CHIP_CDEE', 'CHIP_CDN', 'CHIP_CTC', 'Doc_parsing', 'MRG'
|
||||
medbench_ie_sets = ['DBMHG', 'CMeEE', 'CMeIE', 'CHIP-CDEE', 'CHIP-CDN', 'CHIP-CTC', 'SMDoc', 'IMCS-V2-MRG'] # 判断识别的实体是否一致,用F1评价
|
||||
|
||||
medbench_datasets = []
|
||||
|
||||
|
||||
for name in medbench_single_choice_sets:
|
||||
medbench_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
@ -144,7 +127,7 @@ for name in medbench_ie_sets:
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
medbench_eval_cfg = dict(
|
||||
evaluator=dict(type=eval('MedBenchEvaluator_'+name)), pred_role="BOT")
|
||||
evaluator=dict(type=eval('MedBenchEvaluator_'+name.replace('-', '_'))), pred_role="BOT")
|
||||
|
||||
medbench_datasets.append(
|
||||
dict(
|
||||
@ -157,4 +140,4 @@ for name in medbench_ie_sets:
|
||||
infer_cfg=medbench_infer_cfg.copy(),
|
||||
eval_cfg=medbench_eval_cfg.copy()))
|
||||
|
||||
del name, medbench_infer_cfg, medbench_eval_cfg
|
||||
del name, medbench_infer_cfg, medbench_eval_cfg
|
@ -11,31 +11,31 @@ from .constructions import ChatGPTSchema, ResultsForHumanSchema
|
||||
from .utils import extract_answer, read_jsonl, save_jsonl
|
||||
|
||||
# define the datasets
|
||||
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断
|
||||
medbench_multiple_choices_sets = ['Med-Exam', 'DDx-basic', 'DDx-advanced', 'DDx-advanced', 'SafetyBench'] # 选择题,用acc判断
|
||||
|
||||
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
|
||||
medbench_qa_sets = ['MedHC', 'MedMC', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
|
||||
|
||||
medbench_cloze_sets = ['Triage'] # 限定域QA,有标答
|
||||
medbench_cloze_sets = ['MedHG'] # 限定域QA,有标答
|
||||
|
||||
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
|
||||
medbench_single_choice_sets = ['DrugCA'] # 正确与否判断,有标答
|
||||
|
||||
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价
|
||||
medbench_ie_sets = ['DBMHG', 'CMeEE', 'CMeIE', 'CHIP-CDEE', 'CHIP-CDN', 'CHIP-CTC', 'SMDoc', 'IMCS-V2-MRG'] # 判断识别的实体是否一致,用F1评价
|
||||
|
||||
def convert_zero_shot(line, dataset_name):
|
||||
# passage = line['passage'] if line['passage'] is not None else ''
|
||||
if dataset_name in medbench_qa_sets:
|
||||
return line['question']
|
||||
elif dataset_name in medbench_cloze_sets:
|
||||
return '问题:' + line['question'] + '\n答案:'
|
||||
elif dataset_name in medbench_multiple_choices_sets:
|
||||
return '问题:' + line['question'] + ' ' \
|
||||
+ '选项:' + ' '.join(line['options']) + '\n从A到G,我们应该选择'
|
||||
else:
|
||||
return line['question']
|
||||
# if dataset_name in medbench_qa_sets:
|
||||
# return line['question']
|
||||
# elif dataset_name in medbench_cloze_sets:
|
||||
# return '问题:' + line['question'] + '\n答案:'
|
||||
# elif dataset_name in medbench_multiple_choices_sets:
|
||||
# return '问题:' + line['question'] + ' ' \
|
||||
# + '选项:' + ' '.join(line['options']) + '\n从A到G,我们应该选择'
|
||||
# else:
|
||||
# return line['question']
|
||||
return line['question']
|
||||
|
||||
prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n'
|
||||
|
||||
|
||||
# def convert_zero_shot_CoT_stage1(line, dataset_name):
|
||||
# try:
|
||||
# passage = line['passage'] if line['passage'] is not None else ''
|
||||
|
1
opencompass/datasets/medbench/entity_list.jsonl
Normal file
1
opencompass/datasets/medbench/entity_list.jsonl
Normal file
File diff suppressed because one or more lines are too long
@ -82,10 +82,8 @@ class MedBenchEvaluator(BaseEvaluator):
|
||||
detail['correct'] = True
|
||||
details.append(detail)
|
||||
score = cnt / len(predictions) * 100
|
||||
#输出字典类型 {'score':'', 'details'}
|
||||
return {'Accuracy': score, 'details': details}
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_mcq(BaseEvaluator):
|
||||
|
||||
@ -109,16 +107,18 @@ class MedBenchEvaluator_mcq(BaseEvaluator):
|
||||
return {'score': score, 'details': details}
|
||||
|
||||
def process_generated_results_CMeEE(pred_file):
|
||||
# 实体每类占一行,每行格式为 "[类型名称]实体:实体名称1,实体名称2,实体名称3\n"
|
||||
# 多个实体,用 ,符号分割
|
||||
structured_output = []
|
||||
answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病']
|
||||
for pred in pred_file:
|
||||
list_entities = []
|
||||
for choice in answer_choices:
|
||||
for piece in re.split('[,|.|。|;|\n]', pred):
|
||||
for piece in re.split('[。|;|\n]', pred):
|
||||
if piece.startswith(f"{choice}"):
|
||||
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").split(",")
|
||||
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").replace(f'{choice}:', '').replace(f'{choice}:', '').split(",")
|
||||
for ment in mentions:
|
||||
list_entities.append({'entity':ment, 'type':choice})
|
||||
list_entities.append({'type':choice, 'entity':ment})
|
||||
structured_output.append(list_entities)
|
||||
return structured_output
|
||||
|
||||
@ -128,12 +128,15 @@ def process_generated_results_EMR(pred_file):
|
||||
for pred in pred_file:
|
||||
list_entities = []
|
||||
for choice in answer_choices:
|
||||
for piece in re.split('[,|.|?|;|,|。|;|\n]', pred):
|
||||
if piece.startswith(f"{choice}"):
|
||||
mentions = piece.replace(f"{choice}:", "").split(",")
|
||||
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
|
||||
for ment in mentions:
|
||||
list_entities.append({ment: choice})
|
||||
for piece in re.split('\n', pred):
|
||||
# if piece.startswith(f"{choice}"):
|
||||
if f"{choice}" in piece and len(piece.split(f"{choice}:"))>1:
|
||||
# mentions = piece.replace(f"{choice}:", "").split(",")
|
||||
mentions = piece.split(f"{choice}:")[1].strip()
|
||||
# mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
|
||||
list_entities.append({choice: mentions})
|
||||
# for ment in mentions:
|
||||
# list_entities.append({choice: ment})
|
||||
structured_output.append(list_entities)
|
||||
return structured_output
|
||||
|
||||
@ -141,7 +144,7 @@ def process_generated_results_CMeIE(pred_file):
|
||||
structured_output = []
|
||||
for line in pred_file:
|
||||
gen_output = line
|
||||
|
||||
|
||||
# 答案格式:
|
||||
# 每个关系类型占一行,格式为
|
||||
# "具有{lab}关系的头尾实体对如下:头实体为str,尾实体为str;头实体为str,尾实体为str;"
|
||||
@ -156,14 +159,17 @@ def process_generated_results_CMeIE(pred_file):
|
||||
# 首先是解析出label:
|
||||
predicate = line.split("关系的头尾实体对")[0][2: ].strip()
|
||||
line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "")
|
||||
for spo_str in line.split("。"):
|
||||
if len(spo_str.split(",尾实体为")) < 2:
|
||||
# for spo_str in line.split("。"):
|
||||
for spo_str in re.split(';|。', line):
|
||||
|
||||
if len(spo_str.split(",尾实体:")) < 2:
|
||||
continue
|
||||
|
||||
head_mention_str, tail_mention_str = spo_str.split(",尾实体为")[:2]
|
||||
head_mention_str = head_mention_str.replace("头实体为", "").strip()
|
||||
tail_mention_str = tail_mention_str.replace("尾实体为", "").strip()
|
||||
|
||||
head_mention_str, tail_mention_str = spo_str.split(",尾实体:")[:2]
|
||||
|
||||
head_mention_str = head_mention_str.replace("头实体:", "").strip()
|
||||
tail_mention_str = tail_mention_str.replace("尾实体:", "").strip()
|
||||
|
||||
list_spos.append(
|
||||
{
|
||||
"predicate": predicate,
|
||||
@ -176,10 +182,10 @@ def process_generated_results_CMeIE(pred_file):
|
||||
|
||||
def process_generated_results_CDN(pred_file):
|
||||
structured_output = []
|
||||
answer_choices = json.load(open('./data/MedBench/CHIP_CDN/CHIP-CDN_entity.json', 'r'))
|
||||
answer_choices = json.load(open('./opencompass/datasets/medbench/entity_list.jsonl', 'r'))
|
||||
for line in pred_file:
|
||||
gen_output = line
|
||||
|
||||
|
||||
# 答案格式:
|
||||
# 多个选中的标准化实体,用 , 符号分割
|
||||
|
||||
@ -211,15 +217,17 @@ def process_generated_results_CDEE(pred_file):
|
||||
keys = ["主体词", "发生状态", "描述词", "解剖部位"]
|
||||
|
||||
list_answer_strs = gen_output.split("\n")
|
||||
# list_answer_strs: ['主题词:饮食,描述词:差;', '主题词:消瘦']
|
||||
list_events = []
|
||||
for ans_str in list_answer_strs:
|
||||
if '主体词' in ans_str:
|
||||
event_info = {}
|
||||
ans_attrs = ans_str.split(";")
|
||||
ans_attrs = ans_str.split(",")
|
||||
|
||||
for a_attr in ans_attrs:
|
||||
for key in keys:
|
||||
if a_attr.startswith(f"{key}:"):
|
||||
a_attr = a_attr.replace(f"{key}:", "").strip()
|
||||
a_attr = a_attr.replace(f"{key}:", "").strip().strip(';')
|
||||
if key in ["描述词", "解剖部位"]:
|
||||
a_attr_split = a_attr.split(",")
|
||||
a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0]
|
||||
@ -239,7 +247,7 @@ def process_generated_results_CDEE(pred_file):
|
||||
structured_output.append(list_events)
|
||||
return structured_output
|
||||
|
||||
def process_generated_results_CTC(pred_file, task_dataset):
|
||||
def process_generated_results_CTC(pred_file):
|
||||
structured_output = []
|
||||
|
||||
for line in pred_file:
|
||||
@ -252,60 +260,60 @@ def process_generated_results_CTC(pred_file, task_dataset):
|
||||
def process_generated_results_doc_parsing(pred_file):
|
||||
output = []
|
||||
for line in pred_file:
|
||||
structured_output = {'体温':'', '脉搏':'', '心率':'', '收缩压':'', '舒张压':'', '呼吸':'', '上腹部深压痛':'', '腹部反跳痛':'', '上腹部肿块':''}
|
||||
sentence_list = line.strip().split(',|。|\n')
|
||||
structured_output = []
|
||||
sentence_list = line.strip().split('\n')
|
||||
for sentence in sentence_list:
|
||||
if '体温' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
temp_value = re.search('[0-9]+.[0-9]', sentence)
|
||||
if temp_value:
|
||||
structured_output['体温'] = temp_value.group(0)
|
||||
structured_output.append({'type':'体温', 'entity':temp_value.group(0)})
|
||||
else:
|
||||
structured_output['体温'] = '未扪及'
|
||||
structured_output.append({'type':'体温', 'entity':'未扪及'})
|
||||
elif '脉搏' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
temp_value = re.search('[0-9]+.[0-9]', sentence)
|
||||
if temp_value:
|
||||
structured_output['脉搏'] = temp_value.group(0)
|
||||
structured_output.append({'type':'脉搏', 'entity':temp_value.group(0)})
|
||||
else:
|
||||
structured_output['脉搏'] = '未扪及'
|
||||
structured_output.append({'type':'脉搏', 'entity':'未扪及'})
|
||||
elif '心率' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
temp_value = re.search('[0-9]+.[0-9]', sentence)
|
||||
if temp_value:
|
||||
structured_output['心率'] = temp_value.group(0)
|
||||
structured_output.append({'type':'心率', 'entity':temp_value.group(0)})
|
||||
else:
|
||||
structured_output['心率'] = '未扪及'
|
||||
structured_output.append({'type':'心率', 'entity':'未扪及'})
|
||||
elif '收缩压' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
temp_value = re.search('[0-9]+.[0-9]', sentence)
|
||||
if temp_value:
|
||||
structured_output['收缩压'] = temp_value.group(0)
|
||||
structured_output.append({'type':'收缩压', 'entity':temp_value.group(0)})
|
||||
else:
|
||||
structured_output['收缩压'] = '未扪及'
|
||||
structured_output.append({'type':'收缩压', 'entity':'未扪及'})
|
||||
elif '舒张压' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
temp_value = re.search('[0-9]+.[0-9]', sentence)
|
||||
if temp_value:
|
||||
structured_output['舒张压'] = temp_value.group(0)
|
||||
structured_output.append({'type':'舒张压', 'entity':temp_value.group(0)})
|
||||
else:
|
||||
structured_output['舒张压'] = '未扪及'
|
||||
structured_output.append({'type':'舒张压', 'entity':'未扪及'})
|
||||
elif '呼吸' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
temp_value = re.search('[0-9]+.[0-9]', sentence)
|
||||
if temp_value:
|
||||
structured_output['呼吸'] = temp_value.group(0)
|
||||
structured_output.append({'type':'呼吸', 'entity':temp_value.group(0)})
|
||||
else:
|
||||
structured_output['呼吸'] = '未扪及'
|
||||
structured_output.append({'type':'呼吸', 'entity':'未扪及'})
|
||||
elif '上腹部深压痛' in sentence:
|
||||
if re.search('是|存在|有', sentence):
|
||||
structured_output['是否上腹部深压痛'] = '是'
|
||||
if re.search('未|不|没|无', sentence):
|
||||
structured_output.append({'type':'上腹部深压痛', 'entity':'否是'})
|
||||
else:
|
||||
structured_output['是否上腹部深压痛'] = '否'
|
||||
structured_output.append({'type':'上腹部深压痛', 'entity':'是'})
|
||||
elif '腹部反跳痛' in sentence:
|
||||
if re.search('是|存在|有', sentence):
|
||||
structured_output['是否腹部反跳痛'] = '是'
|
||||
if re.search('未|不|没|无', sentence):
|
||||
structured_output.append({'type':'腹部反跳痛', 'entity':'否'})
|
||||
else:
|
||||
structured_output['是否腹部反跳痛'] = '否'
|
||||
structured_output.append({'type':'腹部反跳痛', 'entity':'是'})
|
||||
elif '上腹部肿块' in sentence:
|
||||
if re.search('是|存在|有', sentence):
|
||||
structured_output['上腹部肿块'] = '扪及'
|
||||
if re.search('未|不|没|无', sentence):
|
||||
structured_output.append({'type':'上腹部肿块', 'entity':'未扪及'})
|
||||
else:
|
||||
structured_output['上腹部肿块'] = '未扪及'
|
||||
structured_output.append({'type':'上腹部肿块', 'entity':'扪及'})
|
||||
output.append(structured_output)
|
||||
return output
|
||||
|
||||
@ -315,18 +323,22 @@ def process_generated_results_mrg(pred_file):
|
||||
for pred in pred_file:
|
||||
list_entities = []
|
||||
for choice in answer_choices:
|
||||
for piece in re.split('[,|.|?|;|,|。|;|\n]', pred):
|
||||
if piece.startswith(f"{choice}实体"):
|
||||
mentions = piece.replace(f"{choice}实体:", "").split(",")
|
||||
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
|
||||
for ment in mentions:
|
||||
list_entities.append({ment: choice})
|
||||
if '\n\n' in pred['answer']:
|
||||
for piece in re.split('\n\n', pred['answer']):
|
||||
if f"{choice}" in piece and len(piece.split(f"{choice}:"))>1:
|
||||
mentions = piece.split(f"{choice}:")[1].strip()
|
||||
list_entities.append({choice:mentions})
|
||||
else:
|
||||
for piece in re.split('\n', pred):
|
||||
if piece.startswith(f"{choice}:"):
|
||||
mentions = piece.replace(f"{choice}:", "").split(",")
|
||||
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
|
||||
for ment in mentions:
|
||||
list_entities.append({choice:ment})
|
||||
structured_output.append(list_entities)
|
||||
return structured_output
|
||||
|
||||
|
||||
def calc_info_extract_task_scores(list_structured_golden,
|
||||
list_structured_predict):
|
||||
def calc_info_extract_task_scores(list_structured_predict, list_structured_golden):
|
||||
|
||||
assert len(list_structured_golden) == len(list_structured_predict)
|
||||
|
||||
@ -334,12 +346,11 @@ def calc_info_extract_task_scores(list_structured_golden,
|
||||
fp = 0
|
||||
fn = 0
|
||||
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
|
||||
|
||||
# samp_golden: [[{}]]
|
||||
answer_golden = samp_golden
|
||||
answer_predict = samp_predict
|
||||
|
||||
assert isinstance(answer_golden, list)
|
||||
assert isinstance(answer_predict, list), "sample format is wrong!"
|
||||
# assert isinstance(answer_golden, list)
|
||||
# assert isinstance(answer_predict, list), "sample format is wrong!"
|
||||
|
||||
set_golden = set()
|
||||
for inst in answer_golden:
|
||||
@ -356,18 +367,11 @@ def calc_info_extract_task_scores(list_structured_golden,
|
||||
for inst in answer_predict:
|
||||
assert isinstance(inst, dict)
|
||||
keys = sorted(list(inst.keys()))
|
||||
# inst = tuple([inst[w] for w in keys])
|
||||
|
||||
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys])
|
||||
|
||||
# inst = list(inst.items())
|
||||
# inst.sort()
|
||||
# inst = tuple(inst)
|
||||
|
||||
set_predict.add(inst)
|
||||
|
||||
# print("set_predict: ", set_predict)
|
||||
# print("set_golden: ", set_golden)
|
||||
|
||||
tp += len(set_golden.intersection(set_predict))
|
||||
fp += len(set_predict.difference(set_golden))
|
||||
fn += len(set_golden.difference(set_predict))
|
||||
@ -402,7 +406,9 @@ def calc_cls_task_scores(list_structured_golden,
|
||||
|
||||
pred_label = pred_samp
|
||||
gt_label = gt_samp
|
||||
assert gt_label != ""
|
||||
# assert gt_label != ""
|
||||
if gt_label == "":
|
||||
get_label = list_labels[0]
|
||||
if pred_label == "":
|
||||
pred_label = list_labels[0]
|
||||
|
||||
@ -434,16 +440,10 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
|
||||
references = []
|
||||
details = []
|
||||
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
|
||||
# print("samp_golden: ", samp_golden)
|
||||
# print("samp_predict: ", samp_predict)
|
||||
|
||||
# assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!"
|
||||
answer_golden = samp_golden
|
||||
answer_predict = samp_predict
|
||||
|
||||
print('#')
|
||||
print(answer_golden)
|
||||
print(answer_predict)
|
||||
if not (answer_predict and answer_golden):
|
||||
continue
|
||||
|
||||
@ -456,8 +456,6 @@ def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
|
||||
answer_golden = "无 。"
|
||||
if answer_predict.strip() == "":
|
||||
answer_predict = "无 。"
|
||||
# print("answer_predict: ", answer_predict)
|
||||
# print("answer_golden: ", answer_golden)
|
||||
|
||||
predictions.append(answer_predict)
|
||||
references.append(answer_golden)
|
||||
@ -487,7 +485,7 @@ def calc_scores_f1(dict_gt, dict_pred):
|
||||
details = []
|
||||
for gt, pred in zip(dict_gt, dict_pred):
|
||||
details.append({'pred':pred, 'answer':gt, 'correct':None})
|
||||
|
||||
|
||||
precision, recall, f1 = calc_info_extract_task_scores(dict_gt, dict_pred)
|
||||
return {'F1':f1, 'details':details}
|
||||
|
||||
@ -498,7 +496,7 @@ def calc_scores_ctc(dict_gt, dict_pred):
|
||||
|
||||
gts = dict_gt
|
||||
preds = dict_pred
|
||||
|
||||
|
||||
precision, recall, f1 = calc_cls_task_scores(
|
||||
gts,
|
||||
preds,
|
||||
@ -520,9 +518,9 @@ def calc_scores_ctc(dict_gt, dict_pred):
|
||||
return_macro=True,
|
||||
)
|
||||
return {'Macro-F1':f1, 'details':details}
|
||||
|
||||
|
||||
def calc_scores_nlg(dict_gt, dict_pred):
|
||||
|
||||
|
||||
# scores = {}
|
||||
scores = {'score':0, 'details':[]}
|
||||
success_flag = 1
|
||||
@ -532,7 +530,7 @@ def calc_scores_nlg(dict_gt, dict_pred):
|
||||
# if not len(gts) == len(preds):
|
||||
# success_flag = 0
|
||||
# try:
|
||||
return calc_nlg_task_scores(gts, preds)
|
||||
return calc_nlg_task_scores(gts, preds)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_CMeEE(BaseEvaluator):
|
||||
@ -542,14 +540,14 @@ class MedBenchEvaluator_CMeEE(BaseEvaluator):
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_EMR(BaseEvaluator):
|
||||
class MedBenchEvaluator_DBMHG(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_EMR(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_MRG(BaseEvaluator):
|
||||
class MedBenchEvaluator_IMCS_V2_MRG(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_mrg(predictions)
|
||||
@ -581,10 +579,10 @@ class MedBenchEvaluator_CHIP_CTC(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_CTC(predictions)
|
||||
return calc_scores_ctc(predictions, references)[0]
|
||||
return calc_scores_ctc(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
|
||||
class MedBenchEvaluator_SMDoc(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_doc_parsing(predictions)
|
||||
@ -594,23 +592,36 @@ class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
|
||||
class MedBenchEvaluator_NLG(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
# predictions = process_generated_results_med(predictions)
|
||||
return calc_scores_nlg(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_Cloze(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
# predictions: [[]]
|
||||
# references: [[]]
|
||||
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
|
||||
erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "小儿风湿免疫科", "放射科", "小儿内分泌代谢科", "急诊科", "心血管内科", "小儿神经内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "小儿心脏中心", "产科", "血液内科", "小儿普外科", "小儿泌尿外科", "小儿感染科", "临床营养科", "小儿骨科", "发育行为儿童保健科", "小儿呼吸内科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "小儿神经外科", "康复医学科", "骨科", "风湿免疫科", "小儿内科", "眼科", "心胸外科", "小儿肾脏内科", "乳腺外科", "小儿血液肿瘤科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "小儿消化内科", "呼吸内科", "核医学科", "肾脏内科"]
|
||||
no_erke_list = ["血管外科", "临床心理科", "生殖医学中心", "肿瘤科", "妇科", "放射科", "急诊科", "心血管内科", "感染科", "整形外科", "全科医学科", "泌尿外科", "皮肤科", "消化内科", "口腔科", "产科", "血液内科", "临床营养科", "神经外科", "内分泌代谢科", "普外科", "肛肠外科", "康复医学科", "骨科", "风湿免疫科", "眼科", "心胸外科", "乳腺外科", "体检中心", "神经内科", "耳鼻咽喉头颈外科", "呼吸内科", "核医学科", "肾脏内科"]
|
||||
|
||||
cross_erke_list = [item for item in erke_list if '小儿' in item and item.replace('小儿', '') in no_erke_list]
|
||||
cross_list = [item[2:] for item in cross_erke_list]
|
||||
|
||||
details = []
|
||||
cnt = 0
|
||||
|
||||
for pred, ref in zip(predictions, references):
|
||||
detail = {'pred':pred, 'answer':ref, 'correct':False}
|
||||
current_pred = []
|
||||
for x in cross_list:
|
||||
if '小儿' + x in predictions:
|
||||
current_pred.append('小儿' + x)
|
||||
elif x in predictions:
|
||||
current_pred.append(x)
|
||||
|
||||
if sum([item in pred for item in ref]) == len(ref):
|
||||
for x in (set(erke_list + no_erke_list) - set(cross_erke_list) - set(cross_list)):
|
||||
if x in predictions:
|
||||
current_pred.append(x)
|
||||
|
||||
# if set([x for x in erke_list + no_erke_list if x in pred]) == set(ref):
|
||||
if set(current_pred) == set(ref):
|
||||
cnt += 1
|
||||
detail['correct'] = True
|
||||
details.append(detail)
|
||||
@ -628,7 +639,7 @@ class MedBenchEvaluator_TF(BaseEvaluator):
|
||||
cnt = 0
|
||||
|
||||
for pred, ref in zip(predictions, references):
|
||||
|
||||
|
||||
if '不' in pred or '否' in pred:
|
||||
cur_pred = '不可以'
|
||||
else:
|
||||
@ -639,8 +650,8 @@ class MedBenchEvaluator_TF(BaseEvaluator):
|
||||
if cur_pred == ref:
|
||||
cnt += 1
|
||||
detail['correct'] = True
|
||||
|
||||
|
||||
details.append(detail)
|
||||
|
||||
score = cnt / len(predictions) * 100
|
||||
return {'Accuracy': score, 'details': details}
|
||||
return {'Accuracy': score, 'details': details}
|
@ -148,8 +148,8 @@ def parse_math_answer(setting_name, raw_string):
|
||||
last_match = None
|
||||
if '=' in s:
|
||||
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
|
||||
if '\\n' in last_match:
|
||||
last_match = last_match.split('\\n')[0]
|
||||
if '\n' in last_match:
|
||||
last_match = last_match.split('\n')[0]
|
||||
else:
|
||||
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
|
||||
matches = re.findall(pattern, s)
|
||||
@ -170,6 +170,8 @@ def parse_math_answer(setting_name, raw_string):
|
||||
def parse_qa_multiple_answer(string):
|
||||
# if setting_name == 'few-shot-CoT':
|
||||
# string = extract_last_line(string)
|
||||
for x in ['CC', 'CA', 'AC', 'POMES', 'AI', 'MIBG', 'CF', 'CTE', 'AD', 'CB', 'BG', 'BD', 'BE', 'BH', 'CTB', 'BI', 'CE', 'Pugh', 'Child', 'CTI', 'CTA', 'TACE', 'PPD', 'Castleman', 'BA', 'CH', 'AB', 'CTC', 'CT', 'CTH', 'CD', 'AH', 'AE', 'AA', 'AF', 'BC', 'CG', 'BB', 'CI', 'BF', 'CTF', 'CTG', 'AG', 'CTD', '分级C', '分级A', 'I131', '分级B', '分级D', '131I‐MIBG', 'NYHA', 'IPF', 'DIP', 'Lambert-Eaton', 'Graves', 'IIA期', 'CKD', 'FDA', 'A级', 'B级', 'C级', 'D级', '维生素D']:
|
||||
string = string.replace(x, '')
|
||||
pattern = '\(*([A-Z])\)*'
|
||||
match = re.findall(pattern, string)
|
||||
if match:
|
||||
|
Loading…
Reference in New Issue
Block a user