mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add medbench (#678)
* update medbench * medbench update * format medbench * format --------- Co-authored-by: 施晓明 <PJLAB\shixiaoming@pjnl104220118l.pjlab.org> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
7cb53a95fa
commit
1bf85949ef
@ -5,7 +5,8 @@ exclude: |
|
||||
opencompass/utils/internal/|
|
||||
opencompass/openicl/icl_evaluator/hf_metrics/|
|
||||
opencompass/datasets/lawbench/utils|
|
||||
opencompass/datasets/lawbench/evaluation_functions/
|
||||
opencompass/datasets/lawbench/evaluation_functions/|
|
||||
opencompass/datasets/medbench
|
||||
)
|
||||
repos:
|
||||
- repo: https://gitee.com/openmmlab/mirrors-flake8
|
||||
|
@ -5,7 +5,8 @@ exclude: |
|
||||
opencompass/utils/internal/|
|
||||
opencompass/openicl/icl_evaluator/hf_metrics/|
|
||||
opencompass/datasets/lawbench/utils|
|
||||
opencompass/datasets/lawbench/evaluation_functions/
|
||||
opencompass/datasets/lawbench/evaluation_functions/|
|
||||
opencompass/datasets/medbench/
|
||||
)
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
|
4
configs/datasets/MedBench/medbench_gen.py
Normal file
4
configs/datasets/MedBench/medbench_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .medbench_gen_d44f24 import medbench_datasets # noqa: F401, F403
|
160
configs/datasets/MedBench/medbench_gen_d44f24.py
Normal file
160
configs/datasets/MedBench/medbench_gen_d44f24.py
Normal file
@ -0,0 +1,160 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import (
|
||||
MedBenchDataset,
|
||||
MedBenchEvaluator,
|
||||
MedBenchEvaluator_Cloze,
|
||||
MedBenchEvaluator_IE,
|
||||
MedBenchEvaluator_mcq,
|
||||
MedBenchEvaluator_CMeEE,
|
||||
MedBenchEvaluator_CMeIE,
|
||||
MedBenchEvaluator_CHIP_CDEE,
|
||||
MedBenchEvaluator_CHIP_CDN,
|
||||
MedBenchEvaluator_CHIP_CTC,
|
||||
MedBenchEvaluator_NLG,
|
||||
MedBenchEvaluator_TF,
|
||||
MedBenchEvaluator_EMR,
|
||||
)
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
medbench_reader_cfg = dict(
|
||||
input_columns=['problem_input'], output_column='label')
|
||||
|
||||
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断
|
||||
|
||||
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
|
||||
|
||||
medbench_cloze_sets = ['Triage'] # 限定域QA,有标答
|
||||
|
||||
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
|
||||
|
||||
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价
|
||||
|
||||
#, 'CMeIE', 'CHIP_CDEE', 'CHIP_CDN', 'CHIP_CTC', 'Doc_parsing', 'MRG'
|
||||
|
||||
medbench_datasets = []
|
||||
|
||||
|
||||
for name in medbench_single_choice_sets:
|
||||
medbench_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[dict(role="HUMAN", prompt='{problem_input}')])),
|
||||
retriever=dict(type=ZeroRetriever
|
||||
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
medbench_eval_cfg = dict(
|
||||
evaluator=dict(type=MedBenchEvaluator_TF), pred_role="BOT")
|
||||
|
||||
medbench_datasets.append(
|
||||
dict(
|
||||
type=MedBenchDataset,
|
||||
path='./data/MedBench/' + name,
|
||||
name=name,
|
||||
abbr='medbench-' + name,
|
||||
setting_name='zero-shot',
|
||||
reader_cfg=medbench_reader_cfg,
|
||||
infer_cfg=medbench_infer_cfg.copy(),
|
||||
eval_cfg=medbench_eval_cfg.copy()))
|
||||
|
||||
for name in medbench_multiple_choices_sets:
|
||||
medbench_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[dict(role="HUMAN", prompt='{problem_input}')])),
|
||||
retriever=dict(type=ZeroRetriever
|
||||
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
medbench_eval_cfg = dict(
|
||||
evaluator=dict(type=MedBenchEvaluator), pred_role="BOT")
|
||||
|
||||
medbench_datasets.append(
|
||||
dict(
|
||||
type=MedBenchDataset,
|
||||
path='./data/MedBench/' + name,
|
||||
name=name,
|
||||
abbr='medbench-' + name,
|
||||
setting_name='zero-shot',
|
||||
reader_cfg=medbench_reader_cfg,
|
||||
infer_cfg=medbench_infer_cfg.copy(),
|
||||
eval_cfg=medbench_eval_cfg.copy()))
|
||||
|
||||
for name in medbench_qa_sets:
|
||||
medbench_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[dict(role="HUMAN", prompt='{problem_input}')])),
|
||||
retriever=dict(type=ZeroRetriever
|
||||
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
medbench_eval_cfg = dict(
|
||||
evaluator=dict(type=MedBenchEvaluator_NLG), pred_role="BOT")
|
||||
|
||||
medbench_datasets.append(
|
||||
dict(
|
||||
type=MedBenchDataset,
|
||||
path='./data/MedBench/' + name,
|
||||
name=name,
|
||||
abbr='medbench-' + name,
|
||||
setting_name='zero-shot',
|
||||
reader_cfg=medbench_reader_cfg,
|
||||
infer_cfg=medbench_infer_cfg.copy(),
|
||||
eval_cfg=medbench_eval_cfg.copy()))
|
||||
|
||||
for name in medbench_cloze_sets:
|
||||
medbench_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[dict(role="HUMAN", prompt='{problem_input}')])),
|
||||
retriever=dict(type=ZeroRetriever
|
||||
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
medbench_eval_cfg = dict(
|
||||
evaluator=dict(type=MedBenchEvaluator_Cloze), pred_role="BOT")
|
||||
|
||||
medbench_datasets.append(
|
||||
dict(
|
||||
type=MedBenchDataset,
|
||||
path='./data/MedBench/' + name,
|
||||
name=name,
|
||||
abbr='medbench-' + name,
|
||||
setting_name='zero-shot',
|
||||
reader_cfg=medbench_reader_cfg,
|
||||
infer_cfg=medbench_infer_cfg.copy(),
|
||||
eval_cfg=medbench_eval_cfg.copy()))
|
||||
|
||||
for name in medbench_ie_sets:
|
||||
medbench_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[dict(role="HUMAN", prompt='{problem_input}')])),
|
||||
retriever=dict(type=ZeroRetriever
|
||||
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
medbench_eval_cfg = dict(
|
||||
evaluator=dict(type=eval('MedBenchEvaluator_'+name)), pred_role="BOT")
|
||||
|
||||
medbench_datasets.append(
|
||||
dict(
|
||||
type=MedBenchDataset,
|
||||
path='./data/MedBench/' + name,
|
||||
name=name,
|
||||
abbr='medbench-' + name,
|
||||
setting_name='zero-shot',
|
||||
reader_cfg=medbench_reader_cfg,
|
||||
infer_cfg=medbench_infer_cfg.copy(),
|
||||
eval_cfg=medbench_eval_cfg.copy()))
|
||||
|
||||
del name, medbench_infer_cfg, medbench_eval_cfg
|
@ -56,6 +56,7 @@ from .longbench import * # noqa: F401, F403
|
||||
from .math import * # noqa: F401, F403
|
||||
from .mathbench import * # noqa: F401, F403
|
||||
from .mbpp import * # noqa: F401, F403
|
||||
from .medbench import * # noqa: F401, F403
|
||||
from .mmlu import * # noqa: F401, F403
|
||||
from .multirc import * # noqa: F401, F403
|
||||
from .narrativeqa import * # noqa: F401, F403
|
||||
|
3
opencompass/datasets/medbench/__init__.py
Normal file
3
opencompass/datasets/medbench/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .medbench import * # noqa: F401, F403
|
104
opencompass/datasets/medbench/constructions.py
Normal file
104
opencompass/datasets/medbench/constructions.py
Normal file
@ -0,0 +1,104 @@
|
||||
# flake8: noqa
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class TaskSchema(object):
|
||||
|
||||
def __init__(self,
|
||||
passage=None,
|
||||
question=None,
|
||||
options=None,
|
||||
label=None,
|
||||
answer=None,
|
||||
other=None):
|
||||
self.passage = passage
|
||||
self.question = question
|
||||
self.options = options
|
||||
self.label = label
|
||||
self.answer = answer
|
||||
self.other = other
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'passage': self.passage,
|
||||
'question': self.question,
|
||||
'options': self.options,
|
||||
'label': self.label,
|
||||
'answer': self.answer,
|
||||
'other': self.other
|
||||
}
|
||||
|
||||
|
||||
# define README.json
|
||||
class MedBenchInstance(object):
|
||||
|
||||
def __init__(self, task_description, data_source, task_schema, output,
|
||||
evaluation_metric, task_example):
|
||||
self.task_description = task_description
|
||||
self.data_source = data_source
|
||||
self.task_schema = task_schema
|
||||
self.output = output
|
||||
self.evaluation_metric = evaluation_metric
|
||||
self.task_example = task_example
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'task description': self.task_description,
|
||||
'data source': self.data_source,
|
||||
'task schema': self.task_schema.to_dict(),
|
||||
'output': self.output,
|
||||
'evaluation metric': self.evaluation_metric,
|
||||
'task example': self.task_example
|
||||
}
|
||||
|
||||
|
||||
class ChatGPTSchema(object):
|
||||
|
||||
def __init__(self, context=None, metadata=''):
|
||||
self.context = context
|
||||
self.metadata = metadata
|
||||
|
||||
def to_dict(self):
|
||||
return {'context': self.context, 'metadata': self.metadata}
|
||||
|
||||
|
||||
class ResultsForHumanSchema(object):
|
||||
|
||||
def __init__(self,
|
||||
index,
|
||||
problem_input,
|
||||
label,
|
||||
model_input='',
|
||||
model_output='',
|
||||
parse_result='',
|
||||
first_stage_output='',
|
||||
second_stage_input='',
|
||||
is_correct=False):
|
||||
self.index = index
|
||||
self.problem_input = problem_input
|
||||
self.model_input = model_input
|
||||
self.model_output = model_output
|
||||
self.parse_result = parse_result
|
||||
self.label = label
|
||||
self.first_stage_output = first_stage_output
|
||||
self.second_stage_input = second_stage_input
|
||||
self.is_correct = is_correct
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'index': self.index,
|
||||
'problem_input': self.problem_input,
|
||||
'model_input': self.model_input,
|
||||
'model_output': self.model_output,
|
||||
'parse_result': self.parse_result,
|
||||
'label': self.label,
|
||||
'is_correct': self.is_correct,
|
||||
'first_stage_output': self.first_stage_output,
|
||||
'second_stage_input': self.second_stage_input,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def to_tsv(result_list, path):
|
||||
result_json = [item.to_dict() for item in result_list]
|
||||
table = pd.json_normalize(result_json)
|
||||
table.to_excel(path, index=False)
|
338
opencompass/datasets/medbench/dataset_loader.py
Normal file
338
opencompass/datasets/medbench/dataset_loader.py
Normal file
@ -0,0 +1,338 @@
|
||||
# flake8: noqa
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from tqdm import tqdm
|
||||
|
||||
from .constructions import ChatGPTSchema, ResultsForHumanSchema
|
||||
from .utils import extract_answer, read_jsonl, save_jsonl
|
||||
|
||||
# define the datasets
|
||||
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断
|
||||
|
||||
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答
|
||||
|
||||
medbench_cloze_sets = ['Triage'] # 限定域QA,有标答
|
||||
|
||||
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
|
||||
|
||||
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价
|
||||
|
||||
def convert_zero_shot(line, dataset_name):
|
||||
# passage = line['passage'] if line['passage'] is not None else ''
|
||||
if dataset_name in medbench_qa_sets:
|
||||
return line['question']
|
||||
elif dataset_name in medbench_cloze_sets:
|
||||
return '问题:' + line['question'] + '\n答案:'
|
||||
elif dataset_name in medbench_multiple_choices_sets:
|
||||
return '问题:' + line['question'] + ' ' \
|
||||
+ '选项:' + ' '.join(line['options']) + '\n从A到G,我们应该选择'
|
||||
else:
|
||||
return line['question']
|
||||
|
||||
prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n'
|
||||
|
||||
|
||||
# def convert_zero_shot_CoT_stage1(line, dataset_name):
|
||||
# try:
|
||||
# passage = line['passage'] if line['passage'] is not None else ''
|
||||
# if dataset_name in english_qa_datasets:
|
||||
# return passage + 'Q: ' + line['question'] + ' ' \
|
||||
# + 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \
|
||||
# "Let's think step by step."
|
||||
|
||||
# elif dataset_name in chinese_qa_datasets:
|
||||
# option_string = 'ABCDEFG'
|
||||
# count = len(line['options'])
|
||||
# if count == 1:
|
||||
# count = 4
|
||||
# return passage + '问题:' + line['question'] + ' ' \
|
||||
# + '选项:' + ' '.join(line['options']) + '\n' + \
|
||||
# '从A到{}, 我们应选择什么?让我们逐步思考:'.format(option_string[count - 1])
|
||||
|
||||
# elif dataset_name in english_cloze_datasets:
|
||||
# return passage + 'Q: ' + line['question'] + '\n' \
|
||||
# "A: Let's think step by step."
|
||||
|
||||
# elif dataset_name in chinese_cloze_datasets:
|
||||
# return passage + '问题:' + line['question'] + '\n' \
|
||||
# '答案:让我们逐步思考:'
|
||||
# except NameError:
|
||||
# print('Dataset not defined.')
|
||||
|
||||
|
||||
# process few-shot raw_prompts
|
||||
def combine_prompt(prompt_path,
|
||||
dataset_name,
|
||||
load_explanation=True,
|
||||
chat_mode=False):
|
||||
skip_passage = False
|
||||
if dataset_name == 'sat-en-without-passage':
|
||||
skip_passage = True
|
||||
dataset_name = 'sat-en'
|
||||
demostrations = []
|
||||
# read the prompts by context and explanation
|
||||
context_row = [0, 1, 3, 5, 7, 9]
|
||||
explanation_row = [0, 2, 4, 6, 8, 10]
|
||||
raw_prompts_context = pd.read_csv(prompt_path,
|
||||
header=0,
|
||||
skiprows=lambda x: x not in context_row,
|
||||
keep_default_na=False)
|
||||
raw_prompts_explanation = pd.read_csv(
|
||||
prompt_path,
|
||||
header=0,
|
||||
skiprows=lambda x: x not in explanation_row,
|
||||
keep_default_na=False).replace(r'\n\n', '\n', regex=True)
|
||||
contexts = []
|
||||
for line in list(raw_prompts_context[dataset_name]):
|
||||
if line:
|
||||
# print(line)
|
||||
contexts.append(ast.literal_eval(line))
|
||||
explanations = [
|
||||
exp for exp in raw_prompts_explanation[dataset_name] if exp
|
||||
]
|
||||
|
||||
for idx, (con, exp) in enumerate(zip(contexts, explanations)):
|
||||
passage = con['passage'] if con[
|
||||
'passage'] is not None and not skip_passage else ''
|
||||
question = con['question']
|
||||
options = con['options'] if con['options'] is not None else ''
|
||||
label = con['label'] if con['label'] is not None else ''
|
||||
answer = con[
|
||||
'answer'] if 'answer' in con and con['answer'] is not None else ''
|
||||
|
||||
if dataset_name in qa_datasets:
|
||||
question_input = '问题 {}. '.format(idx + 1) + passage + ' ' + question + '\n' \
|
||||
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
|
||||
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
|
||||
+ '答案是 {}'.format(label)
|
||||
|
||||
elif dataset_name in cloze_datasets:
|
||||
question_input = '问题 {}. '.format(idx + 1) + question + '\n'
|
||||
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
|
||||
+ '答案是 {}'.format(answer)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'During loading few-sot examples, found unknown dataset: {dataset_name}'
|
||||
)
|
||||
if chat_mode:
|
||||
demostrations.append((question_input, question_output))
|
||||
else:
|
||||
demostrations.append(question_input + question_output + '\n')
|
||||
|
||||
return demostrations
|
||||
|
||||
|
||||
enc = None
|
||||
|
||||
|
||||
def _lazy_load_enc():
|
||||
global enc
|
||||
if enc is None:
|
||||
enc = tiktoken.encoding_for_model('gpt-4')
|
||||
|
||||
|
||||
# cut prompt if reach max token length
|
||||
def concat_prompt(demos,
|
||||
dataset_name,
|
||||
max_tokens,
|
||||
end_of_example='\n',
|
||||
verbose=False):
|
||||
_lazy_load_enc()
|
||||
demostration_en = 'Here are the answers for the problems in the exam.\n'
|
||||
demostration_zh = '以下是考试中各个问题的答案。\n'
|
||||
|
||||
for i in range(len(demos)):
|
||||
# print(len(enc.encode(demostration_en)), len(enc.encode(demostration_zh)))
|
||||
if dataset_name in english_qa_datasets:
|
||||
demostration_en = demostration_en + demos[i] + end_of_example
|
||||
elif dataset_name in chinese_qa_datasets:
|
||||
demostration_zh = demostration_zh + demos[i] + end_of_example
|
||||
elif dataset_name in english_cloze_datasets:
|
||||
demostration_en = demostration_en + demos[i] + end_of_example
|
||||
elif dataset_name in chinese_cloze_datasets:
|
||||
demostration_zh = demostration_zh + demos[i] + end_of_example
|
||||
# break if reach max token limit
|
||||
if len(enc.encode(demostration_en)) < max_tokens and len(
|
||||
enc.encode(demostration_zh)) < max_tokens:
|
||||
output = demostration_en if len(demostration_en) > len(
|
||||
demostration_zh) else demostration_zh
|
||||
prompt_num = i + 1
|
||||
else:
|
||||
break
|
||||
if verbose:
|
||||
print('max_tokens set as ', max_tokens, 'actual_tokens is',
|
||||
len(enc.encode(output)), 'num_shot is', prompt_num)
|
||||
return output, prompt_num
|
||||
|
||||
|
||||
def concat_prompt_chat_mode(demos,
|
||||
dataset_name,
|
||||
max_tokens,
|
||||
end_of_example='\n',
|
||||
verbose=False):
|
||||
_lazy_load_enc()
|
||||
answers = []
|
||||
sentences = ''
|
||||
for i in range(len(demos)):
|
||||
answers += [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': demos[i][0]
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': demos[i][1]
|
||||
},
|
||||
]
|
||||
sentences += json.dumps(answers[-1])
|
||||
# break if reach max token limit
|
||||
if len(enc.encode(sentences)) > max_tokens:
|
||||
answers.pop()
|
||||
answers.pop()
|
||||
break
|
||||
if verbose:
|
||||
print('max_tokens set as ', max_tokens, 'actual_tokens is',
|
||||
len(enc.encode(sentences)), 'num_shot is',
|
||||
len(answers) // 2)
|
||||
return answers, len(answers) // 2
|
||||
|
||||
|
||||
def convert_few_shot(line, dataset_name, demo, n_shot, chat_mode=False):
|
||||
passage = line['passage'] if line['passage'] is not None else ''
|
||||
question = line['question']
|
||||
options = line['options'] if line['options'] is not None else ''
|
||||
|
||||
if dataset_name in qa_datasets:
|
||||
question_input = '问题 {}. '.format(n_shot + 1) + passage + ' ' + question + '\n' \
|
||||
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
|
||||
# + "问题 {}的解析: ".format(n_shot + 1)
|
||||
|
||||
if dataset_name in cloze_datasets:
|
||||
question_input = '问题 {}. '.format(n_shot + 1) + question + '\n'
|
||||
# + "问题 {}的解析: ".format(n_shot + 1)
|
||||
if chat_mode:
|
||||
return demo + [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': question_input
|
||||
},
|
||||
]
|
||||
else:
|
||||
return demo + question_input
|
||||
|
||||
|
||||
def load_dataset(dataset_name,
|
||||
setting_name,
|
||||
parent_path,
|
||||
prompt_path=None,
|
||||
max_tokens=None,
|
||||
end_of_example='\n',
|
||||
chat_mode=False,
|
||||
verbose=False):
|
||||
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
|
||||
loaded_jsonl = read_jsonl(test_path)
|
||||
processed = []
|
||||
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
|
||||
# process demo once if it is few-shot-CoT
|
||||
processed_demos = combine_prompt(
|
||||
prompt_path,
|
||||
dataset_name,
|
||||
load_explanation=setting_name == 'few-shot-CoT',
|
||||
chat_mode=chat_mode)
|
||||
if chat_mode:
|
||||
chosen_prompt, n_shot = concat_prompt_chat_mode(processed_demos,
|
||||
dataset_name,
|
||||
max_tokens,
|
||||
end_of_example,
|
||||
verbose=verbose)
|
||||
else:
|
||||
chosen_prompt, n_shot = concat_prompt(processed_demos,
|
||||
dataset_name,
|
||||
max_tokens,
|
||||
end_of_example,
|
||||
verbose=verbose)
|
||||
|
||||
if verbose:
|
||||
loaded_jsonl = tqdm(loaded_jsonl)
|
||||
for meta_idx, line in enumerate(loaded_jsonl):
|
||||
# 正确
|
||||
if setting_name == 'zero-shot':
|
||||
ctxt = convert_zero_shot(line, dataset_name)
|
||||
elif setting_name == 'zero-shot-CoT':
|
||||
ctxt = convert_zero_shot_CoT_stage1(line, dataset_name)
|
||||
elif setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
|
||||
ctxt = convert_few_shot(line, dataset_name, chosen_prompt, n_shot,
|
||||
chat_mode)
|
||||
try:
|
||||
new_instance = ChatGPTSchema(context=ctxt, metadata=meta_idx)
|
||||
processed.append(new_instance.to_dict())
|
||||
except NameError:
|
||||
print('Dataset not defined.')
|
||||
return processed
|
||||
|
||||
|
||||
def generate_second_stage_input(dataset_name,
|
||||
input_list,
|
||||
output_list,
|
||||
with_format_prompt=False):
|
||||
try:
|
||||
chinese_format_prompt = '根据以上内容,你的任务是把最终的答案提取出来并填在【】中,例如【0】或者【A】。'
|
||||
if dataset_name in qa_datasets:
|
||||
prompt_suffix = '因此,从A到D, 我们应选择'
|
||||
if with_format_prompt:
|
||||
prompt_suffix = chinese_format_prompt + prompt_suffix
|
||||
elif dataset_name in cloze_datasets:
|
||||
prompt_suffix = '因此,答案是'
|
||||
if with_format_prompt:
|
||||
prompt_suffix = chinese_format_prompt + prompt_suffix
|
||||
except NameError:
|
||||
print('Dataset not defined.')
|
||||
processed = []
|
||||
for i in range(len(input_list)):
|
||||
ctxt = '{0}\n{1}\n{2}'.format(input_list[i]['context'],
|
||||
extract_answer(output_list[i]),
|
||||
prompt_suffix)
|
||||
new_instance = ChatGPTSchema(context=ctxt,
|
||||
metadata=input_list[i]['metadata'])
|
||||
processed.append(new_instance.to_dict())
|
||||
return processed
|
||||
|
||||
|
||||
def load_dataset_as_result_schema(dataset_name, parent_path):
|
||||
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
|
||||
loaded_jsonl = read_jsonl(test_path)
|
||||
|
||||
processed = []
|
||||
for i, line in enumerate(loaded_jsonl):
|
||||
problem_input = convert_zero_shot(line, dataset_name)
|
||||
processed.append(
|
||||
ResultsForHumanSchema(
|
||||
index=i,
|
||||
problem_input=problem_input,
|
||||
# label=line['label'] if line['label'] else line['answer']
|
||||
label = line['answer']
|
||||
))
|
||||
return processed
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# set variables
|
||||
parent_dir = '../../data/exam_guidance'
|
||||
|
||||
# set dataset name to process
|
||||
setting_name = 'zero-shot' # setting_name can be chosen from ["zero-shot", "zero-shot-CoT", "few-shot-CoT"]
|
||||
data_name = 'health_exam'
|
||||
save_dir = '../../experiment_input/{}/'.format(setting_name)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
processed_data = load_dataset(data_name,
|
||||
setting_name,
|
||||
parent_dir,
|
||||
prompt_path=raw_prompt_path,
|
||||
max_tokens=2048)
|
||||
save_jsonl(processed_data,
|
||||
os.path.join(save_dir, '{}.jsonl'.format(data_name)))
|
43
opencompass/datasets/medbench/evaluation.py
Normal file
43
opencompass/datasets/medbench/evaluation.py
Normal file
@ -0,0 +1,43 @@
|
||||
# flake8: noqa
|
||||
from . import dataset_loader, utils
|
||||
from .math_equivalence import is_equiv
|
||||
|
||||
|
||||
def convert_to_set(item):
|
||||
if isinstance(item, list):
|
||||
return set(item)
|
||||
if isinstance(item, str):
|
||||
return {item}
|
||||
if item is None:
|
||||
return {}
|
||||
raise ValueError("Input can't parse:", item)
|
||||
|
||||
|
||||
def evaluate_single_sample(dataset_name, prediction, label):
|
||||
if dataset_name in dataset_loader.multi_choice_datasets:
|
||||
p = convert_to_set(prediction)
|
||||
l = convert_to_set(label)
|
||||
return p == l
|
||||
elif dataset_name in dataset_loader.math_output_datasets:
|
||||
return is_equiv(prediction, label)
|
||||
else:
|
||||
return prediction == label
|
||||
|
||||
|
||||
# def evaluate(dataset_name, prediction_list, label_list):
|
||||
# correct = 0
|
||||
# if dataset_name in multi_choice_datasets:
|
||||
# for prediction, label in zip(prediction_list, label_list):
|
||||
# p = convert_to_set(prediction)
|
||||
# l = convert_to_set(label)
|
||||
# if p == l:
|
||||
# correct += 1
|
||||
# elif dataset_name in math_output_datasets:
|
||||
# for prediction, label in zip(prediction_list, label_list):
|
||||
# if is_equiv(prediction, label):
|
||||
# correct += 1
|
||||
# else:
|
||||
# for prediction, label in zip(prediction_list, label_list):
|
||||
# if prediction == label:
|
||||
# correct += 1
|
||||
# return "{0:.2%}".format(correct / len(label_list))
|
161
opencompass/datasets/medbench/math_equivalence.py
Normal file
161
opencompass/datasets/medbench/math_equivalence.py
Normal file
@ -0,0 +1,161 @@
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
# code from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split('\\frac')
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += '\\frac'
|
||||
if substr[0] == '{':
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != '{':
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += '{' + a + '}{' + b + '}' + post_substr
|
||||
else:
|
||||
new_str += '{' + a + '}{' + b + '}'
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += '{' + a + '}' + b + post_substr
|
||||
else:
|
||||
new_str += '{' + a + '}' + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split('/')) != 2:
|
||||
return string
|
||||
a = string.split('/')[0]
|
||||
b = string.split('/')[1]
|
||||
try:
|
||||
a = int(a)
|
||||
b = int(b)
|
||||
assert string == '{}/{}'.format(a, b)
|
||||
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
||||
return new_string
|
||||
except:
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string):
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
if '\\text{ ' in string:
|
||||
splits = string.split('\\text{ ')
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
if '\\sqrt' not in string:
|
||||
return string
|
||||
splits = string.split('\\sqrt')
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if split[0] != '{':
|
||||
a = split[0]
|
||||
new_substr = '\\sqrt{' + a + '}' + split[1:]
|
||||
else:
|
||||
new_substr = '\\sqrt' + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _strip_string(string):
|
||||
# linebreaks
|
||||
string = string.replace('\n', '')
|
||||
# print(string)
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace('\\!', '')
|
||||
# print(string)
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace('\\\\', '\\')
|
||||
# print(string)
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace('tfrac', 'frac')
|
||||
string = string.replace('dfrac', 'frac')
|
||||
# print(string)
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace('\\left', '')
|
||||
string = string.replace('\\right', '')
|
||||
# print(string)
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace('^{\\circ}', '')
|
||||
string = string.replace('^\\circ', '')
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace('\\$', '')
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace('\\%', '')
|
||||
string = string.replace('\%', '')
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(' .', ' 0.')
|
||||
string = string.replace('{.', '{0.')
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == '.':
|
||||
string = '0' + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split('=')) == 2:
|
||||
if len(string.split('=')[0]) <= 2:
|
||||
string = string.split('=')[1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(' ', '')
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == '0.5':
|
||||
string = '\\frac{1}{2}'
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def is_equiv(str1, str2, verbose=False):
|
||||
if str1 is None and str2 is None:
|
||||
print('WARNING: Both None')
|
||||
return True
|
||||
if str1 is None or str2 is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
ss1 = _strip_string(str1)
|
||||
ss2 = _strip_string(str2)
|
||||
if verbose:
|
||||
print(ss1, ss2)
|
||||
return ss1 == ss2
|
||||
except:
|
||||
return str1 == str2
|
646
opencompass/datasets/medbench/medbench.py
Normal file
646
opencompass/datasets/medbench/medbench.py
Normal file
@ -0,0 +1,646 @@
|
||||
import json
|
||||
import os.path as osp
|
||||
import sys
|
||||
from datasets import Dataset
|
||||
from sklearn.metrics import classification_report
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
|
||||
from ..base import BaseDataset
|
||||
from .math_equivalence import is_equiv
|
||||
from .post_process import parse_math_answer, parse_qa_multiple_answer
|
||||
|
||||
import evaluate
|
||||
from nltk.translate.bleu_score import sentence_bleu
|
||||
# from bert_score import score
|
||||
import re
|
||||
from transformers import BasicTokenizer
|
||||
from rouge_chinese import Rouge
|
||||
basic_tokenizer = BasicTokenizer(tokenize_chinese_chars=True)
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class MedBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str, setting_name: str):
|
||||
from .dataset_loader import load_dataset, load_dataset_as_result_schema
|
||||
|
||||
assert setting_name in 'zero-shot', 'only support zero-shot setting'
|
||||
dataset_wo_label = load_dataset(name, setting_name, path)
|
||||
dataset_with_label = load_dataset_as_result_schema(name, path)
|
||||
dataset = []
|
||||
for d1, d2 in zip(dataset_wo_label, dataset_with_label):
|
||||
dataset.append({
|
||||
'id': d2.index,
|
||||
'problem_input': d1['context'],
|
||||
'label': d2.label,
|
||||
})
|
||||
dataset = Dataset.from_list(dataset)
|
||||
return dataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class MedBenchDataset_v2(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str, setting_name: str):
|
||||
assert setting_name in 'zero-shot', 'only support zero-shot setting'
|
||||
filename = osp.join(path, name + '.jsonl')
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
dataset = []
|
||||
for item in data:
|
||||
passage = item['passage'] if item['passage'] else ''
|
||||
question = passage + item['question']
|
||||
options = '\n'.join(item['options']) if item['options'] else ''
|
||||
if item['label']:
|
||||
if isinstance(item['label'], list):
|
||||
label = ''.join(item['label'])
|
||||
else:
|
||||
label = item['label']
|
||||
else:
|
||||
label = item['answer']
|
||||
d = {'question': question, 'options': options, 'label': label}
|
||||
dataset.append(d)
|
||||
dataset = Dataset.from_list(dataset)
|
||||
return dataset
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
# predictions: [[]]
|
||||
# references: [[]]
|
||||
predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
|
||||
details = []
|
||||
cnt = 0
|
||||
for pred, ref in zip(predictions, references):
|
||||
detail = {'pred': pred, 'answer': ref, 'correct': False}
|
||||
if is_equiv(pred, ref):
|
||||
cnt += 1
|
||||
detail['correct'] = True
|
||||
details.append(detail)
|
||||
score = cnt / len(predictions) * 100
|
||||
#输出字典类型 {'score':'', 'details'}
|
||||
return {'Accuracy': score, 'details': details}
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_mcq(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error': 'predictions and references have different '
|
||||
'length'
|
||||
}
|
||||
details = []
|
||||
cnt = 0
|
||||
for pred, ref in zip(predictions, references):
|
||||
detail = {'pred': pred, 'answer': ref, 'correct': False}
|
||||
if pred == ref:
|
||||
cnt += 1
|
||||
detail['correct'] = True
|
||||
details.append(detail)
|
||||
|
||||
score = cnt / len(predictions) * 100
|
||||
|
||||
return {'score': score, 'details': details}
|
||||
|
||||
def process_generated_results_CMeEE(pred_file):
|
||||
structured_output = []
|
||||
answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病']
|
||||
for pred in pred_file:
|
||||
list_entities = []
|
||||
for choice in answer_choices:
|
||||
for piece in re.split('[,|.|。|;|\n]', pred):
|
||||
if piece.startswith(f"{choice}"):
|
||||
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").split(",")
|
||||
for ment in mentions:
|
||||
list_entities.append({'entity':ment, 'type':choice})
|
||||
structured_output.append(list_entities)
|
||||
return structured_output
|
||||
|
||||
def process_generated_results_EMR(pred_file):
|
||||
structured_output = []
|
||||
answer_choices = ['主诉', '现病史', '既往史', '个人史', '婚育史', '家族史']
|
||||
for pred in pred_file:
|
||||
list_entities = []
|
||||
for choice in answer_choices:
|
||||
for piece in re.split('[,|.|?|;|,|。|;|\n]', pred):
|
||||
if piece.startswith(f"{choice}"):
|
||||
mentions = piece.replace(f"{choice}:", "").split(",")
|
||||
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
|
||||
for ment in mentions:
|
||||
list_entities.append({ment: choice})
|
||||
structured_output.append(list_entities)
|
||||
return structured_output
|
||||
|
||||
def process_generated_results_CMeIE(pred_file):
|
||||
structured_output = []
|
||||
for line in pred_file:
|
||||
gen_output = line
|
||||
|
||||
# 答案格式:
|
||||
# 每个关系类型占一行,格式为
|
||||
# "具有{lab}关系的头尾实体对如下:头实体为str,尾实体为str;头实体为str,尾实体为str;"
|
||||
|
||||
answer_choices = "相关(导致)、鉴别诊断、遗传因素、发病性别倾向、相关(症状)、手术治疗、预防、辅助检查、筛查、阶段、临床表现、风险评估因素、同义词、发病年龄、预后生存率、病史、传播途径、治疗后症状、药物治疗、辅助治疗、化疗、死亡率、放射治疗、病因、组织学检查、内窥镜检查、多发群体、并发症、实验室检查、就诊科室、病理生理、高危因素、发病率、多发地区、病理分型、影像学检查、转移部位、发病部位、相关(转化)、外侵部位、预后状况、发病机制、多发季节"
|
||||
answer_choices = answer_choices.split('、')
|
||||
list_spos = []
|
||||
assert isinstance(answer_choices, list)
|
||||
list_answer_strs = gen_output.split("\n")
|
||||
|
||||
for line in list_answer_strs:
|
||||
# 首先是解析出label:
|
||||
predicate = line.split("关系的头尾实体对")[0][2: ].strip()
|
||||
line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "")
|
||||
for spo_str in line.split("。"):
|
||||
if len(spo_str.split(",尾实体为")) < 2:
|
||||
continue
|
||||
|
||||
head_mention_str, tail_mention_str = spo_str.split(",尾实体为")[:2]
|
||||
head_mention_str = head_mention_str.replace("头实体为", "").strip()
|
||||
tail_mention_str = tail_mention_str.replace("尾实体为", "").strip()
|
||||
|
||||
list_spos.append(
|
||||
{
|
||||
"predicate": predicate,
|
||||
"subject": head_mention_str,
|
||||
"object": tail_mention_str,
|
||||
}
|
||||
)
|
||||
structured_output.append(list_spos)
|
||||
return structured_output
|
||||
|
||||
def process_generated_results_CDN(pred_file):
|
||||
structured_output = []
|
||||
answer_choices = json.load(open('./data/MedBench/CHIP_CDN/CHIP-CDN_entity.json', 'r'))
|
||||
for line in pred_file:
|
||||
gen_output = line
|
||||
|
||||
# 答案格式:
|
||||
# 多个选中的标准化实体,用 , 符号分割
|
||||
|
||||
answer_str = gen_output.split("\n")[-1]
|
||||
answers = answer_str.split(",")
|
||||
answers = [w.strip() for w in answers if len(w.strip()) > 0]
|
||||
answers = [w for w in answers if w in answer_choices]
|
||||
answers = list(set(answers))
|
||||
answers = [
|
||||
{
|
||||
"entity": w,
|
||||
"type": "normalization",
|
||||
}
|
||||
for w in answers
|
||||
]
|
||||
|
||||
structured_output.append(answers)
|
||||
return structured_output
|
||||
|
||||
def process_generated_results_CDEE(pred_file):
|
||||
|
||||
structured_output = []
|
||||
for line in pred_file:
|
||||
gen_output = line
|
||||
# 答案格式:
|
||||
# 第一行:引导词
|
||||
# 每个事件占一行,事件字段用 ; 分隔, 然后每个字段是 字段名:字段值的格式"
|
||||
# 字段值有多个,则用 ,符号分隔
|
||||
keys = ["主体词", "发生状态", "描述词", "解剖部位"]
|
||||
|
||||
list_answer_strs = gen_output.split("\n")
|
||||
list_events = []
|
||||
for ans_str in list_answer_strs:
|
||||
if '主体词' in ans_str:
|
||||
event_info = {}
|
||||
ans_attrs = ans_str.split(";")
|
||||
for a_attr in ans_attrs:
|
||||
for key in keys:
|
||||
if a_attr.startswith(f"{key}:"):
|
||||
a_attr = a_attr.replace(f"{key}:", "").strip()
|
||||
if key in ["描述词", "解剖部位"]:
|
||||
a_attr_split = a_attr.split(",")
|
||||
a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0]
|
||||
event_info[key] = a_attr_split
|
||||
else:
|
||||
event_info[key] = a_attr
|
||||
|
||||
for key in keys:
|
||||
if key not in event_info:
|
||||
if key in ["描述词", "解剖部位"]:
|
||||
event_info[key] = []
|
||||
else:
|
||||
event_info[key] = ""
|
||||
|
||||
list_events.append(event_info)
|
||||
|
||||
structured_output.append(list_events)
|
||||
return structured_output
|
||||
|
||||
def process_generated_results_CTC(pred_file, task_dataset):
|
||||
structured_output = []
|
||||
|
||||
for line in pred_file:
|
||||
gen_output = line
|
||||
# 答案格式:直接回答分类标签
|
||||
answer_str = gen_output.strip()
|
||||
structured_output.append(answer_str)
|
||||
return structured_output
|
||||
|
||||
def process_generated_results_doc_parsing(pred_file):
|
||||
output = []
|
||||
for line in pred_file:
|
||||
structured_output = {'体温':'', '脉搏':'', '心率':'', '收缩压':'', '舒张压':'', '呼吸':'', '上腹部深压痛':'', '腹部反跳痛':'', '上腹部肿块':''}
|
||||
sentence_list = line.strip().split(',|。|\n')
|
||||
for sentence in sentence_list:
|
||||
if '体温' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
if temp_value:
|
||||
structured_output['体温'] = temp_value.group(0)
|
||||
else:
|
||||
structured_output['体温'] = '未扪及'
|
||||
elif '脉搏' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
if temp_value:
|
||||
structured_output['脉搏'] = temp_value.group(0)
|
||||
else:
|
||||
structured_output['脉搏'] = '未扪及'
|
||||
elif '心率' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
if temp_value:
|
||||
structured_output['心率'] = temp_value.group(0)
|
||||
else:
|
||||
structured_output['心率'] = '未扪及'
|
||||
elif '收缩压' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
if temp_value:
|
||||
structured_output['收缩压'] = temp_value.group(0)
|
||||
else:
|
||||
structured_output['收缩压'] = '未扪及'
|
||||
elif '舒张压' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
if temp_value:
|
||||
structured_output['舒张压'] = temp_value.group(0)
|
||||
else:
|
||||
structured_output['舒张压'] = '未扪及'
|
||||
elif '呼吸' in sentence:
|
||||
temp_value = re.search('[0-9]+', sentence)
|
||||
if temp_value:
|
||||
structured_output['呼吸'] = temp_value.group(0)
|
||||
else:
|
||||
structured_output['呼吸'] = '未扪及'
|
||||
elif '上腹部深压痛' in sentence:
|
||||
if re.search('是|存在|有', sentence):
|
||||
structured_output['是否上腹部深压痛'] = '是'
|
||||
else:
|
||||
structured_output['是否上腹部深压痛'] = '否'
|
||||
elif '腹部反跳痛' in sentence:
|
||||
if re.search('是|存在|有', sentence):
|
||||
structured_output['是否腹部反跳痛'] = '是'
|
||||
else:
|
||||
structured_output['是否腹部反跳痛'] = '否'
|
||||
elif '上腹部肿块' in sentence:
|
||||
if re.search('是|存在|有', sentence):
|
||||
structured_output['上腹部肿块'] = '扪及'
|
||||
else:
|
||||
structured_output['上腹部肿块'] = '未扪及'
|
||||
output.append(structured_output)
|
||||
return output
|
||||
|
||||
def process_generated_results_mrg(pred_file):
|
||||
structured_output = []
|
||||
answer_choices = ['主诉', '现病史', '既往史', '辅助检查', '诊断']
|
||||
for pred in pred_file:
|
||||
list_entities = []
|
||||
for choice in answer_choices:
|
||||
for piece in re.split('[,|.|?|;|,|。|;|\n]', pred):
|
||||
if piece.startswith(f"{choice}实体"):
|
||||
mentions = piece.replace(f"{choice}实体:", "").split(",")
|
||||
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
|
||||
for ment in mentions:
|
||||
list_entities.append({ment: choice})
|
||||
structured_output.append(list_entities)
|
||||
return structured_output
|
||||
|
||||
|
||||
def calc_info_extract_task_scores(list_structured_golden,
|
||||
list_structured_predict):
|
||||
|
||||
assert len(list_structured_golden) == len(list_structured_predict)
|
||||
|
||||
tp = 0
|
||||
fp = 0
|
||||
fn = 0
|
||||
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
|
||||
|
||||
answer_golden = samp_golden
|
||||
answer_predict = samp_predict
|
||||
|
||||
assert isinstance(answer_golden, list)
|
||||
assert isinstance(answer_predict, list), "sample format is wrong!"
|
||||
|
||||
set_golden = set()
|
||||
for inst in answer_golden:
|
||||
assert isinstance(inst, dict)
|
||||
keys = sorted(list(inst.keys()))
|
||||
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys ])
|
||||
# inst = list(inst.items())
|
||||
# inst.sort()
|
||||
# inst = tuple(inst)
|
||||
|
||||
set_golden.add(inst)
|
||||
|
||||
set_predict = set()
|
||||
for inst in answer_predict:
|
||||
assert isinstance(inst, dict)
|
||||
keys = sorted(list(inst.keys()))
|
||||
# inst = tuple([inst[w] for w in keys])
|
||||
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys])
|
||||
|
||||
# inst = list(inst.items())
|
||||
# inst.sort()
|
||||
# inst = tuple(inst)
|
||||
|
||||
set_predict.add(inst)
|
||||
|
||||
# print("set_predict: ", set_predict)
|
||||
# print("set_golden: ", set_golden)
|
||||
|
||||
tp += len(set_golden.intersection(set_predict))
|
||||
fp += len(set_predict.difference(set_golden))
|
||||
fn += len(set_golden.difference(set_predict))
|
||||
|
||||
if tp:
|
||||
precision = tp / (tp + fp)
|
||||
recall = tp / (tp + fn)
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
|
||||
else:
|
||||
precision, recall, f1 = 0, 0, 0
|
||||
|
||||
return precision, recall, f1
|
||||
|
||||
def calc_cls_task_scores(list_structured_golden,
|
||||
list_structured_predict,
|
||||
list_labels=None,
|
||||
return_macro=False,
|
||||
):
|
||||
# types = list_labels
|
||||
# scores = {c: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for c in list_labels + ["ALL"]}
|
||||
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
# Count GT relations and Predicted relations
|
||||
assert len(list_structured_golden) == len(list_structured_predict)
|
||||
n_sents = len(list_structured_golden)
|
||||
|
||||
# Count TP, FP and FN per type
|
||||
for pred_samp, gt_samp in zip(list_structured_predict, list_structured_golden):
|
||||
|
||||
pred_label = pred_samp
|
||||
gt_label = gt_samp
|
||||
assert gt_label != ""
|
||||
if pred_label == "":
|
||||
pred_label = list_labels[0]
|
||||
|
||||
predictions.append(pred_label)
|
||||
ground_truths.append(gt_label)
|
||||
|
||||
# metric
|
||||
cls_report = classification_report(
|
||||
ground_truths, predictions,
|
||||
output_dict=True,
|
||||
zero_division=0,
|
||||
)
|
||||
|
||||
if return_macro:
|
||||
return cls_report["macro avg"]["precision"], \
|
||||
cls_report["macro avg"]["recall"], \
|
||||
cls_report["macro avg"]["f1-score"]
|
||||
else:
|
||||
return cls_report["weighted avg"]["precision"], \
|
||||
cls_report["weighted avg"]["recall"], \
|
||||
cls_report["weighted avg"]["f1-score"]
|
||||
|
||||
def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
|
||||
|
||||
assert len(list_structured_golden) == len(list_structured_predict)
|
||||
|
||||
scores = []
|
||||
predictions = []
|
||||
references = []
|
||||
details = []
|
||||
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
|
||||
# print("samp_golden: ", samp_golden)
|
||||
# print("samp_predict: ", samp_predict)
|
||||
|
||||
# assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!"
|
||||
answer_golden = samp_golden
|
||||
answer_predict = samp_predict
|
||||
|
||||
print('#')
|
||||
print(answer_golden)
|
||||
print(answer_predict)
|
||||
if not (answer_predict and answer_golden):
|
||||
continue
|
||||
|
||||
# basic tokenizer: 拆分中文字,保留英文单词
|
||||
answer_predict = basic_tokenizer.tokenize(answer_predict)
|
||||
answer_golden = basic_tokenizer.tokenize(answer_golden)
|
||||
answer_predict = " ".join(answer_predict).strip()
|
||||
answer_golden = " ".join(answer_golden).strip()
|
||||
if answer_golden.strip() == "":
|
||||
answer_golden = "无 。"
|
||||
if answer_predict.strip() == "":
|
||||
answer_predict = "无 。"
|
||||
# print("answer_predict: ", answer_predict)
|
||||
# print("answer_golden: ", answer_golden)
|
||||
|
||||
predictions.append(answer_predict)
|
||||
references.append(answer_golden)
|
||||
|
||||
details.append({'pred':answer_predict, 'answer':answer_golden, 'correct':False})
|
||||
|
||||
rouge = Rouge()
|
||||
# bleu = evaluate.load('sacrebleu')
|
||||
scores = rouge.get_scores(predictions, references, avg=True)
|
||||
# scores_bleu = bleu.compute(predictions=predictions, references=references)
|
||||
|
||||
rouge1 = scores["rouge-1"]["f"]
|
||||
rouge2 = scores["rouge-2"]["f"]
|
||||
rougeL = scores["rouge-l"]["f"]
|
||||
|
||||
# bleu = sentence_bleu(references, predictions)
|
||||
|
||||
# bert_score = []
|
||||
# for id in range(len(predictions)):
|
||||
# P, R, F1 = score([predictions[i]], [references[i]], model_type='bert-base-chinese', lang="zh", verbose=True)
|
||||
# bert_score.append(F1)
|
||||
# bert_score = float(sum(bert_score)) / float(len(bert_score))
|
||||
# return rougeL, bleu, bert_score
|
||||
return {'RougeL': rougeL, 'details':details}
|
||||
|
||||
def calc_scores_f1(dict_gt, dict_pred):
|
||||
details = []
|
||||
for gt, pred in zip(dict_gt, dict_pred):
|
||||
details.append({'pred':pred, 'answer':gt, 'correct':None})
|
||||
|
||||
precision, recall, f1 = calc_info_extract_task_scores(dict_gt, dict_pred)
|
||||
return {'F1':f1, 'details':details}
|
||||
|
||||
def calc_scores_ctc(dict_gt, dict_pred):
|
||||
details = []
|
||||
for gt, pred in zip(dict_gt, dict_pred):
|
||||
details.append({'pred':pred, 'answer':gt, 'correct':None})
|
||||
|
||||
gts = dict_gt
|
||||
preds = dict_pred
|
||||
|
||||
precision, recall, f1 = calc_cls_task_scores(
|
||||
gts,
|
||||
preds,
|
||||
list_labels=['非上述类型', '疾病', '症状(患者感受)',
|
||||
'体征(医生检测)', '怀孕相关', '肿瘤进展',
|
||||
'疾病分期', '过敏耐受', '器官组织状态',
|
||||
'预期寿命', '口腔相关', '药物',
|
||||
'治疗或手术', '设备', '护理',
|
||||
'诊断', '实验室检查', '风险评估',
|
||||
'受体状态', '年龄', '特殊病人特征',
|
||||
'读写能力', '性别', '教育情况',
|
||||
'居住情况', '种族', '知情同意',
|
||||
'参与其它试验', '研究者决定', '能力',
|
||||
'伦理审查', '依存性', '成瘾行为',
|
||||
'睡眠', '锻炼', '饮食', '酒精使用',
|
||||
'性取向', '吸烟状况', '献血',
|
||||
'病例来源', '残疾群体', '健康群体',
|
||||
'数据可及性', "含有多个类别"],
|
||||
return_macro=True,
|
||||
)
|
||||
return {'Macro-F1':f1, 'details':details}
|
||||
|
||||
def calc_scores_nlg(dict_gt, dict_pred):
|
||||
|
||||
# scores = {}
|
||||
scores = {'score':0, 'details':[]}
|
||||
success_flag = 1
|
||||
|
||||
gts = dict_gt
|
||||
preds = dict_pred
|
||||
# if not len(gts) == len(preds):
|
||||
# success_flag = 0
|
||||
# try:
|
||||
return calc_nlg_task_scores(gts, preds)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_CMeEE(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_CMeEE(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_EMR(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_EMR(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_MRG(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_mrg(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_CMeIE(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_CMeIE(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_CHIP_CDEE(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_CDEE(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_CHIP_CDN(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_CDN(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_CHIP_CTC(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_CTC(predictions)
|
||||
return calc_scores_ctc(predictions, references)[0]
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
predictions = process_generated_results_doc_parsing(predictions)
|
||||
return calc_scores_f1(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_NLG(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
# predictions = process_generated_results_med(predictions)
|
||||
return calc_scores_nlg(predictions, references)
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_Cloze(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
# predictions: [[]]
|
||||
# references: [[]]
|
||||
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
|
||||
details = []
|
||||
cnt = 0
|
||||
|
||||
for pred, ref in zip(predictions, references):
|
||||
detail = {'pred':pred, 'answer':ref, 'correct':False}
|
||||
|
||||
if sum([item in pred for item in ref]) == len(ref):
|
||||
cnt += 1
|
||||
detail['correct'] = True
|
||||
details.append(detail)
|
||||
score = cnt / len(predictions) * 100
|
||||
return {'Accuracy': score, 'details': details}
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MedBenchEvaluator_TF(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
# predictions: [[]]
|
||||
# references: [[]]
|
||||
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
|
||||
details = []
|
||||
cnt = 0
|
||||
|
||||
for pred, ref in zip(predictions, references):
|
||||
|
||||
if '不' in pred or '否' in pred:
|
||||
cur_pred = '不可以'
|
||||
else:
|
||||
cur_pred = '可以'
|
||||
|
||||
detail = {'pred':cur_pred, 'answer':ref, 'correct':False}
|
||||
|
||||
if cur_pred == ref:
|
||||
cnt += 1
|
||||
detail['correct'] = True
|
||||
|
||||
details.append(detail)
|
||||
|
||||
score = cnt / len(predictions) * 100
|
||||
return {'Accuracy': score, 'details': details}
|
198
opencompass/datasets/medbench/post_process.py
Normal file
198
opencompass/datasets/medbench/post_process.py
Normal file
@ -0,0 +1,198 @@
|
||||
# flake8: noqa
|
||||
import json
|
||||
import re
|
||||
|
||||
from . import dataset_loader
|
||||
|
||||
|
||||
def extract_last_line(string):
|
||||
lines = string.split('\n')
|
||||
for item in lines[::-1]:
|
||||
if item.strip() != '':
|
||||
string = item
|
||||
break
|
||||
return string
|
||||
|
||||
|
||||
def remove_few_shot_prefix(string: str):
|
||||
prefix_list = ['The answer is therefore', '答案是']
|
||||
for prefix in prefix_list:
|
||||
if string.startswith(prefix):
|
||||
string = string[len(prefix):].strip()
|
||||
elif prefix in string:
|
||||
index = string.rfind(prefix)
|
||||
if index >= 0:
|
||||
string = string[index + len(prefix):].strip()
|
||||
return string
|
||||
|
||||
|
||||
def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'):
|
||||
if setting_name == 'few-shot-CoT':
|
||||
string = extract_last_line(string)
|
||||
if language == 'en':
|
||||
pattern = 'answer is .*?([A-G])'
|
||||
match = re.search(pattern, string)
|
||||
elif language == 'zh':
|
||||
pattern = '答案是.*?([A-G])'
|
||||
match = re.search(pattern, string)
|
||||
else:
|
||||
raise ValueError('Unknown language {0}'.format(language))
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def try_parse_few_shot_pattern(string: str, dataset_name, setting_name):
|
||||
if setting_name == 'few-shot-CoT':
|
||||
string = extract_last_line(string)
|
||||
if dataset_name in dataset_loader.chinese_cloze_datasets:
|
||||
return string.startswith('答案是')
|
||||
elif dataset_name in dataset_loader.english_cloze_datasets:
|
||||
return string.startswith('The answer is therefore')
|
||||
elif dataset_name in dataset_loader.chinese_qa_datasets:
|
||||
pattern = '答案是.*?([A-G])'
|
||||
match = re.search(pattern, string)
|
||||
return match is not None
|
||||
elif dataset_name in dataset_loader.english_qa_datasets:
|
||||
pattern = 'answer is .*?([A-G])'
|
||||
match = re.search(pattern, string)
|
||||
return match is not None
|
||||
return False
|
||||
|
||||
|
||||
def parse_few_shot_qa_single_answer(string, setting_name, language='en'):
|
||||
answer = try_parse_few_shot_qa_single_answer(string, setting_name,
|
||||
language)
|
||||
if answer is None:
|
||||
return find_first_capital_letter(string)
|
||||
else:
|
||||
return answer
|
||||
|
||||
|
||||
def find_first_capital_letter(answer):
|
||||
letter_set = {'A', 'B', 'C', 'D', 'E', 'F'}
|
||||
for c in answer:
|
||||
if c in letter_set:
|
||||
return c
|
||||
# print("Can't find capital letter in:", answer)
|
||||
return ''
|
||||
|
||||
|
||||
def extract_answer_in_bracket(answer, prefix='【', suffix='】'):
|
||||
if prefix not in answer and suffix not in answer:
|
||||
# print("doesn't found special tokens in:", answer)
|
||||
return ''
|
||||
s = answer.index(prefix) + len(prefix)
|
||||
t = answer.index(suffix)
|
||||
ret = answer[s:t]
|
||||
return ret
|
||||
|
||||
|
||||
def parse_math_answer(setting_name, raw_string):
|
||||
if setting_name == 'few-shot-CoT':
|
||||
raw_string = extract_last_line(raw_string)
|
||||
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
|
||||
raw_string = remove_few_shot_prefix(raw_string)
|
||||
return raw_string
|
||||
|
||||
def remove_boxed(s):
|
||||
left = '\\boxed{'
|
||||
try:
|
||||
assert s[:len(left)] == left
|
||||
assert s[-1] == '}'
|
||||
answer = s[len(left):-1]
|
||||
if '=' in answer:
|
||||
answer = answer.split('=')[-1].lstrip(' ')
|
||||
return answer
|
||||
except:
|
||||
return None
|
||||
|
||||
def last_boxed_only_string(string):
|
||||
idx = string.rfind('\\boxed')
|
||||
if idx < 0:
|
||||
idx = string.rfind('\\fbox')
|
||||
if idx < 0:
|
||||
return None
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == '{':
|
||||
num_left_braces_open += 1
|
||||
if string[i] == '}':
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx == None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx:right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
def get_answer_with_dollar_sign(s):
|
||||
first_pattern = '\$(.*)\$'
|
||||
last_match = None
|
||||
matches = re.findall(first_pattern, s)
|
||||
if matches:
|
||||
last_match = matches[-1]
|
||||
if '=' in last_match:
|
||||
last_match = last_match.split('=')[-1].lstrip(' ')
|
||||
return last_match
|
||||
|
||||
def get_answer_without_dollar_sign(s):
|
||||
last_match = None
|
||||
if '=' in s:
|
||||
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
|
||||
if '\\n' in last_match:
|
||||
last_match = last_match.split('\\n')[0]
|
||||
else:
|
||||
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
|
||||
matches = re.findall(pattern, s)
|
||||
if matches:
|
||||
last_match = matches[-1]
|
||||
return last_match
|
||||
|
||||
raw_string = remove_few_shot_prefix(raw_string)
|
||||
if '\\boxed' in raw_string:
|
||||
answer = remove_boxed(last_boxed_only_string(raw_string))
|
||||
else:
|
||||
answer = get_answer_with_dollar_sign(raw_string)
|
||||
if not answer:
|
||||
answer = get_answer_without_dollar_sign(raw_string)
|
||||
return answer
|
||||
|
||||
|
||||
def parse_qa_multiple_answer(string):
|
||||
# if setting_name == 'few-shot-CoT':
|
||||
# string = extract_last_line(string)
|
||||
pattern = '\(*([A-Z])\)*'
|
||||
match = re.findall(pattern, string)
|
||||
if match:
|
||||
return match
|
||||
return []
|
||||
|
||||
|
||||
def post_process(dataset_name, setting_name, prediction):
|
||||
if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets:
|
||||
return parse_math_answer(setting_name, prediction)
|
||||
|
||||
if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']:
|
||||
return parse_qa_multiple_answer(prediction, setting_name)
|
||||
|
||||
# all other datasets are QA problems with single answer
|
||||
if 'zero-shot' in setting_name:
|
||||
answer = find_first_capital_letter(prediction)
|
||||
return answer
|
||||
|
||||
# all other datasets are QA problems with single answer and setting_name are few-shot
|
||||
language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh'
|
||||
if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets:
|
||||
return parse_few_shot_qa_single_answer(prediction, setting_name,
|
||||
language)
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset name {dataset_name}')
|
43
opencompass/datasets/medbench/utils.py
Normal file
43
opencompass/datasets/medbench/utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
# flake8: noqa
|
||||
import json
|
||||
|
||||
|
||||
def read_jsonl(path):
|
||||
with open(path, encoding='utf8') as fh:
|
||||
results = []
|
||||
for line in fh:
|
||||
if line is None:
|
||||
continue
|
||||
try:
|
||||
results.append(json.loads(line) if line != 'null' else line)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(path)
|
||||
print(line)
|
||||
raise e
|
||||
return results
|
||||
|
||||
|
||||
def save_jsonl(lines, directory):
|
||||
with open(directory, 'w', encoding='utf8') as f:
|
||||
for line in lines:
|
||||
f.write(json.dumps(line, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
def extract_answer(js):
|
||||
try:
|
||||
if js is None or js == 'null':
|
||||
return ''
|
||||
answer = ''
|
||||
if isinstance(js, str):
|
||||
answer = js
|
||||
elif 'text' in js['choices'][0]:
|
||||
answer = js['choices'][0]['text']
|
||||
else:
|
||||
answer = js['choices'][0]['message']['content']
|
||||
# answer = js['']
|
||||
return answer
|
||||
except Exception as e:
|
||||
# print(e)
|
||||
# print(js)
|
||||
return ''
|
Loading…
Reference in New Issue
Block a user