[Feature] Add medbench (#678)

* update medbench

* medbench update

* format medbench

* format

---------

Co-authored-by: 施晓明 <PJLAB\shixiaoming@pjnl104220118l.pjlab.org>
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Xiaoming Shi 2023-12-09 16:05:46 +08:00 committed by GitHub
parent 7cb53a95fa
commit 1bf85949ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1705 additions and 2 deletions

View File

@ -5,7 +5,8 @@ exclude: |
opencompass/utils/internal/|
opencompass/openicl/icl_evaluator/hf_metrics/|
opencompass/datasets/lawbench/utils|
opencompass/datasets/lawbench/evaluation_functions/
opencompass/datasets/lawbench/evaluation_functions/|
opencompass/datasets/medbench
)
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8

View File

@ -5,7 +5,8 @@ exclude: |
opencompass/utils/internal/|
opencompass/openicl/icl_evaluator/hf_metrics/|
opencompass/datasets/lawbench/utils|
opencompass/datasets/lawbench/evaluation_functions/
opencompass/datasets/lawbench/evaluation_functions/|
opencompass/datasets/medbench/
)
repos:
- repo: https://github.com/PyCQA/flake8

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .medbench_gen_d44f24 import medbench_datasets # noqa: F401, F403

View File

@ -0,0 +1,160 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import (
MedBenchDataset,
MedBenchEvaluator,
MedBenchEvaluator_Cloze,
MedBenchEvaluator_IE,
MedBenchEvaluator_mcq,
MedBenchEvaluator_CMeEE,
MedBenchEvaluator_CMeIE,
MedBenchEvaluator_CHIP_CDEE,
MedBenchEvaluator_CHIP_CDN,
MedBenchEvaluator_CHIP_CTC,
MedBenchEvaluator_NLG,
MedBenchEvaluator_TF,
MedBenchEvaluator_EMR,
)
from opencompass.utils.text_postprocessors import first_capital_postprocess
medbench_reader_cfg = dict(
input_columns=['problem_input'], output_column='label')
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题用acc判断
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA有标答
medbench_cloze_sets = ['Triage'] # 限定域QA有标答
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致用F1评价
#, 'CMeIE', 'CHIP_CDEE', 'CHIP_CDN', 'CHIP_CTC', 'Doc_parsing', 'MRG'
medbench_datasets = []
for name in medbench_single_choice_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator_TF), pred_role="BOT")
medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))
for name in medbench_multiple_choices_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator), pred_role="BOT")
medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))
for name in medbench_qa_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator_NLG), pred_role="BOT")
medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))
for name in medbench_cloze_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator_Cloze), pred_role="BOT")
medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))
for name in medbench_ie_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))
medbench_eval_cfg = dict(
evaluator=dict(type=eval('MedBenchEvaluator_'+name)), pred_role="BOT")
medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))
del name, medbench_infer_cfg, medbench_eval_cfg

View File

@ -56,6 +56,7 @@ from .longbench import * # noqa: F401, F403
from .math import * # noqa: F401, F403
from .mathbench import * # noqa: F401, F403
from .mbpp import * # noqa: F401, F403
from .medbench import * # noqa: F401, F403
from .mmlu import * # noqa: F401, F403
from .multirc import * # noqa: F401, F403
from .narrativeqa import * # noqa: F401, F403

View File

@ -0,0 +1,3 @@
# flake8: noqa
from .medbench import * # noqa: F401, F403

View File

@ -0,0 +1,104 @@
# flake8: noqa
import pandas as pd
class TaskSchema(object):
def __init__(self,
passage=None,
question=None,
options=None,
label=None,
answer=None,
other=None):
self.passage = passage
self.question = question
self.options = options
self.label = label
self.answer = answer
self.other = other
def to_dict(self):
return {
'passage': self.passage,
'question': self.question,
'options': self.options,
'label': self.label,
'answer': self.answer,
'other': self.other
}
# define README.json
class MedBenchInstance(object):
def __init__(self, task_description, data_source, task_schema, output,
evaluation_metric, task_example):
self.task_description = task_description
self.data_source = data_source
self.task_schema = task_schema
self.output = output
self.evaluation_metric = evaluation_metric
self.task_example = task_example
def to_dict(self):
return {
'task description': self.task_description,
'data source': self.data_source,
'task schema': self.task_schema.to_dict(),
'output': self.output,
'evaluation metric': self.evaluation_metric,
'task example': self.task_example
}
class ChatGPTSchema(object):
def __init__(self, context=None, metadata=''):
self.context = context
self.metadata = metadata
def to_dict(self):
return {'context': self.context, 'metadata': self.metadata}
class ResultsForHumanSchema(object):
def __init__(self,
index,
problem_input,
label,
model_input='',
model_output='',
parse_result='',
first_stage_output='',
second_stage_input='',
is_correct=False):
self.index = index
self.problem_input = problem_input
self.model_input = model_input
self.model_output = model_output
self.parse_result = parse_result
self.label = label
self.first_stage_output = first_stage_output
self.second_stage_input = second_stage_input
self.is_correct = is_correct
def to_dict(self):
return {
'index': self.index,
'problem_input': self.problem_input,
'model_input': self.model_input,
'model_output': self.model_output,
'parse_result': self.parse_result,
'label': self.label,
'is_correct': self.is_correct,
'first_stage_output': self.first_stage_output,
'second_stage_input': self.second_stage_input,
}
@staticmethod
def to_tsv(result_list, path):
result_json = [item.to_dict() for item in result_list]
table = pd.json_normalize(result_json)
table.to_excel(path, index=False)

