From 1bf85949ef1789c31458a4b036c90d1aecb792e8 Mon Sep 17 00:00:00 2001 From: Xiaoming Shi <1669490794@qq.com> Date: Sat, 9 Dec 2023 16:05:46 +0800 Subject: [PATCH] [Feature] Add medbench (#678) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update medbench * medbench update * format medbench * format --------- Co-authored-by: 施晓明 Co-authored-by: Leymore --- .pre-commit-config-zh-cn.yaml | 3 +- .pre-commit-config.yaml | 3 +- configs/datasets/MedBench/medbench_gen.py | 4 + .../datasets/MedBench/medbench_gen_d44f24.py | 160 +++++ opencompass/datasets/__init__.py | 1 + opencompass/datasets/medbench/__init__.py | 3 + .../datasets/medbench/constructions.py | 104 +++ .../datasets/medbench/dataset_loader.py | 338 +++++++++ opencompass/datasets/medbench/evaluation.py | 43 ++ .../datasets/medbench/math_equivalence.py | 161 +++++ opencompass/datasets/medbench/medbench.py | 646 ++++++++++++++++++ opencompass/datasets/medbench/post_process.py | 198 ++++++ opencompass/datasets/medbench/utils.py | 43 ++ 13 files changed, 1705 insertions(+), 2 deletions(-) create mode 100644 configs/datasets/MedBench/medbench_gen.py create mode 100644 configs/datasets/MedBench/medbench_gen_d44f24.py create mode 100644 opencompass/datasets/medbench/__init__.py create mode 100644 opencompass/datasets/medbench/constructions.py create mode 100644 opencompass/datasets/medbench/dataset_loader.py create mode 100644 opencompass/datasets/medbench/evaluation.py create mode 100644 opencompass/datasets/medbench/math_equivalence.py create mode 100644 opencompass/datasets/medbench/medbench.py create mode 100644 opencompass/datasets/medbench/post_process.py create mode 100644 opencompass/datasets/medbench/utils.py diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 6b9be079..b1817ca0 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -5,7 +5,8 @@ exclude: | opencompass/utils/internal/| opencompass/openicl/icl_evaluator/hf_metrics/| opencompass/datasets/lawbench/utils| - opencompass/datasets/lawbench/evaluation_functions/ + opencompass/datasets/lawbench/evaluation_functions/| + opencompass/datasets/medbench ) repos: - repo: https://gitee.com/openmmlab/mirrors-flake8 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a67ade99..d4326160 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,8 @@ exclude: | opencompass/utils/internal/| opencompass/openicl/icl_evaluator/hf_metrics/| opencompass/datasets/lawbench/utils| - opencompass/datasets/lawbench/evaluation_functions/ + opencompass/datasets/lawbench/evaluation_functions/| + opencompass/datasets/medbench/ ) repos: - repo: https://github.com/PyCQA/flake8 diff --git a/configs/datasets/MedBench/medbench_gen.py b/configs/datasets/MedBench/medbench_gen.py new file mode 100644 index 00000000..7c1aa0c9 --- /dev/null +++ b/configs/datasets/MedBench/medbench_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .medbench_gen_d44f24 import medbench_datasets # noqa: F401, F403 diff --git a/configs/datasets/MedBench/medbench_gen_d44f24.py b/configs/datasets/MedBench/medbench_gen_d44f24.py new file mode 100644 index 00000000..1bfa5bb6 --- /dev/null +++ b/configs/datasets/MedBench/medbench_gen_d44f24.py @@ -0,0 +1,160 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import ( + MedBenchDataset, + MedBenchEvaluator, + MedBenchEvaluator_Cloze, + MedBenchEvaluator_IE, + MedBenchEvaluator_mcq, + MedBenchEvaluator_CMeEE, + MedBenchEvaluator_CMeIE, + MedBenchEvaluator_CHIP_CDEE, + MedBenchEvaluator_CHIP_CDN, + MedBenchEvaluator_CHIP_CTC, + MedBenchEvaluator_NLG, + MedBenchEvaluator_TF, + MedBenchEvaluator_EMR, +) +from opencompass.utils.text_postprocessors import first_capital_postprocess + +medbench_reader_cfg = dict( + input_columns=['problem_input'], output_column='label') + +medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断 + +medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答 + +medbench_cloze_sets = ['Triage'] # 限定域QA,有标答 + +medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答 + +medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价 + +#, 'CMeIE', 'CHIP_CDEE', 'CHIP_CDN', 'CHIP_CTC', 'Doc_parsing', 'MRG' + +medbench_datasets = [] + + +for name in medbench_single_choice_sets: + medbench_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[dict(role="HUMAN", prompt='{problem_input}')])), + retriever=dict(type=ZeroRetriever + ), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot) + inferencer=dict(type=GenInferencer)) + + medbench_eval_cfg = dict( + evaluator=dict(type=MedBenchEvaluator_TF), pred_role="BOT") + + medbench_datasets.append( + dict( + type=MedBenchDataset, + path='./data/MedBench/' + name, + name=name, + abbr='medbench-' + name, + setting_name='zero-shot', + reader_cfg=medbench_reader_cfg, + infer_cfg=medbench_infer_cfg.copy(), + eval_cfg=medbench_eval_cfg.copy())) + +for name in medbench_multiple_choices_sets: + medbench_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[dict(role="HUMAN", prompt='{problem_input}')])), + retriever=dict(type=ZeroRetriever + ), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot) + inferencer=dict(type=GenInferencer)) + + medbench_eval_cfg = dict( + evaluator=dict(type=MedBenchEvaluator), pred_role="BOT") + + medbench_datasets.append( + dict( + type=MedBenchDataset, + path='./data/MedBench/' + name, + name=name, + abbr='medbench-' + name, + setting_name='zero-shot', + reader_cfg=medbench_reader_cfg, + infer_cfg=medbench_infer_cfg.copy(), + eval_cfg=medbench_eval_cfg.copy())) + +for name in medbench_qa_sets: + medbench_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[dict(role="HUMAN", prompt='{problem_input}')])), + retriever=dict(type=ZeroRetriever + ), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot) + inferencer=dict(type=GenInferencer)) + + medbench_eval_cfg = dict( + evaluator=dict(type=MedBenchEvaluator_NLG), pred_role="BOT") + + medbench_datasets.append( + dict( + type=MedBenchDataset, + path='./data/MedBench/' + name, + name=name, + abbr='medbench-' + name, + setting_name='zero-shot', + reader_cfg=medbench_reader_cfg, + infer_cfg=medbench_infer_cfg.copy(), + eval_cfg=medbench_eval_cfg.copy())) + +for name in medbench_cloze_sets: + medbench_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[dict(role="HUMAN", prompt='{problem_input}')])), + retriever=dict(type=ZeroRetriever + ), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot) + inferencer=dict(type=GenInferencer)) + + medbench_eval_cfg = dict( + evaluator=dict(type=MedBenchEvaluator_Cloze), pred_role="BOT") + + medbench_datasets.append( + dict( + type=MedBenchDataset, + path='./data/MedBench/' + name, + name=name, + abbr='medbench-' + name, + setting_name='zero-shot', + reader_cfg=medbench_reader_cfg, + infer_cfg=medbench_infer_cfg.copy(), + eval_cfg=medbench_eval_cfg.copy())) + +for name in medbench_ie_sets: + medbench_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[dict(role="HUMAN", prompt='{problem_input}')])), + retriever=dict(type=ZeroRetriever + ), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot) + inferencer=dict(type=GenInferencer)) + + medbench_eval_cfg = dict( + evaluator=dict(type=eval('MedBenchEvaluator_'+name)), pred_role="BOT") + + medbench_datasets.append( + dict( + type=MedBenchDataset, + path='./data/MedBench/' + name, + name=name, + abbr='medbench-' + name, + setting_name='zero-shot', + reader_cfg=medbench_reader_cfg, + infer_cfg=medbench_infer_cfg.copy(), + eval_cfg=medbench_eval_cfg.copy())) + +del name, medbench_infer_cfg, medbench_eval_cfg diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index a15c2e6f..70d875ab 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -56,6 +56,7 @@ from .longbench import * # noqa: F401, F403 from .math import * # noqa: F401, F403 from .mathbench import * # noqa: F401, F403 from .mbpp import * # noqa: F401, F403 +from .medbench import * # noqa: F401, F403 from .mmlu import * # noqa: F401, F403 from .multirc import * # noqa: F401, F403 from .narrativeqa import * # noqa: F401, F403 diff --git a/opencompass/datasets/medbench/__init__.py b/opencompass/datasets/medbench/__init__.py new file mode 100644 index 00000000..5d0d8ccd --- /dev/null +++ b/opencompass/datasets/medbench/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .medbench import * # noqa: F401, F403 diff --git a/opencompass/datasets/medbench/constructions.py b/opencompass/datasets/medbench/constructions.py new file mode 100644 index 00000000..c3302173 --- /dev/null +++ b/opencompass/datasets/medbench/constructions.py @@ -0,0 +1,104 @@ +# flake8: noqa +import pandas as pd + + +class TaskSchema(object): + + def __init__(self, + passage=None, + question=None, + options=None, + label=None, + answer=None, + other=None): + self.passage = passage + self.question = question + self.options = options + self.label = label + self.answer = answer + self.other = other + + def to_dict(self): + return { + 'passage': self.passage, + 'question': self.question, + 'options': self.options, + 'label': self.label, + 'answer': self.answer, + 'other': self.other + } + + +# define README.json +class MedBenchInstance(object): + + def __init__(self, task_description, data_source, task_schema, output, + evaluation_metric, task_example): + self.task_description = task_description + self.data_source = data_source + self.task_schema = task_schema + self.output = output + self.evaluation_metric = evaluation_metric + self.task_example = task_example + + def to_dict(self): + return { + 'task description': self.task_description, + 'data source': self.data_source, + 'task schema': self.task_schema.to_dict(), + 'output': self.output, + 'evaluation metric': self.evaluation_metric, + 'task example': self.task_example + } + + +class ChatGPTSchema(object): + + def __init__(self, context=None, metadata=''): + self.context = context + self.metadata = metadata + + def to_dict(self): + return {'context': self.context, 'metadata': self.metadata} + + +class ResultsForHumanSchema(object): + + def __init__(self, + index, + problem_input, + label, + model_input='', + model_output='', + parse_result='', + first_stage_output='', + second_stage_input='', + is_correct=False): + self.index = index + self.problem_input = problem_input + self.model_input = model_input + self.model_output = model_output + self.parse_result = parse_result + self.label = label + self.first_stage_output = first_stage_output + self.second_stage_input = second_stage_input + self.is_correct = is_correct + + def to_dict(self): + return { + 'index': self.index, + 'problem_input': self.problem_input, + 'model_input': self.model_input, + 'model_output': self.model_output, + 'parse_result': self.parse_result, + 'label': self.label, + 'is_correct': self.is_correct, + 'first_stage_output': self.first_stage_output, + 'second_stage_input': self.second_stage_input, + } + + @staticmethod + def to_tsv(result_list, path): + result_json = [item.to_dict() for item in result_list] + table = pd.json_normalize(result_json) + table.to_excel(path, index=False) diff --git a/opencompass/datasets/medbench/dataset_loader.py b/opencompass/datasets/medbench/dataset_loader.py new file mode 100644 index 00000000..32100fe5 --- /dev/null +++ b/opencompass/datasets/medbench/dataset_loader.py @@ -0,0 +1,338 @@ +# flake8: noqa +import ast +import json +import os + +import pandas as pd +import tiktoken +from tqdm import tqdm + +from .constructions import ChatGPTSchema, ResultsForHumanSchema +from .utils import extract_answer, read_jsonl, save_jsonl + +# define the datasets +medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断 + +medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答 + +medbench_cloze_sets = ['Triage'] # 限定域QA,有标答 + +medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答 + +medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价 + +def convert_zero_shot(line, dataset_name): + # passage = line['passage'] if line['passage'] is not None else '' + if dataset_name in medbench_qa_sets: + return line['question'] + elif dataset_name in medbench_cloze_sets: + return '问题:' + line['question'] + '\n答案:' + elif dataset_name in medbench_multiple_choices_sets: + return '问题:' + line['question'] + ' ' \ + + '选项:' + ' '.join(line['options']) + '\n从A到G,我们应该选择' + else: + return line['question'] + +prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n' + + +# def convert_zero_shot_CoT_stage1(line, dataset_name): +# try: +# passage = line['passage'] if line['passage'] is not None else '' +# if dataset_name in english_qa_datasets: +# return passage + 'Q: ' + line['question'] + ' ' \ +# + 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \ +# "Let's think step by step." + +# elif dataset_name in chinese_qa_datasets: +# option_string = 'ABCDEFG' +# count = len(line['options']) +# if count == 1: +# count = 4 +# return passage + '问题:' + line['question'] + ' ' \ +# + '选项:' + ' '.join(line['options']) + '\n' + \ +# '从A到{}, 我们应选择什么?让我们逐步思考:'.format(option_string[count - 1]) + +# elif dataset_name in english_cloze_datasets: +# return passage + 'Q: ' + line['question'] + '\n' \ +# "A: Let's think step by step." + +# elif dataset_name in chinese_cloze_datasets: +# return passage + '问题:' + line['question'] + '\n' \ +# '答案:让我们逐步思考:' +# except NameError: +# print('Dataset not defined.') + + +# process few-shot raw_prompts +def combine_prompt(prompt_path, + dataset_name, + load_explanation=True, + chat_mode=False): + skip_passage = False + if dataset_name == 'sat-en-without-passage': + skip_passage = True + dataset_name = 'sat-en' + demostrations = [] + # read the prompts by context and explanation + context_row = [0, 1, 3, 5, 7, 9] + explanation_row = [0, 2, 4, 6, 8, 10] + raw_prompts_context = pd.read_csv(prompt_path, + header=0, + skiprows=lambda x: x not in context_row, + keep_default_na=False) + raw_prompts_explanation = pd.read_csv( + prompt_path, + header=0, + skiprows=lambda x: x not in explanation_row, + keep_default_na=False).replace(r'\n\n', '\n', regex=True) + contexts = [] + for line in list(raw_prompts_context[dataset_name]): + if line: + # print(line) + contexts.append(ast.literal_eval(line)) + explanations = [ + exp for exp in raw_prompts_explanation[dataset_name] if exp + ] + + for idx, (con, exp) in enumerate(zip(contexts, explanations)): + passage = con['passage'] if con[ + 'passage'] is not None and not skip_passage else '' + question = con['question'] + options = con['options'] if con['options'] is not None else '' + label = con['label'] if con['label'] is not None else '' + answer = con[ + 'answer'] if 'answer' in con and con['answer'] is not None else '' + + if dataset_name in qa_datasets: + question_input = '问题 {}. '.format(idx + 1) + passage + ' ' + question + '\n' \ + + '从以下选项中选择: ' + ' '.join(options) + '\n' + question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \ + + '答案是 {}'.format(label) + + elif dataset_name in cloze_datasets: + question_input = '问题 {}. '.format(idx + 1) + question + '\n' + question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \ + + '答案是 {}'.format(answer) + else: + raise ValueError( + f'During loading few-sot examples, found unknown dataset: {dataset_name}' + ) + if chat_mode: + demostrations.append((question_input, question_output)) + else: + demostrations.append(question_input + question_output + '\n') + + return demostrations + + +enc = None + + +def _lazy_load_enc(): + global enc + if enc is None: + enc = tiktoken.encoding_for_model('gpt-4') + + +# cut prompt if reach max token length +def concat_prompt(demos, + dataset_name, + max_tokens, + end_of_example='\n', + verbose=False): + _lazy_load_enc() + demostration_en = 'Here are the answers for the problems in the exam.\n' + demostration_zh = '以下是考试中各个问题的答案。\n' + + for i in range(len(demos)): + # print(len(enc.encode(demostration_en)), len(enc.encode(demostration_zh))) + if dataset_name in english_qa_datasets: + demostration_en = demostration_en + demos[i] + end_of_example + elif dataset_name in chinese_qa_datasets: + demostration_zh = demostration_zh + demos[i] + end_of_example + elif dataset_name in english_cloze_datasets: + demostration_en = demostration_en + demos[i] + end_of_example + elif dataset_name in chinese_cloze_datasets: + demostration_zh = demostration_zh + demos[i] + end_of_example + # break if reach max token limit + if len(enc.encode(demostration_en)) < max_tokens and len( + enc.encode(demostration_zh)) < max_tokens: + output = demostration_en if len(demostration_en) > len( + demostration_zh) else demostration_zh + prompt_num = i + 1 + else: + break + if verbose: + print('max_tokens set as ', max_tokens, 'actual_tokens is', + len(enc.encode(output)), 'num_shot is', prompt_num) + return output, prompt_num + + +def concat_prompt_chat_mode(demos, + dataset_name, + max_tokens, + end_of_example='\n', + verbose=False): + _lazy_load_enc() + answers = [] + sentences = '' + for i in range(len(demos)): + answers += [ + { + 'role': 'user', + 'content': demos[i][0] + }, + { + 'role': 'assistant', + 'content': demos[i][1] + }, + ] + sentences += json.dumps(answers[-1]) + # break if reach max token limit + if len(enc.encode(sentences)) > max_tokens: + answers.pop() + answers.pop() + break + if verbose: + print('max_tokens set as ', max_tokens, 'actual_tokens is', + len(enc.encode(sentences)), 'num_shot is', + len(answers) // 2) + return answers, len(answers) // 2 + + +def convert_few_shot(line, dataset_name, demo, n_shot, chat_mode=False): + passage = line['passage'] if line['passage'] is not None else '' + question = line['question'] + options = line['options'] if line['options'] is not None else '' + + if dataset_name in qa_datasets: + question_input = '问题 {}. '.format(n_shot + 1) + passage + ' ' + question + '\n' \ + + '从以下选项中选择: ' + ' '.join(options) + '\n' + # + "问题 {}的解析: ".format(n_shot + 1) + + if dataset_name in cloze_datasets: + question_input = '问题 {}. '.format(n_shot + 1) + question + '\n' + # + "问题 {}的解析: ".format(n_shot + 1) + if chat_mode: + return demo + [ + { + 'role': 'user', + 'content': question_input + }, + ] + else: + return demo + question_input + + +def load_dataset(dataset_name, + setting_name, + parent_path, + prompt_path=None, + max_tokens=None, + end_of_example='\n', + chat_mode=False, + verbose=False): + test_path = os.path.join(parent_path, dataset_name + '.jsonl') + loaded_jsonl = read_jsonl(test_path) + processed = [] + if setting_name == 'few-shot-CoT' or setting_name == 'few-shot': + # process demo once if it is few-shot-CoT + processed_demos = combine_prompt( + prompt_path, + dataset_name, + load_explanation=setting_name == 'few-shot-CoT', + chat_mode=chat_mode) + if chat_mode: + chosen_prompt, n_shot = concat_prompt_chat_mode(processed_demos, + dataset_name, + max_tokens, + end_of_example, + verbose=verbose) + else: + chosen_prompt, n_shot = concat_prompt(processed_demos, + dataset_name, + max_tokens, + end_of_example, + verbose=verbose) + + if verbose: + loaded_jsonl = tqdm(loaded_jsonl) + for meta_idx, line in enumerate(loaded_jsonl): + # 正确 + if setting_name == 'zero-shot': + ctxt = convert_zero_shot(line, dataset_name) + elif setting_name == 'zero-shot-CoT': + ctxt = convert_zero_shot_CoT_stage1(line, dataset_name) + elif setting_name == 'few-shot-CoT' or setting_name == 'few-shot': + ctxt = convert_few_shot(line, dataset_name, chosen_prompt, n_shot, + chat_mode) + try: + new_instance = ChatGPTSchema(context=ctxt, metadata=meta_idx) + processed.append(new_instance.to_dict()) + except NameError: + print('Dataset not defined.') + return processed + + +def generate_second_stage_input(dataset_name, + input_list, + output_list, + with_format_prompt=False): + try: + chinese_format_prompt = '根据以上内容,你的任务是把最终的答案提取出来并填在【】中,例如【0】或者【A】。' + if dataset_name in qa_datasets: + prompt_suffix = '因此,从A到D, 我们应选择' + if with_format_prompt: + prompt_suffix = chinese_format_prompt + prompt_suffix + elif dataset_name in cloze_datasets: + prompt_suffix = '因此,答案是' + if with_format_prompt: + prompt_suffix = chinese_format_prompt + prompt_suffix + except NameError: + print('Dataset not defined.') + processed = [] + for i in range(len(input_list)): + ctxt = '{0}\n{1}\n{2}'.format(input_list[i]['context'], + extract_answer(output_list[i]), + prompt_suffix) + new_instance = ChatGPTSchema(context=ctxt, + metadata=input_list[i]['metadata']) + processed.append(new_instance.to_dict()) + return processed + + +def load_dataset_as_result_schema(dataset_name, parent_path): + test_path = os.path.join(parent_path, dataset_name + '.jsonl') + loaded_jsonl = read_jsonl(test_path) + + processed = [] + for i, line in enumerate(loaded_jsonl): + problem_input = convert_zero_shot(line, dataset_name) + processed.append( + ResultsForHumanSchema( + index=i, + problem_input=problem_input, + # label=line['label'] if line['label'] else line['answer'] + label = line['answer'] + )) + return processed + + +if __name__ == '__main__': + # set variables + parent_dir = '../../data/exam_guidance' + + # set dataset name to process + setting_name = 'zero-shot' # setting_name can be chosen from ["zero-shot", "zero-shot-CoT", "few-shot-CoT"] + data_name = 'health_exam' + save_dir = '../../experiment_input/{}/'.format(setting_name) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + processed_data = load_dataset(data_name, + setting_name, + parent_dir, + prompt_path=raw_prompt_path, + max_tokens=2048) + save_jsonl(processed_data, + os.path.join(save_dir, '{}.jsonl'.format(data_name))) diff --git a/opencompass/datasets/medbench/evaluation.py b/opencompass/datasets/medbench/evaluation.py new file mode 100644 index 00000000..c5a9916a --- /dev/null +++ b/opencompass/datasets/medbench/evaluation.py @@ -0,0 +1,43 @@ +# flake8: noqa +from . import dataset_loader, utils +from .math_equivalence import is_equiv + + +def convert_to_set(item): + if isinstance(item, list): + return set(item) + if isinstance(item, str): + return {item} + if item is None: + return {} + raise ValueError("Input can't parse:", item) + + +def evaluate_single_sample(dataset_name, prediction, label): + if dataset_name in dataset_loader.multi_choice_datasets: + p = convert_to_set(prediction) + l = convert_to_set(label) + return p == l + elif dataset_name in dataset_loader.math_output_datasets: + return is_equiv(prediction, label) + else: + return prediction == label + + +# def evaluate(dataset_name, prediction_list, label_list): +# correct = 0 +# if dataset_name in multi_choice_datasets: +# for prediction, label in zip(prediction_list, label_list): +# p = convert_to_set(prediction) +# l = convert_to_set(label) +# if p == l: +# correct += 1 +# elif dataset_name in math_output_datasets: +# for prediction, label in zip(prediction_list, label_list): +# if is_equiv(prediction, label): +# correct += 1 +# else: +# for prediction, label in zip(prediction_list, label_list): +# if prediction == label: +# correct += 1 +# return "{0:.2%}".format(correct / len(label_list)) diff --git a/opencompass/datasets/medbench/math_equivalence.py b/opencompass/datasets/medbench/math_equivalence.py new file mode 100644 index 00000000..788900ea --- /dev/null +++ b/opencompass/datasets/medbench/math_equivalence.py @@ -0,0 +1,161 @@ +# flake8: noqa + + +# code from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py +def _fix_fracs(string): + substrs = string.split('\\frac') + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += '\\frac' + if substr[0] == '{': + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != '{': + if len(substr) > 2: + post_substr = substr[2:] + new_str += '{' + a + '}{' + b + '}' + post_substr + else: + new_str += '{' + a + '}{' + b + '}' + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += '{' + a + '}' + b + post_substr + else: + new_str += '{' + a + '}' + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split('/')) != 2: + return string + a = string.split('/')[0] + b = string.split('/')[1] + try: + a = int(a) + b = int(b) + assert string == '{}/{}'.format(a, b) + new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' + return new_string + except: + return string + + +def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if '\\text{ ' in string: + splits = string.split('\\text{ ') + assert len(splits) == 2 + return splits[0] + else: + return string + + +def _fix_sqrt(string): + if '\\sqrt' not in string: + return string + splits = string.split('\\sqrt') + new_string = splits[0] + for split in splits[1:]: + if split[0] != '{': + a = split[0] + new_substr = '\\sqrt{' + a + '}' + split[1:] + else: + new_substr = '\\sqrt' + split + new_string += new_substr + return new_string + + +def _strip_string(string): + # linebreaks + string = string.replace('\n', '') + # print(string) + + # remove inverse spaces + string = string.replace('\\!', '') + # print(string) + + # replace \\ with \ + string = string.replace('\\\\', '\\') + # print(string) + + # replace tfrac and dfrac with frac + string = string.replace('tfrac', 'frac') + string = string.replace('dfrac', 'frac') + # print(string) + + # remove \left and \right + string = string.replace('\\left', '') + string = string.replace('\\right', '') + # print(string) + + # Remove circ (degrees) + string = string.replace('^{\\circ}', '') + string = string.replace('^\\circ', '') + + # remove dollar signs + string = string.replace('\\$', '') + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace('\\%', '') + string = string.replace('\%', '') + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(' .', ' 0.') + string = string.replace('{.', '{0.') + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == '.': + string = '0' + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split('=')) == 2: + if len(string.split('=')[0]) <= 2: + string = string.split('=')[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(' ', '') + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == '0.5': + string = '\\frac{1}{2}' + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print('WARNING: Both None') + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = _strip_string(str1) + ss2 = _strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except: + return str1 == str2 diff --git a/opencompass/datasets/medbench/medbench.py b/opencompass/datasets/medbench/medbench.py new file mode 100644 index 00000000..9e8effe3 --- /dev/null +++ b/opencompass/datasets/medbench/medbench.py @@ -0,0 +1,646 @@ +import json +import os.path as osp +import sys +from datasets import Dataset +from sklearn.metrics import classification_report +from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET + +from ..base import BaseDataset +from .math_equivalence import is_equiv +from .post_process import parse_math_answer, parse_qa_multiple_answer + +import evaluate +from nltk.translate.bleu_score import sentence_bleu +# from bert_score import score +import re +from transformers import BasicTokenizer +from rouge_chinese import Rouge +basic_tokenizer = BasicTokenizer(tokenize_chinese_chars=True) + +@LOAD_DATASET.register_module() +class MedBenchDataset(BaseDataset): + + @staticmethod + def load(path: str, name: str, setting_name: str): + from .dataset_loader import load_dataset, load_dataset_as_result_schema + + assert setting_name in 'zero-shot', 'only support zero-shot setting' + dataset_wo_label = load_dataset(name, setting_name, path) + dataset_with_label = load_dataset_as_result_schema(name, path) + dataset = [] + for d1, d2 in zip(dataset_wo_label, dataset_with_label): + dataset.append({ + 'id': d2.index, + 'problem_input': d1['context'], + 'label': d2.label, + }) + dataset = Dataset.from_list(dataset) + return dataset + + +@LOAD_DATASET.register_module() +class MedBenchDataset_v2(BaseDataset): + + @staticmethod + def load(path: str, name: str, setting_name: str): + assert setting_name in 'zero-shot', 'only support zero-shot setting' + filename = osp.join(path, name + '.jsonl') + with open(filename, encoding='utf-8') as f: + data = [json.loads(line.strip()) for line in f] + dataset = [] + for item in data: + passage = item['passage'] if item['passage'] else '' + question = passage + item['question'] + options = '\n'.join(item['options']) if item['options'] else '' + if item['label']: + if isinstance(item['label'], list): + label = ''.join(item['label']) + else: + label = item['label'] + else: + label = item['answer'] + d = {'question': question, 'options': options, 'label': label} + dataset.append(d) + dataset = Dataset.from_list(dataset) + return dataset + + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator(BaseEvaluator): + + def score(self, predictions, references): + # predictions: [[]] + # references: [[]] + predictions = [parse_qa_multiple_answer(pred) for pred in predictions] + details = [] + cnt = 0 + for pred, ref in zip(predictions, references): + detail = {'pred': pred, 'answer': ref, 'correct': False} + if is_equiv(pred, ref): + cnt += 1 + detail['correct'] = True + details.append(detail) + score = cnt / len(predictions) * 100 + #输出字典类型 {'score':'', 'details'} + return {'Accuracy': score, 'details': details} + + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_mcq(BaseEvaluator): + + def score(self, predictions, references): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + 'length' + } + details = [] + cnt = 0 + for pred, ref in zip(predictions, references): + detail = {'pred': pred, 'answer': ref, 'correct': False} + if pred == ref: + cnt += 1 + detail['correct'] = True + details.append(detail) + + score = cnt / len(predictions) * 100 + + return {'score': score, 'details': details} + +def process_generated_results_CMeEE(pred_file): + structured_output = [] + answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病'] + for pred in pred_file: + list_entities = [] + for choice in answer_choices: + for piece in re.split('[,|.|。|;|\n]', pred): + if piece.startswith(f"{choice}"): + mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").split(",") + for ment in mentions: + list_entities.append({'entity':ment, 'type':choice}) + structured_output.append(list_entities) + return structured_output + +def process_generated_results_EMR(pred_file): + structured_output = [] + answer_choices = ['主诉', '现病史', '既往史', '个人史', '婚育史', '家族史'] + for pred in pred_file: + list_entities = [] + for choice in answer_choices: + for piece in re.split('[,|.|?|;|,|。|;|\n]', pred): + if piece.startswith(f"{choice}"): + mentions = piece.replace(f"{choice}:", "").split(",") + mentions = [w.strip() for w in mentions if len(w.strip()) > 0] + for ment in mentions: + list_entities.append({ment: choice}) + structured_output.append(list_entities) + return structured_output + +def process_generated_results_CMeIE(pred_file): + structured_output = [] + for line in pred_file: + gen_output = line + + # 答案格式: + # 每个关系类型占一行,格式为 + # "具有{lab}关系的头尾实体对如下:头实体为str,尾实体为str;头实体为str,尾实体为str;" + + answer_choices = "相关(导致)、鉴别诊断、遗传因素、发病性别倾向、相关(症状)、手术治疗、预防、辅助检查、筛查、阶段、临床表现、风险评估因素、同义词、发病年龄、预后生存率、病史、传播途径、治疗后症状、药物治疗、辅助治疗、化疗、死亡率、放射治疗、病因、组织学检查、内窥镜检查、多发群体、并发症、实验室检查、就诊科室、病理生理、高危因素、发病率、多发地区、病理分型、影像学检查、转移部位、发病部位、相关(转化)、外侵部位、预后状况、发病机制、多发季节" + answer_choices = answer_choices.split('、') + list_spos = [] + assert isinstance(answer_choices, list) + list_answer_strs = gen_output.split("\n") + + for line in list_answer_strs: + # 首先是解析出label: + predicate = line.split("关系的头尾实体对")[0][2: ].strip() + line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "") + for spo_str in line.split("。"): + if len(spo_str.split(",尾实体为")) < 2: + continue + + head_mention_str, tail_mention_str = spo_str.split(",尾实体为")[:2] + head_mention_str = head_mention_str.replace("头实体为", "").strip() + tail_mention_str = tail_mention_str.replace("尾实体为", "").strip() + + list_spos.append( + { + "predicate": predicate, + "subject": head_mention_str, + "object": tail_mention_str, + } + ) + structured_output.append(list_spos) + return structured_output + +def process_generated_results_CDN(pred_file): + structured_output = [] + answer_choices = json.load(open('./data/MedBench/CHIP_CDN/CHIP-CDN_entity.json', 'r')) + for line in pred_file: + gen_output = line + + # 答案格式: + # 多个选中的标准化实体,用 , 符号分割 + + answer_str = gen_output.split("\n")[-1] + answers = answer_str.split(",") + answers = [w.strip() for w in answers if len(w.strip()) > 0] + answers = [w for w in answers if w in answer_choices] + answers = list(set(answers)) + answers = [ + { + "entity": w, + "type": "normalization", + } + for w in answers + ] + + structured_output.append(answers) + return structured_output + +def process_generated_results_CDEE(pred_file): + + structured_output = [] + for line in pred_file: + gen_output = line + # 答案格式: + # 第一行:引导词 + # 每个事件占一行,事件字段用 ; 分隔, 然后每个字段是 字段名:字段值的格式" + # 字段值有多个,则用 ,符号分隔 + keys = ["主体词", "发生状态", "描述词", "解剖部位"] + + list_answer_strs = gen_output.split("\n") + list_events = [] + for ans_str in list_answer_strs: + if '主体词' in ans_str: + event_info = {} + ans_attrs = ans_str.split(";") + for a_attr in ans_attrs: + for key in keys: + if a_attr.startswith(f"{key}:"): + a_attr = a_attr.replace(f"{key}:", "").strip() + if key in ["描述词", "解剖部位"]: + a_attr_split = a_attr.split(",") + a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0] + event_info[key] = a_attr_split + else: + event_info[key] = a_attr + + for key in keys: + if key not in event_info: + if key in ["描述词", "解剖部位"]: + event_info[key] = [] + else: + event_info[key] = "" + + list_events.append(event_info) + + structured_output.append(list_events) + return structured_output + +def process_generated_results_CTC(pred_file, task_dataset): + structured_output = [] + + for line in pred_file: + gen_output = line + # 答案格式:直接回答分类标签 + answer_str = gen_output.strip() + structured_output.append(answer_str) + return structured_output + +def process_generated_results_doc_parsing(pred_file): + output = [] + for line in pred_file: + structured_output = {'体温':'', '脉搏':'', '心率':'', '收缩压':'', '舒张压':'', '呼吸':'', '上腹部深压痛':'', '腹部反跳痛':'', '上腹部肿块':''} + sentence_list = line.strip().split(',|。|\n') + for sentence in sentence_list: + if '体温' in sentence: + temp_value = re.search('[0-9]+', sentence) + if temp_value: + structured_output['体温'] = temp_value.group(0) + else: + structured_output['体温'] = '未扪及' + elif '脉搏' in sentence: + temp_value = re.search('[0-9]+', sentence) + if temp_value: + structured_output['脉搏'] = temp_value.group(0) + else: + structured_output['脉搏'] = '未扪及' + elif '心率' in sentence: + temp_value = re.search('[0-9]+', sentence) + if temp_value: + structured_output['心率'] = temp_value.group(0) + else: + structured_output['心率'] = '未扪及' + elif '收缩压' in sentence: + temp_value = re.search('[0-9]+', sentence) + if temp_value: + structured_output['收缩压'] = temp_value.group(0) + else: + structured_output['收缩压'] = '未扪及' + elif '舒张压' in sentence: + temp_value = re.search('[0-9]+', sentence) + if temp_value: + structured_output['舒张压'] = temp_value.group(0) + else: + structured_output['舒张压'] = '未扪及' + elif '呼吸' in sentence: + temp_value = re.search('[0-9]+', sentence) + if temp_value: + structured_output['呼吸'] = temp_value.group(0) + else: + structured_output['呼吸'] = '未扪及' + elif '上腹部深压痛' in sentence: + if re.search('是|存在|有', sentence): + structured_output['是否上腹部深压痛'] = '是' + else: + structured_output['是否上腹部深压痛'] = '否' + elif '腹部反跳痛' in sentence: + if re.search('是|存在|有', sentence): + structured_output['是否腹部反跳痛'] = '是' + else: + structured_output['是否腹部反跳痛'] = '否' + elif '上腹部肿块' in sentence: + if re.search('是|存在|有', sentence): + structured_output['上腹部肿块'] = '扪及' + else: + structured_output['上腹部肿块'] = '未扪及' + output.append(structured_output) + return output + +def process_generated_results_mrg(pred_file): + structured_output = [] + answer_choices = ['主诉', '现病史', '既往史', '辅助检查', '诊断'] + for pred in pred_file: + list_entities = [] + for choice in answer_choices: + for piece in re.split('[,|.|?|;|,|。|;|\n]', pred): + if piece.startswith(f"{choice}实体"): + mentions = piece.replace(f"{choice}实体:", "").split(",") + mentions = [w.strip() for w in mentions if len(w.strip()) > 0] + for ment in mentions: + list_entities.append({ment: choice}) + structured_output.append(list_entities) + return structured_output + + +def calc_info_extract_task_scores(list_structured_golden, + list_structured_predict): + + assert len(list_structured_golden) == len(list_structured_predict) + + tp = 0 + fp = 0 + fn = 0 + for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): + + answer_golden = samp_golden + answer_predict = samp_predict + + assert isinstance(answer_golden, list) + assert isinstance(answer_predict, list), "sample format is wrong!" + + set_golden = set() + for inst in answer_golden: + assert isinstance(inst, dict) + keys = sorted(list(inst.keys())) + inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys ]) + # inst = list(inst.items()) + # inst.sort() + # inst = tuple(inst) + + set_golden.add(inst) + + set_predict = set() + for inst in answer_predict: + assert isinstance(inst, dict) + keys = sorted(list(inst.keys())) + # inst = tuple([inst[w] for w in keys]) + inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys]) + + # inst = list(inst.items()) + # inst.sort() + # inst = tuple(inst) + + set_predict.add(inst) + + # print("set_predict: ", set_predict) + # print("set_golden: ", set_golden) + + tp += len(set_golden.intersection(set_predict)) + fp += len(set_predict.difference(set_golden)) + fn += len(set_golden.difference(set_predict)) + + if tp: + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1 = 2 * precision * recall / (precision + recall) + + else: + precision, recall, f1 = 0, 0, 0 + + return precision, recall, f1 + +def calc_cls_task_scores(list_structured_golden, + list_structured_predict, + list_labels=None, + return_macro=False, + ): + # types = list_labels + # scores = {c: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for c in list_labels + ["ALL"]} + + predictions = [] + ground_truths = [] + + # Count GT relations and Predicted relations + assert len(list_structured_golden) == len(list_structured_predict) + n_sents = len(list_structured_golden) + + # Count TP, FP and FN per type + for pred_samp, gt_samp in zip(list_structured_predict, list_structured_golden): + + pred_label = pred_samp + gt_label = gt_samp + assert gt_label != "" + if pred_label == "": + pred_label = list_labels[0] + + predictions.append(pred_label) + ground_truths.append(gt_label) + + # metric + cls_report = classification_report( + ground_truths, predictions, + output_dict=True, + zero_division=0, + ) + + if return_macro: + return cls_report["macro avg"]["precision"], \ + cls_report["macro avg"]["recall"], \ + cls_report["macro avg"]["f1-score"] + else: + return cls_report["weighted avg"]["precision"], \ + cls_report["weighted avg"]["recall"], \ + cls_report["weighted avg"]["f1-score"] + +def calc_nlg_task_scores(list_structured_golden, list_structured_predict): + + assert len(list_structured_golden) == len(list_structured_predict) + + scores = [] + predictions = [] + references = [] + details = [] + for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): + # print("samp_golden: ", samp_golden) + # print("samp_predict: ", samp_predict) + + # assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!" + answer_golden = samp_golden + answer_predict = samp_predict + + print('#') + print(answer_golden) + print(answer_predict) + if not (answer_predict and answer_golden): + continue + + # basic tokenizer: 拆分中文字,保留英文单词 + answer_predict = basic_tokenizer.tokenize(answer_predict) + answer_golden = basic_tokenizer.tokenize(answer_golden) + answer_predict = " ".join(answer_predict).strip() + answer_golden = " ".join(answer_golden).strip() + if answer_golden.strip() == "": + answer_golden = "无 。" + if answer_predict.strip() == "": + answer_predict = "无 。" + # print("answer_predict: ", answer_predict) + # print("answer_golden: ", answer_golden) + + predictions.append(answer_predict) + references.append(answer_golden) + + details.append({'pred':answer_predict, 'answer':answer_golden, 'correct':False}) + + rouge = Rouge() + # bleu = evaluate.load('sacrebleu') + scores = rouge.get_scores(predictions, references, avg=True) + # scores_bleu = bleu.compute(predictions=predictions, references=references) + + rouge1 = scores["rouge-1"]["f"] + rouge2 = scores["rouge-2"]["f"] + rougeL = scores["rouge-l"]["f"] + + # bleu = sentence_bleu(references, predictions) + + # bert_score = [] + # for id in range(len(predictions)): + # P, R, F1 = score([predictions[i]], [references[i]], model_type='bert-base-chinese', lang="zh", verbose=True) + # bert_score.append(F1) + # bert_score = float(sum(bert_score)) / float(len(bert_score)) + # return rougeL, bleu, bert_score + return {'RougeL': rougeL, 'details':details} + +def calc_scores_f1(dict_gt, dict_pred): + details = [] + for gt, pred in zip(dict_gt, dict_pred): + details.append({'pred':pred, 'answer':gt, 'correct':None}) + + precision, recall, f1 = calc_info_extract_task_scores(dict_gt, dict_pred) + return {'F1':f1, 'details':details} + +def calc_scores_ctc(dict_gt, dict_pred): + details = [] + for gt, pred in zip(dict_gt, dict_pred): + details.append({'pred':pred, 'answer':gt, 'correct':None}) + + gts = dict_gt + preds = dict_pred + + precision, recall, f1 = calc_cls_task_scores( + gts, + preds, + list_labels=['非上述类型', '疾病', '症状(患者感受)', + '体征(医生检测)', '怀孕相关', '肿瘤进展', + '疾病分期', '过敏耐受', '器官组织状态', + '预期寿命', '口腔相关', '药物', + '治疗或手术', '设备', '护理', + '诊断', '实验室检查', '风险评估', + '受体状态', '年龄', '特殊病人特征', + '读写能力', '性别', '教育情况', + '居住情况', '种族', '知情同意', + '参与其它试验', '研究者决定', '能力', + '伦理审查', '依存性', '成瘾行为', + '睡眠', '锻炼', '饮食', '酒精使用', + '性取向', '吸烟状况', '献血', + '病例来源', '残疾群体', '健康群体', + '数据可及性', "含有多个类别"], + return_macro=True, + ) + return {'Macro-F1':f1, 'details':details} + +def calc_scores_nlg(dict_gt, dict_pred): + + # scores = {} + scores = {'score':0, 'details':[]} + success_flag = 1 + + gts = dict_gt + preds = dict_pred + # if not len(gts) == len(preds): + # success_flag = 0 + # try: + return calc_nlg_task_scores(gts, preds) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_CMeEE(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_CMeEE(predictions) + return calc_scores_f1(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_EMR(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_EMR(predictions) + return calc_scores_f1(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_MRG(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_mrg(predictions) + return calc_scores_f1(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_CMeIE(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_CMeIE(predictions) + return calc_scores_f1(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_CHIP_CDEE(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_CDEE(predictions) + return calc_scores_f1(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_CHIP_CDN(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_CDN(predictions) + return calc_scores_f1(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_CHIP_CTC(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_CTC(predictions) + return calc_scores_ctc(predictions, references)[0] + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_Doc_parsing(BaseEvaluator): + + def score(self, predictions, references): + predictions = process_generated_results_doc_parsing(predictions) + return calc_scores_f1(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_NLG(BaseEvaluator): + + def score(self, predictions, references): + # predictions = process_generated_results_med(predictions) + return calc_scores_nlg(predictions, references) + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_Cloze(BaseEvaluator): + + def score(self, predictions, references): + # predictions: [[]] + # references: [[]] + # predictions = [parse_qa_multiple_answer(pred) for pred in predictions] + details = [] + cnt = 0 + + for pred, ref in zip(predictions, references): + detail = {'pred':pred, 'answer':ref, 'correct':False} + + if sum([item in pred for item in ref]) == len(ref): + cnt += 1 + detail['correct'] = True + details.append(detail) + score = cnt / len(predictions) * 100 + return {'Accuracy': score, 'details': details} + +@ICL_EVALUATORS.register_module() +class MedBenchEvaluator_TF(BaseEvaluator): + + def score(self, predictions, references): + # predictions: [[]] + # references: [[]] + # predictions = [parse_qa_multiple_answer(pred) for pred in predictions] + details = [] + cnt = 0 + + for pred, ref in zip(predictions, references): + + if '不' in pred or '否' in pred: + cur_pred = '不可以' + else: + cur_pred = '可以' + + detail = {'pred':cur_pred, 'answer':ref, 'correct':False} + + if cur_pred == ref: + cnt += 1 + detail['correct'] = True + + details.append(detail) + + score = cnt / len(predictions) * 100 + return {'Accuracy': score, 'details': details} diff --git a/opencompass/datasets/medbench/post_process.py b/opencompass/datasets/medbench/post_process.py new file mode 100644 index 00000000..77d4fb68 --- /dev/null +++ b/opencompass/datasets/medbench/post_process.py @@ -0,0 +1,198 @@ +# flake8: noqa +import json +import re + +from . import dataset_loader + + +def extract_last_line(string): + lines = string.split('\n') + for item in lines[::-1]: + if item.strip() != '': + string = item + break + return string + + +def remove_few_shot_prefix(string: str): + prefix_list = ['The answer is therefore', '答案是'] + for prefix in prefix_list: + if string.startswith(prefix): + string = string[len(prefix):].strip() + elif prefix in string: + index = string.rfind(prefix) + if index >= 0: + string = string[index + len(prefix):].strip() + return string + + +def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'): + if setting_name == 'few-shot-CoT': + string = extract_last_line(string) + if language == 'en': + pattern = 'answer is .*?([A-G])' + match = re.search(pattern, string) + elif language == 'zh': + pattern = '答案是.*?([A-G])' + match = re.search(pattern, string) + else: + raise ValueError('Unknown language {0}'.format(language)) + if match: + return match.group(1) + else: + return None + + +def try_parse_few_shot_pattern(string: str, dataset_name, setting_name): + if setting_name == 'few-shot-CoT': + string = extract_last_line(string) + if dataset_name in dataset_loader.chinese_cloze_datasets: + return string.startswith('答案是') + elif dataset_name in dataset_loader.english_cloze_datasets: + return string.startswith('The answer is therefore') + elif dataset_name in dataset_loader.chinese_qa_datasets: + pattern = '答案是.*?([A-G])' + match = re.search(pattern, string) + return match is not None + elif dataset_name in dataset_loader.english_qa_datasets: + pattern = 'answer is .*?([A-G])' + match = re.search(pattern, string) + return match is not None + return False + + +def parse_few_shot_qa_single_answer(string, setting_name, language='en'): + answer = try_parse_few_shot_qa_single_answer(string, setting_name, + language) + if answer is None: + return find_first_capital_letter(string) + else: + return answer + + +def find_first_capital_letter(answer): + letter_set = {'A', 'B', 'C', 'D', 'E', 'F'} + for c in answer: + if c in letter_set: + return c + # print("Can't find capital letter in:", answer) + return '' + + +def extract_answer_in_bracket(answer, prefix='【', suffix='】'): + if prefix not in answer and suffix not in answer: + # print("doesn't found special tokens in:", answer) + return '' + s = answer.index(prefix) + len(prefix) + t = answer.index(suffix) + ret = answer[s:t] + return ret + + +def parse_math_answer(setting_name, raw_string): + if setting_name == 'few-shot-CoT': + raw_string = extract_last_line(raw_string) + if setting_name == 'few-shot-CoT' or setting_name == 'few-shot': + raw_string = remove_few_shot_prefix(raw_string) + return raw_string + + def remove_boxed(s): + left = '\\boxed{' + try: + assert s[:len(left)] == left + assert s[-1] == '}' + answer = s[len(left):-1] + if '=' in answer: + answer = answer.split('=')[-1].lstrip(' ') + return answer + except: + return None + + def last_boxed_only_string(string): + idx = string.rfind('\\boxed') + if idx < 0: + idx = string.rfind('\\fbox') + if idx < 0: + return None + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == '{': + num_left_braces_open += 1 + if string[i] == '}': + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx == None: + retval = None + else: + retval = string[idx:right_brace_idx + 1] + + return retval + + def get_answer_with_dollar_sign(s): + first_pattern = '\$(.*)\$' + last_match = None + matches = re.findall(first_pattern, s) + if matches: + last_match = matches[-1] + if '=' in last_match: + last_match = last_match.split('=')[-1].lstrip(' ') + return last_match + + def get_answer_without_dollar_sign(s): + last_match = None + if '=' in s: + last_match = s.split('=')[-1].lstrip(' ').rstrip('.') + if '\\n' in last_match: + last_match = last_match.split('\\n')[0] + else: + pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])' + matches = re.findall(pattern, s) + if matches: + last_match = matches[-1] + return last_match + + raw_string = remove_few_shot_prefix(raw_string) + if '\\boxed' in raw_string: + answer = remove_boxed(last_boxed_only_string(raw_string)) + else: + answer = get_answer_with_dollar_sign(raw_string) + if not answer: + answer = get_answer_without_dollar_sign(raw_string) + return answer + + +def parse_qa_multiple_answer(string): + # if setting_name == 'few-shot-CoT': + # string = extract_last_line(string) + pattern = '\(*([A-Z])\)*' + match = re.findall(pattern, string) + if match: + return match + return [] + + +def post_process(dataset_name, setting_name, prediction): + if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets: + return parse_math_answer(setting_name, prediction) + + if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']: + return parse_qa_multiple_answer(prediction, setting_name) + + # all other datasets are QA problems with single answer + if 'zero-shot' in setting_name: + answer = find_first_capital_letter(prediction) + return answer + + # all other datasets are QA problems with single answer and setting_name are few-shot + language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh' + if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets: + return parse_few_shot_qa_single_answer(prediction, setting_name, + language) + else: + raise ValueError(f'Unsupported dataset name {dataset_name}') diff --git a/opencompass/datasets/medbench/utils.py b/opencompass/datasets/medbench/utils.py new file mode 100644 index 00000000..fbb31105 --- /dev/null +++ b/opencompass/datasets/medbench/utils.py @@ -0,0 +1,43 @@ +# flake8: noqa +import json + + +def read_jsonl(path): + with open(path, encoding='utf8') as fh: + results = [] + for line in fh: + if line is None: + continue + try: + results.append(json.loads(line) if line != 'null' else line) + except Exception as e: + print(e) + print(path) + print(line) + raise e + return results + + +def save_jsonl(lines, directory): + with open(directory, 'w', encoding='utf8') as f: + for line in lines: + f.write(json.dumps(line, ensure_ascii=False) + '\n') + + +def extract_answer(js): + try: + if js is None or js == 'null': + return '' + answer = '' + if isinstance(js, str): + answer = js + elif 'text' in js['choices'][0]: + answer = js['choices'][0]['text'] + else: + answer = js['choices'][0]['message']['content'] + # answer = js[''] + return answer + except Exception as e: + # print(e) + # print(js) + return ''