View File

@ -0,0 +1,338 @@
# flake8: noqa
import ast
import json
import os
import pandas as pd
import tiktoken
from tqdm import tqdm
from .constructions import ChatGPTSchema, ResultsForHumanSchema
from .utils import extract_answer, read_jsonl, save_jsonl
# define the datasets
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题用acc判断
medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA有标答
medbench_cloze_sets = ['Triage'] # 限定域QA有标答
medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答
medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致用F1评价
def convert_zero_shot(line, dataset_name):
# passage = line['passage'] if line['passage'] is not None else ''
if dataset_name in medbench_qa_sets:
return line['question']
elif dataset_name in medbench_cloze_sets:
return '问题:' + line['question'] + '\n答案:'
elif dataset_name in medbench_multiple_choices_sets:
return '问题:' + line['question'] + ' ' \
+ '选项:' + ' '.join(line['options']) + '\n从A到G我们应该选择'
else:
return line['question']
prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n'
# def convert_zero_shot_CoT_stage1(line, dataset_name):
# try:
# passage = line['passage'] if line['passage'] is not None else ''
# if dataset_name in english_qa_datasets:
# return passage + 'Q: ' + line['question'] + ' ' \
# + 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \
# "Let's think step by step."
# elif dataset_name in chinese_qa_datasets:
# option_string = 'ABCDEFG'
# count = len(line['options'])
# if count == 1:
# count = 4
# return passage + '问题:' + line['question'] + ' ' \
# + '选项:' + ' '.join(line['options']) + '\n' + \
# '从A到{}, 我们应选择什么?让我们逐步思考:'.format(option_string[count - 1])
# elif dataset_name in english_cloze_datasets:
# return passage + 'Q: ' + line['question'] + '\n' \
# "A: Let's think step by step."
# elif dataset_name in chinese_cloze_datasets:
# return passage + '问题:' + line['question'] + '\n' \
# '答案:让我们逐步思考:'
# except NameError:
# print('Dataset not defined.')
# process few-shot raw_prompts
def combine_prompt(prompt_path,
dataset_name,
load_explanation=True,
chat_mode=False):
skip_passage = False
if dataset_name == 'sat-en-without-passage':
skip_passage = True
dataset_name = 'sat-en'
demostrations = []
# read the prompts by context and explanation
context_row = [0, 1, 3, 5, 7, 9]
explanation_row = [0, 2, 4, 6, 8, 10]
raw_prompts_context = pd.read_csv(prompt_path,
header=0,
skiprows=lambda x: x not in context_row,
keep_default_na=False)
raw_prompts_explanation = pd.read_csv(
prompt_path,
header=0,
skiprows=lambda x: x not in explanation_row,
keep_default_na=False).replace(r'\n\n', '\n', regex=True)
contexts = []
for line in list(raw_prompts_context[dataset_name]):
if line:
# print(line)
contexts.append(ast.literal_eval(line))
explanations = [
exp for exp in raw_prompts_explanation[dataset_name] if exp
]
for idx, (con, exp) in enumerate(zip(contexts, explanations)):
passage = con['passage'] if con[
'passage'] is not None and not skip_passage else ''
question = con['question']
options = con['options'] if con['options'] is not None else ''
label = con['label'] if con['label'] is not None else ''
answer = con[
'answer'] if 'answer' in con and con['answer'] is not None else ''
if dataset_name in qa_datasets:
question_input = '问题 {}. '.format(idx + 1) + passage + ' ' + question + '\n' \
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ '答案是 {}'.format(label)
elif dataset_name in cloze_datasets:
question_input = '问题 {}. '.format(idx + 1) + question + '\n'
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ '答案是 {}'.format(answer)
else:
raise ValueError(
f'During loading few-sot examples, found unknown dataset: {dataset_name}'
)
if chat_mode:
demostrations.append((question_input, question_output))
else:
demostrations.append(question_input + question_output + '\n')
return demostrations
enc = None
def _lazy_load_enc():
global enc
if enc is None:
enc = tiktoken.encoding_for_model('gpt-4')
# cut prompt if reach max token length
def concat_prompt(demos,
dataset_name,
max_tokens,
end_of_example='\n',
verbose=False):
_lazy_load_enc()
demostration_en = 'Here are the answers for the problems in the exam.\n'
demostration_zh = '以下是考试中各个问题的答案。\n'
for i in range(len(demos)):
# print(len(enc.encode(demostration_en)), len(enc.encode(demostration_zh)))
if dataset_name in english_qa_datasets:
demostration_en = demostration_en + demos[i] + end_of_example
elif dataset_name in chinese_qa_datasets:
demostration_zh = demostration_zh + demos[i] + end_of_example
elif dataset_name in english_cloze_datasets:
demostration_en = demostration_en + demos[i] + end_of_example
elif dataset_name in chinese_cloze_datasets:
demostration_zh = demostration_zh + demos[i] + end_of_example
# break if reach max token limit
if len(enc.encode(demostration_en)) < max_tokens and len(
enc.encode(demostration_zh)) < max_tokens:
output = demostration_en if len(demostration_en) > len(
demostration_zh) else demostration_zh
prompt_num = i + 1
else:
break
if verbose:
print('max_tokens set as ', max_tokens, 'actual_tokens is',
len(enc.encode(output)), 'num_shot is', prompt_num)
return output, prompt_num
def concat_prompt_chat_mode(demos,
dataset_name,
max_tokens,
end_of_example='\n',
verbose=False):
_lazy_load_enc()
answers = []
sentences = ''
for i in range(len(demos)):
answers += [
{
'role': 'user',
'content': demos[i][0]
},
{
'role': 'assistant',
'content': demos[i][1]
},
]
sentences += json.dumps(answers[-1])
# break if reach max token limit
if len(enc.encode(sentences)) > max_tokens:
answers.pop()
answers.pop()
break
if verbose:
print('max_tokens set as ', max_tokens, 'actual_tokens is',
len(enc.encode(sentences)), 'num_shot is',
len(answers) // 2)
return answers, len(answers) // 2
def convert_few_shot(line, dataset_name, demo, n_shot, chat_mode=False):
passage = line['passage'] if line['passage'] is not None else ''
question = line['question']
options = line['options'] if line['options'] is not None else ''
if dataset_name in qa_datasets:
question_input = '问题 {}. '.format(n_shot + 1) + passage + ' ' + question + '\n' \
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
# + "问题 {}的解析: ".format(n_shot + 1)
if dataset_name in cloze_datasets:
question_input = '问题 {}. '.format(n_shot + 1) + question + '\n'
# + "问题 {}的解析: ".format(n_shot + 1)
if chat_mode:
return demo + [
{
'role': 'user',
'content': question_input
},
]
else:
return demo + question_input
def load_dataset(dataset_name,
setting_name,
parent_path,
prompt_path=None,
max_tokens=None,
end_of_example='\n',
chat_mode=False,
verbose=False):
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
loaded_jsonl = read_jsonl(test_path)
processed = []
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
# process demo once if it is few-shot-CoT
processed_demos = combine_prompt(
prompt_path,
dataset_name,
load_explanation=setting_name == 'few-shot-CoT',
chat_mode=chat_mode)
if chat_mode:
chosen_prompt, n_shot = concat_prompt_chat_mode(processed_demos,
dataset_name,
max_tokens,
end_of_example,
verbose=verbose)
else:
chosen_prompt, n_shot = concat_prompt(processed_demos,
dataset_name,
max_tokens,
end_of_example,
verbose=verbose)
if verbose:
loaded_jsonl = tqdm(loaded_jsonl)
for meta_idx, line in enumerate(loaded_jsonl):
# 正确
if setting_name == 'zero-shot':
ctxt = convert_zero_shot(line, dataset_name)
elif setting_name == 'zero-shot-CoT':
ctxt = convert_zero_shot_CoT_stage1(line, dataset_name)
elif setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
ctxt = convert_few_shot(line, dataset_name, chosen_prompt, n_shot,
chat_mode)
try:
new_instance = ChatGPTSchema(context=ctxt, metadata=meta_idx)
processed.append(new_instance.to_dict())
except NameError:
print('Dataset not defined.')
return processed
def generate_second_stage_input(dataset_name,
input_list,
output_list,
with_format_prompt=False):
try:
chinese_format_prompt = '根据以上内容你的任务是把最终的答案提取出来并填在【】中例如【0】或者【A】。'
if dataset_name in qa_datasets:
prompt_suffix = '因此从A到D, 我们应选择'
if with_format_prompt:
prompt_suffix = chinese_format_prompt + prompt_suffix
elif dataset_name in cloze_datasets:
prompt_suffix = '因此,答案是'
if with_format_prompt:
prompt_suffix = chinese_format_prompt + prompt_suffix
except NameError:
print('Dataset not defined.')
processed = []
for i in range(len(input_list)):
ctxt = '{0}\n{1}\n{2}'.format(input_list[i]['context'],
extract_answer(output_list[i]),
prompt_suffix)
new_instance = ChatGPTSchema(context=ctxt,
metadata=input_list[i]['metadata'])
processed.append(new_instance.to_dict())
return processed
def load_dataset_as_result_schema(dataset_name, parent_path):
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
loaded_jsonl = read_jsonl(test_path)
processed = []
for i, line in enumerate(loaded_jsonl):
problem_input = convert_zero_shot(line, dataset_name)
processed.append(
ResultsForHumanSchema(
index=i,
problem_input=problem_input,
# label=line['label'] if line['label'] else line['answer']
label = line['answer']
))
return processed
if __name__ == '__main__':
# set variables
parent_dir = '../../data/exam_guidance'
# set dataset name to process
setting_name = 'zero-shot' # setting_name can be chosen from ["zero-shot", "zero-shot-CoT", "few-shot-CoT"]
data_name = 'health_exam'
save_dir = '../../experiment_input/{}/'.format(setting_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
processed_data = load_dataset(data_name,
setting_name,
parent_dir,
prompt_path=raw_prompt_path,
max_tokens=2048)
save_jsonl(processed_data,
os.path.join(save_dir, '{}.jsonl'.format(data_name)))

View File

@ -0,0 +1,43 @@
# flake8: noqa
from . import dataset_loader, utils
from .math_equivalence import is_equiv
def convert_to_set(item):
if isinstance(item, list):
return set(item)
if isinstance(item, str):
return {item}
if item is None:
return {}
raise ValueError("Input can't parse:", item)
def evaluate_single_sample(dataset_name, prediction, label):
if dataset_name in dataset_loader.multi_choice_datasets:
p = convert_to_set(prediction)
l = convert_to_set(label)
return p == l
elif dataset_name in dataset_loader.math_output_datasets:
return is_equiv(prediction, label)
else:
return prediction == label
# def evaluate(dataset_name, prediction_list, label_list):
# correct = 0
# if dataset_name in multi_choice_datasets:
# for prediction, label in zip(prediction_list, label_list):
# p = convert_to_set(prediction)
# l = convert_to_set(label)
# if p == l:
# correct += 1
# elif dataset_name in math_output_datasets:
# for prediction, label in zip(prediction_list, label_list):
# if is_equiv(prediction, label):
# correct += 1
# else:
# for prediction, label in zip(prediction_list, label_list):
# if prediction == label:
# correct += 1
# return "{0:.2%}".format(correct / len(label_list))

View File

@ -0,0 +1,161 @@
# flake8: noqa
# code from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
def _fix_fracs(string):
substrs = string.split('\\frac')
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += '\\frac'
if substr[0] == '{':
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != '{':
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}{' + b + '}' + post_substr
else:
new_str += '{' + a + '}{' + b + '}'
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}' + b + post_substr
else:
new_str += '{' + a + '}' + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split('/')) != 2:
return string
a = string.split('/')[0]
b = string.split('/')[1]
try:
a = int(a)
b = int(b)
assert string == '{}/{}'.format(a, b)
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string
except:
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if '\\text{ ' in string:
splits = string.split('\\text{ ')
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if '\\sqrt' not in string:
return string
splits = string.split('\\sqrt')
new_string = splits[0]
for split in splits[1:]:
if split[0] != '{':
a = split[0]
new_substr = '\\sqrt{' + a + '}' + split[1:]
else:
new_substr = '\\sqrt' + split
new_string += new_substr
return new_string
def _strip_string(string):
# linebreaks
string = string.replace('\n', '')
# print(string)
# remove inverse spaces
string = string.replace('\\!', '')
# print(string)
# replace \\ with \
string = string.replace('\\\\', '\\')
# print(string)
# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')
# print(string)
# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')
# print(string)
# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')
# remove dollar signs
string = string.replace('\\$', '')
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '')
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(' ', '')
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == '0.5':
string = '\\frac{1}{2}'
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print('WARNING: Both None')
return True
if str1 is None or str2 is None:
return False
try:
ss1 = _strip_string(str1)
ss2 = _strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except:
return str1 == str2

View File

@ -0,0 +1,646 @@
import json
import os.path as osp
import sys
from datasets import Dataset
from sklearn.metrics import classification_report
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from ..base import BaseDataset
from .math_equivalence import is_equiv
from .post_process import parse_math_answer, parse_qa_multiple_answer
import evaluate
from nltk.translate.bleu_score import sentence_bleu
# from bert_score import score
import re
from transformers import BasicTokenizer
from rouge_chinese import Rouge
basic_tokenizer = BasicTokenizer(tokenize_chinese_chars=True)
@LOAD_DATASET.register_module()
class MedBenchDataset(BaseDataset):
@staticmethod
def load(path: str, name: str, setting_name: str):
from .dataset_loader import load_dataset, load_dataset_as_result_schema
assert setting_name in 'zero-shot', 'only support zero-shot setting'
dataset_wo_label = load_dataset(name, setting_name, path)
dataset_with_label = load_dataset_as_result_schema(name, path)
dataset = []
for d1, d2 in zip(dataset_wo_label, dataset_with_label):
dataset.append({
'id': d2.index,
'problem_input': d1['context'],
'label': d2.label,
})
dataset = Dataset.from_list(dataset)
return dataset
@LOAD_DATASET.register_module()
class MedBenchDataset_v2(BaseDataset):
@staticmethod
def load(path: str, name: str, setting_name: str):
assert setting_name in 'zero-shot', 'only support zero-shot setting'
filename = osp.join(path, name + '.jsonl')
with open(filename, encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
dataset = []
for item in data:
passage = item['passage'] if item['passage'] else ''
question = passage + item['question']
options = '\n'.join(item['options']) if item['options'] else ''
if item['label']:
if isinstance(item['label'], list):
label = ''.join(item['label'])
else:
label = item['label']
else:
label = item['answer']
d = {'question': question, 'options': options, 'label': label}
dataset.append(d)
dataset = Dataset.from_list(dataset)
return dataset
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator(BaseEvaluator):
def score(self, predictions, references):
# predictions: [[]]
# references: [[]]
predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred': pred, 'answer': ref, 'correct': False}
if is_equiv(pred, ref):
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
#输出字典类型 {'score':'', 'details'}
return {'Accuracy': score, 'details': details}
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_mcq(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred': pred, 'answer': ref, 'correct': False}
if pred == ref:
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'score': score, 'details': details}
def process_generated_results_CMeEE(pred_file):
structured_output = []
answer_choices = ['药物', '设备', '医院科室', '微生物类', '身体部位', '医疗操作', '医学检验项目', '症状', '疾病']
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('[,|.|。||\n]', pred):
if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}实体为", "").replace(f"{choice}实体是", "").replace(f"{choice}实体:", "").split("")
for ment in mentions:
list_entities.append({'entity':ment, 'type':choice})
structured_output.append(list_entities)
return structured_output
def process_generated_results_EMR(pred_file):
structured_output = []
answer_choices = ['主诉', '现病史', '既往史', '个人史', '婚育史', '家族史']
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('[,|.|?|;||。||\n]', pred):
if piece.startswith(f"{choice}"):
mentions = piece.replace(f"{choice}", "").split("")
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
for ment in mentions:
list_entities.append({ment: choice})
structured_output.append(list_entities)
return structured_output
def process_generated_results_CMeIE(pred_file):
structured_output = []
for line in pred_file:
gen_output = line
# 答案格式:
# 每个关系类型占一行,格式为
# "具有{lab}关系的头尾实体对如下头实体为str尾实体为str头实体为str尾实体为str"
answer_choices = "相关(导致)、鉴别诊断、遗传因素、发病性别倾向、相关(症状)、手术治疗、预防、辅助检查、筛查、阶段、临床表现、风险评估因素、同义词、发病年龄、预后生存率、病史、传播途径、治疗后症状、药物治疗、辅助治疗、化疗、死亡率、放射治疗、病因、组织学检查、内窥镜检查、多发群体、并发症、实验室检查、就诊科室、病理生理、高危因素、发病率、多发地区、病理分型、影像学检查、转移部位、发病部位、相关(转化)、外侵部位、预后状况、发病机制、多发季节"
answer_choices = answer_choices.split('')
list_spos = []
assert isinstance(answer_choices, list)
list_answer_strs = gen_output.split("\n")
for line in list_answer_strs:
# 首先是解析出label:
predicate = line.split("关系的头尾实体对")[0][2: ].strip()
line = line.replace(f"具有{predicate}关系的头尾实体对如下:", "")
for spo_str in line.split(""):
if len(spo_str.split(",尾实体为")) < 2:
continue
head_mention_str, tail_mention_str = spo_str.split(",尾实体为")[:2]
head_mention_str = head_mention_str.replace("头实体为", "").strip()
tail_mention_str = tail_mention_str.replace("尾实体为", "").strip()
list_spos.append(
{
"predicate": predicate,
"subject": head_mention_str,
"object": tail_mention_str,
}
)
structured_output.append(list_spos)
return structured_output
def process_generated_results_CDN(pred_file):
structured_output = []
answer_choices = json.load(open('./data/MedBench/CHIP_CDN/CHIP-CDN_entity.json', 'r'))
for line in pred_file:
gen_output = line
# 答案格式:
# 多个选中的标准化实体,用 符号分割
answer_str = gen_output.split("\n")[-1]
answers = answer_str.split("")
answers = [w.strip() for w in answers if len(w.strip()) > 0]
answers = [w for w in answers if w in answer_choices]
answers = list(set(answers))
answers = [
{
"entity": w,
"type": "normalization",
}
for w in answers
]
structured_output.append(answers)
return structured_output
def process_generated_results_CDEE(pred_file):
structured_output = []
for line in pred_file:
gen_output = line
# 答案格式:
# 第一行:引导词
# 每个事件占一行,事件字段用 分隔, 然后每个字段是 字段名:字段值的格式"
# 字段值有多个,则用 ,符号分隔
keys = ["主体词", "发生状态", "描述词", "解剖部位"]
list_answer_strs = gen_output.split("\n")
list_events = []
for ans_str in list_answer_strs:
if '主体词' in ans_str:
event_info = {}
ans_attrs = ans_str.split("")
for a_attr in ans_attrs:
for key in keys:
if a_attr.startswith(f"{key}"):
a_attr = a_attr.replace(f"{key}", "").strip()
if key in ["描述词", "解剖部位"]:
a_attr_split = a_attr.split("")
a_attr_split = [w.strip() for w in a_attr_split if len(w.strip()) > 0]
event_info[key] = a_attr_split
else:
event_info[key] = a_attr
for key in keys:
if key not in event_info:
if key in ["描述词", "解剖部位"]:
event_info[key] = []
else:
event_info[key] = ""
list_events.append(event_info)
structured_output.append(list_events)
return structured_output
def process_generated_results_CTC(pred_file, task_dataset):
structured_output = []
for line in pred_file:
gen_output = line
# 答案格式:直接回答分类标签
answer_str = gen_output.strip()
structured_output.append(answer_str)
return structured_output
def process_generated_results_doc_parsing(pred_file):
output = []
for line in pred_file:
structured_output = {'体温':'', '脉搏':'', '心率':'', '收缩压':'', '舒张压':'', '呼吸':'', '上腹部深压痛':'', '腹部反跳痛':'', '上腹部肿块':''}
sentence_list = line.strip().split('|。|\n')
for sentence in sentence_list:
if '体温' in sentence:
temp_value = re.search('[0-9]+', sentence)
if temp_value:
structured_output['体温'] = temp_value.group(0)
else:
structured_output['体温'] = '未扪及'
elif '脉搏' in sentence:
temp_value = re.search('[0-9]+', sentence)
if temp_value:
structured_output['脉搏'] = temp_value.group(0)
else:
structured_output['脉搏'] = '未扪及'
elif '心率' in sentence:
temp_value = re.search('[0-9]+', sentence)
if temp_value:
structured_output['心率'] = temp_value.group(0)
else:
structured_output['心率'] = '未扪及'
elif '收缩压' in sentence:
temp_value = re.search('[0-9]+', sentence)
if temp_value:
structured_output['收缩压'] = temp_value.group(0)
else:
structured_output['收缩压'] = '未扪及'
elif '舒张压' in sentence:
temp_value = re.search('[0-9]+', sentence)
if temp_value:
structured_output['舒张压'] = temp_value.group(0)
else:
structured_output['舒张压'] = '未扪及'
elif '呼吸' in sentence:
temp_value = re.search('[0-9]+', sentence)
if temp_value:
structured_output['呼吸'] = temp_value.group(0)
else:
structured_output['呼吸'] = '未扪及'
elif '上腹部深压痛' in sentence:
if re.search('是|存在|有', sentence):
structured_output['是否上腹部深压痛'] = ''
else:
structured_output['是否上腹部深压痛'] = ''
elif '腹部反跳痛' in sentence:
if re.search('是|存在|有', sentence):
structured_output['是否腹部反跳痛'] = ''
else:
structured_output['是否腹部反跳痛'] = ''
elif '上腹部肿块' in sentence:
if re.search('是|存在|有', sentence):
structured_output['上腹部肿块'] = '扪及'
else:
structured_output['上腹部肿块'] = '未扪及'
output.append(structured_output)
return output
def process_generated_results_mrg(pred_file):
structured_output = []
answer_choices = ['主诉', '现病史', '既往史', '辅助检查', '诊断']
for pred in pred_file:
list_entities = []
for choice in answer_choices:
for piece in re.split('[,|.|?|;||。||\n]', pred):
if piece.startswith(f"{choice}实体"):
mentions = piece.replace(f"{choice}实体:", "").split("")
mentions = [w.strip() for w in mentions if len(w.strip()) > 0]
for ment in mentions:
list_entities.append({ment: choice})
structured_output.append(list_entities)
return structured_output
def calc_info_extract_task_scores(list_structured_golden,
list_structured_predict):
assert len(list_structured_golden) == len(list_structured_predict)
tp = 0
fp = 0
fn = 0
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
answer_golden = samp_golden
answer_predict = samp_predict
assert isinstance(answer_golden, list)
assert isinstance(answer_predict, list), "sample format is wrong!"
set_golden = set()
for inst in answer_golden:
assert isinstance(inst, dict)
keys = sorted(list(inst.keys()))
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys ])
# inst = list(inst.items())
# inst.sort()
# inst = tuple(inst)
set_golden.add(inst)
set_predict = set()
for inst in answer_predict:
assert isinstance(inst, dict)
keys = sorted(list(inst.keys()))
# inst = tuple([inst[w] for w in keys])
inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys])
# inst = list(inst.items())
# inst.sort()
# inst = tuple(inst)
set_predict.add(inst)
# print("set_predict: ", set_predict)
# print("set_golden: ", set_golden)
tp += len(set_golden.intersection(set_predict))
fp += len(set_predict.difference(set_golden))
fn += len(set_golden.difference(set_predict))
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
return precision, recall, f1
def calc_cls_task_scores(list_structured_golden,
list_structured_predict,
list_labels=None,
return_macro=False,
):
# types = list_labels
# scores = {c: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for c in list_labels + ["ALL"]}
predictions = []
ground_truths = []
# Count GT relations and Predicted relations
assert len(list_structured_golden) == len(list_structured_predict)
n_sents = len(list_structured_golden)
# Count TP, FP and FN per type
for pred_samp, gt_samp in zip(list_structured_predict, list_structured_golden):
pred_label = pred_samp
gt_label = gt_samp
assert gt_label != ""
if pred_label == "":
pred_label = list_labels[0]
predictions.append(pred_label)
ground_truths.append(gt_label)
# metric
cls_report = classification_report(
ground_truths, predictions,
output_dict=True,
zero_division=0,
)
if return_macro:
return cls_report["macro avg"]["precision"], \
cls_report["macro avg"]["recall"], \
cls_report["macro avg"]["f1-score"]
else:
return cls_report["weighted avg"]["precision"], \
cls_report["weighted avg"]["recall"], \
cls_report["weighted avg"]["f1-score"]
def calc_nlg_task_scores(list_structured_golden, list_structured_predict):
assert len(list_structured_golden) == len(list_structured_predict)
scores = []
predictions = []
references = []
details = []
for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict):
# print("samp_golden: ", samp_golden)
# print("samp_predict: ", samp_predict)
# assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!"
answer_golden = samp_golden
answer_predict = samp_predict
print('#')
print(answer_golden)
print(answer_predict)
if not (answer_predict and answer_golden):
continue
# basic tokenizer: 拆分中文字,保留英文单词
answer_predict = basic_tokenizer.tokenize(answer_predict)
answer_golden = basic_tokenizer.tokenize(answer_golden)
answer_predict = " ".join(answer_predict).strip()
answer_golden = " ".join(answer_golden).strip()
if answer_golden.strip() == "":
answer_golden = "无 。"
if answer_predict.strip() == "":
answer_predict = "无 。"
# print("answer_predict: ", answer_predict)
# print("answer_golden: ", answer_golden)
predictions.append(answer_predict)
references.append(answer_golden)
details.append({'pred':answer_predict, 'answer':answer_golden, 'correct':False})
rouge = Rouge()
# bleu = evaluate.load('sacrebleu')
scores = rouge.get_scores(predictions, references, avg=True)
# scores_bleu = bleu.compute(predictions=predictions, references=references)
rouge1 = scores["rouge-1"]["f"]
rouge2 = scores["rouge-2"]["f"]
rougeL = scores["rouge-l"]["f"]
# bleu = sentence_bleu(references, predictions)
# bert_score = []
# for id in range(len(predictions)):
# P, R, F1 = score([predictions[i]], [references[i]], model_type='bert-base-chinese', lang="zh", verbose=True)
# bert_score.append(F1)
# bert_score = float(sum(bert_score)) / float(len(bert_score))
# return rougeL, bleu, bert_score
return {'RougeL': rougeL, 'details':details}
def calc_scores_f1(dict_gt, dict_pred):
details = []
for gt, pred in zip(dict_gt, dict_pred):
details.append({'pred':pred, 'answer':gt, 'correct':None})
precision, recall, f1 = calc_info_extract_task_scores(dict_gt, dict_pred)
return {'F1':f1, 'details':details}
def calc_scores_ctc(dict_gt, dict_pred):
details = []
for gt, pred in zip(dict_gt, dict_pred):
details.append({'pred':pred, 'answer':gt, 'correct':None})
gts = dict_gt
preds = dict_pred
precision, recall, f1 = calc_cls_task_scores(
gts,
preds,
list_labels=['非上述类型', '疾病', '症状(患者感受)',
'体征(医生检测)', '怀孕相关', '肿瘤进展',
'疾病分期', '过敏耐受', '器官组织状态',
'预期寿命', '口腔相关', '药物',
'治疗或手术', '设备', '护理',
'诊断', '实验室检查', '风险评估',
'受体状态', '年龄', '特殊病人特征',
'读写能力', '性别', '教育情况',
'居住情况', '种族', '知情同意',
'参与其它试验', '研究者决定', '能力',
'伦理审查', '依存性', '成瘾行为',
'睡眠', '锻炼', '饮食', '酒精使用',
'性取向', '吸烟状况', '献血',
'病例来源', '残疾群体', '健康群体',
'数据可及性', "含有多个类别"],
return_macro=True,
)
return {'Macro-F1':f1, 'details':details}
def calc_scores_nlg(dict_gt, dict_pred):
# scores = {}
scores = {'score':0, 'details':[]}
success_flag = 1
gts = dict_gt
preds = dict_pred
# if not len(gts) == len(preds):
# success_flag = 0
# try:
return calc_nlg_task_scores(gts, preds)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CMeEE(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CMeEE(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_EMR(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_EMR(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_MRG(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_mrg(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CMeIE(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CMeIE(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CHIP_CDEE(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CDEE(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CHIP_CDN(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CDN(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_CHIP_CTC(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_CTC(predictions)
return calc_scores_ctc(predictions, references)[0]
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Doc_parsing(BaseEvaluator):
def score(self, predictions, references):
predictions = process_generated_results_doc_parsing(predictions)
return calc_scores_f1(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_NLG(BaseEvaluator):
def score(self, predictions, references):
# predictions = process_generated_results_med(predictions)
return calc_scores_nlg(predictions, references)
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_Cloze(BaseEvaluator):
def score(self, predictions, references):
# predictions: [[]]
# references: [[]]
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
detail = {'pred':pred, 'answer':ref, 'correct':False}
if sum([item in pred for item in ref]) == len(ref):
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'Accuracy': score, 'details': details}
@ICL_EVALUATORS.register_module()
class MedBenchEvaluator_TF(BaseEvaluator):
def score(self, predictions, references):
# predictions: [[]]
# references: [[]]
# predictions = [parse_qa_multiple_answer(pred) for pred in predictions]
details = []
cnt = 0
for pred, ref in zip(predictions, references):
if '' in pred or '' in pred:
cur_pred = '不可以'
else:
cur_pred = '可以'
detail = {'pred':cur_pred, 'answer':ref, 'correct':False}
if cur_pred == ref:
cnt += 1
detail['correct'] = True
details.append(detail)
score = cnt / len(predictions) * 100
return {'Accuracy': score, 'details': details}

View File

@ -0,0 +1,198 @@
# flake8: noqa
import json
import re
from . import dataset_loader
def extract_last_line(string):
lines = string.split('\n')
for item in lines[::-1]:
if item.strip() != '':
string = item
break
return string
def remove_few_shot_prefix(string: str):
prefix_list = ['The answer is therefore', '答案是']
for prefix in prefix_list:
if string.startswith(prefix):
string = string[len(prefix):].strip()
elif prefix in string:
index = string.rfind(prefix)
if index >= 0:
string = string[index + len(prefix):].strip()
return string
def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if language == 'en':
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
elif language == 'zh':
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
else:
raise ValueError('Unknown language {0}'.format(language))
if match:
return match.group(1)
else:
return None
def try_parse_few_shot_pattern(string: str, dataset_name, setting_name):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if dataset_name in dataset_loader.chinese_cloze_datasets:
return string.startswith('答案是')
elif dataset_name in dataset_loader.english_cloze_datasets:
return string.startswith('The answer is therefore')
elif dataset_name in dataset_loader.chinese_qa_datasets:
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
return match is not None
elif dataset_name in dataset_loader.english_qa_datasets:
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
return match is not None
return False
def parse_few_shot_qa_single_answer(string, setting_name, language='en'):
answer = try_parse_few_shot_qa_single_answer(string, setting_name,
language)
if answer is None:
return find_first_capital_letter(string)
else:
return answer
def find_first_capital_letter(answer):
letter_set = {'A', 'B', 'C', 'D', 'E', 'F'}
for c in answer:
if c in letter_set:
return c
# print("Can't find capital letter in:", answer)
return ''
def extract_answer_in_bracket(answer, prefix='', suffix=''):
if prefix not in answer and suffix not in answer:
# print("doesn't found special tokens in:", answer)
return ''
s = answer.index(prefix) + len(prefix)
t = answer.index(suffix)
ret = answer[s:t]
return ret
def parse_math_answer(setting_name, raw_string):
if setting_name == 'few-shot-CoT':
raw_string = extract_last_line(raw_string)
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
raw_string = remove_few_shot_prefix(raw_string)
return raw_string
def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
answer = s[len(left):-1]
if '=' in answer:
answer = answer.split('=')[-1].lstrip(' ')
return answer
except:
return None
def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx == None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def get_answer_with_dollar_sign(s):
first_pattern = '\$(.*)\$'
last_match = None
matches = re.findall(first_pattern, s)
if matches:
last_match = matches[-1]
if '=' in last_match:
last_match = last_match.split('=')[-1].lstrip(' ')
return last_match
def get_answer_without_dollar_sign(s):
last_match = None
if '=' in s:
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
if '\\n' in last_match:
last_match = last_match.split('\\n')[0]
else:
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
matches = re.findall(pattern, s)
if matches:
last_match = matches[-1]
return last_match
raw_string = remove_few_shot_prefix(raw_string)
if '\\boxed' in raw_string:
answer = remove_boxed(last_boxed_only_string(raw_string))
else:
answer = get_answer_with_dollar_sign(raw_string)
if not answer:
answer = get_answer_without_dollar_sign(raw_string)
return answer
def parse_qa_multiple_answer(string):
# if setting_name == 'few-shot-CoT':
# string = extract_last_line(string)
pattern = '\(*([A-Z])\)*'
match = re.findall(pattern, string)
if match:
return match
return []
def post_process(dataset_name, setting_name, prediction):
if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets:
return parse_math_answer(setting_name, prediction)
if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']:
return parse_qa_multiple_answer(prediction, setting_name)
# all other datasets are QA problems with single answer
if 'zero-shot' in setting_name:
answer = find_first_capital_letter(prediction)
return answer
# all other datasets are QA problems with single answer and setting_name are few-shot
language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh'
if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets:
return parse_few_shot_qa_single_answer(prediction, setting_name,
language)
else:
raise ValueError(f'Unsupported dataset name {dataset_name}')

View File

@ -0,0 +1,43 @@
# flake8: noqa
import json
def read_jsonl(path):
with open(path, encoding='utf8') as fh:
results = []
for line in fh:
if line is None:
continue
try:
results.append(json.loads(line) if line != 'null' else line)
except Exception as e:
print(e)
print(path)
print(line)
raise e
return results
def save_jsonl(lines, directory):
with open(directory, 'w', encoding='utf8') as f:
for line in lines:
f.write(json.dumps(line, ensure_ascii=False) + '\n')
def extract_answer(js):
try:
if js is None or js == 'null':
return ''
answer = ''
if isinstance(js, str):
answer = js
elif 'text' in js['choices'][0]:
answer = js['choices'][0]['text']
else:
answer = js['choices'][0]['message']['content']
# answer = js['']
return answer
except Exception as e:
# print(e)
# print(js)
return ''