[Feat] update postprocessor to get first option more accurately (#193)

* [Feat] update postprocessor to get first option

* minor fix

* minor fix
This commit is contained in:
Hubert 2023-08-11 17:33:00 +08:00 committed by GitHub
parent 14332e08fd
commit 8d9cee060f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 72 additions and 43 deletions

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import ARCDataset from opencompass.datasets import ARCDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
ARC_c_reader_cfg = dict( ARC_c_reader_cfg = dict(
input_columns=["question", "textA", "textB", "textC", "textD"], input_columns=["question", "textA", "textB", "textC", "textD"],
@ -28,7 +28,7 @@ ARC_c_infer_cfg = dict(
ARC_c_eval_cfg = dict( ARC_c_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
ARC_c_datasets = [ ARC_c_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import ARCDataset from opencompass.datasets import ARCDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
ARC_e_reader_cfg = dict( ARC_e_reader_cfg = dict(
input_columns=["question", "textA", "textB", "textC", "textD"], input_columns=["question", "textA", "textB", "textC", "textD"],
@ -28,7 +28,7 @@ ARC_e_infer_cfg = dict(
ARC_e_eval_cfg = dict( ARC_e_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
ARC_e_datasets = [ ARC_e_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AXDataset_V2 from opencompass.datasets import AXDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
AX_b_reader_cfg = dict( AX_b_reader_cfg = dict(
input_columns=["sentence1", "sentence2"], input_columns=["sentence1", "sentence2"],
@ -28,7 +28,7 @@ AX_b_infer_cfg = dict(
AX_b_eval_cfg = dict( AX_b_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
AX_b_datasets = [ AX_b_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AXDataset_V2 from opencompass.datasets import AXDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
AX_g_reader_cfg = dict( AX_g_reader_cfg = dict(
input_columns=["hypothesis", "premise"], input_columns=["hypothesis", "premise"],
@ -28,7 +28,7 @@ AX_g_infer_cfg = dict(
AX_g_eval_cfg = dict( AX_g_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
AX_g_datasets = [ AX_g_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import CBDataset_V2 from opencompass.datasets import CBDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
CB_reader_cfg = dict( CB_reader_cfg = dict(
input_columns=["premise", "hypothesis"], input_columns=["premise", "hypothesis"],
@ -29,7 +29,7 @@ CB_infer_cfg = dict(
CB_eval_cfg = dict( CB_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABC'),
) )
CB_datasets = [ CB_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import COPADataset_V2 from opencompass.datasets import COPADataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
COPA_reader_cfg = dict( COPA_reader_cfg = dict(
input_columns=["question", "premise", "choice1", "choice2"], input_columns=["question", "premise", "choice1", "choice2"],
@ -29,7 +29,7 @@ COPA_infer_cfg = dict(
COPA_eval_cfg = dict( COPA_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
COPA_datasets = [ COPA_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import MultiRCDataset_V2 from opencompass.datasets import MultiRCDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
MultiRC_reader_cfg = dict( MultiRC_reader_cfg = dict(
input_columns=["question", "text", "answer"], input_columns=["question", "text", "answer"],
@ -28,7 +28,7 @@ MultiRC_infer_cfg = dict(
MultiRC_eval_cfg = dict( MultiRC_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
MultiRC_datasets = [ MultiRC_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AXDataset_V2 from opencompass.datasets import AXDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
RTE_reader_cfg = dict( RTE_reader_cfg = dict(
input_columns=["hypothesis", "premise"], input_columns=["hypothesis", "premise"],
@ -28,7 +28,7 @@ RTE_infer_cfg = dict(
RTE_eval_cfg = dict( RTE_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
RTE_datasets = [ RTE_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AGIEvalDataset_v2, AGIEvalEvaluator from opencompass.datasets import AGIEvalDataset_v2, AGIEvalEvaluator
from opencompass.utils.text_postprocessors import first_capital_postprocess, first_capital_postprocess_multi from opencompass.utils.text_postprocessors import first_option_postprocess, first_capital_postprocess_multi
agieval_reader_cfg = dict( agieval_reader_cfg = dict(
input_columns=['question', 'options'], output_column='label') input_columns=['question', 'options'], output_column='label')
@ -76,14 +76,16 @@ for _name in agieval_single_choice_sets:
prompt_template=dict( prompt_template=dict(
type=PromptTemplate, type=PromptTemplate,
template=dict(round=[ template=dict(round=[
dict(role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}') dict(
role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}')
])), ])),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024)) inferencer=dict(type=GenInferencer, max_out_len=1024))
agieval_eval_cfg = dict( agieval_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess)) pred_postprocessor=dict(
type=first_option_postprocess, options='ABCDE'))
agieval_datasets.append( agieval_datasets.append(
dict( dict(
@ -105,7 +107,8 @@ for _name in agieval_multiple_choices_sets:
prompt_template=dict( prompt_template=dict(
type=PromptTemplate, type=PromptTemplate,
template=dict(round=[ template=dict(round=[
dict(role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}') dict(
role='HUMAN', prompt=f'{{question}}\n{{options}}\n{_hint}')
])), ])),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024)) inferencer=dict(type=GenInferencer, max_out_len=1024))

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import hellaswagDataset_V2 from opencompass.datasets import hellaswagDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
hellaswag_reader_cfg = dict( hellaswag_reader_cfg = dict(
input_columns=["ctx", "A", "B", "C", "D"], input_columns=["ctx", "A", "B", "C", "D"],
@ -16,11 +16,10 @@ hellaswag_infer_cfg = dict(
template=dict(round=[ template=dict(round=[
dict( dict(
role="HUMAN", role="HUMAN",
prompt=( prompt=("{ctx}\nQuestion: Which ending makes the most sense?\n"
"{ctx}\nQuestion: Which ending makes the most sense?\n" "A. {A}\nB. {B}\nC. {C}\nD. {D}\n"
"A. {A}\nB. {B}\nC. {C}\nD. {D}\n" "You may choose from 'A', 'B', 'C', 'D'.\n"
"You may choose from 'A', 'B', 'C', 'D'.\n" "Answer:"),
"Answer:"),
), ),
]), ]),
), ),
@ -31,7 +30,7 @@ hellaswag_infer_cfg = dict(
hellaswag_eval_cfg = dict( hellaswag_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
hellaswag_datasets = [ hellaswag_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import OBQADataset from opencompass.datasets import OBQADataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
_input_columns = [ _input_columns = [
["question_stem", "A", "B", "C", "D"], ["question_stem", "A", "B", "C", "D"],
@ -14,14 +14,16 @@ _template = [
round=[ round=[
dict( dict(
role="HUMAN", role="HUMAN",
prompt="Question: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:" prompt=
"Question: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:"
), ),
], ), ], ),
dict( dict(
round=[ round=[
dict( dict(
role="HUMAN", role="HUMAN",
prompt="Given the fact: {fact1}\nQuestion: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:", prompt=
"Given the fact: {fact1}\nQuestion: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:",
), ),
], ), ], ),
] ]
@ -46,16 +48,14 @@ for _i in range(2):
obqa_reader_cfg = dict( obqa_reader_cfg = dict(
input_columns=_input_columns[_i], output_column="answerKey") input_columns=_input_columns[_i], output_column="answerKey")
obqa_infer_cfg = dict( obqa_infer_cfg = dict(
prompt_template=dict( prompt_template=dict(type=PromptTemplate, template=_template[_i]),
type=PromptTemplate,
template=_template[_i]),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer), inferencer=dict(type=GenInferencer),
) )
obqa_eval_cfg = dict( obqa_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )
obqa_datasets[_i]["reader_cfg"] = obqa_reader_cfg obqa_datasets[_i]["reader_cfg"] = obqa_reader_cfg

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import piqaDataset_V2 from opencompass.datasets import piqaDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
piqa_reader_cfg = dict( piqa_reader_cfg = dict(
input_columns=["goal", "sol1", "sol2"], input_columns=["goal", "sol1", "sol2"],
@ -15,7 +15,9 @@ piqa_infer_cfg = dict(
type=PromptTemplate, type=PromptTemplate,
template=dict( template=dict(
round=[ round=[
dict(role="HUMAN", prompt="{goal}\nA. {sol1}\nB. {sol2}\nAnswer:") dict(
role="HUMAN",
prompt="{goal}\nA. {sol1}\nB. {sol2}\nAnswer:")
], ), ], ),
), ),
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
@ -25,7 +27,7 @@ piqa_infer_cfg = dict(
piqa_eval_cfg = dict( piqa_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
piqa_datasets = [ piqa_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import RaceDataset from opencompass.datasets import RaceDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
race_reader_cfg = dict( race_reader_cfg = dict(
input_columns=['article', 'question', 'A', 'B', 'C', 'D'], input_columns=['article', 'question', 'A', 'B', 'C', 'D'],
@ -24,7 +24,7 @@ race_infer_cfg = dict(
race_eval_cfg = dict( race_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
pred_role='BOT') pred_role='BOT')
race_datasets = [ race_datasets = [

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import storyclozeDataset_V2 from opencompass.datasets import storyclozeDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
storycloze_reader_cfg = dict( storycloze_reader_cfg = dict(
input_columns=["context", "sentence_quiz1", "sentence_quiz2"], input_columns=["context", "sentence_quiz1", "sentence_quiz2"],
@ -28,7 +28,7 @@ storycloze_infer_cfg = dict(
storycloze_eval_cfg = dict( storycloze_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
# The original story cloze dataset and repo are not long maintaining. # The original story cloze dataset and repo are not long maintaining.

View File

@ -3,7 +3,7 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import winograndeDataset_V2 from opencompass.datasets import winograndeDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
winogrande_reader_cfg = dict( winogrande_reader_cfg = dict(
input_columns=["opt1", "opt2"], input_columns=["opt1", "opt2"],
@ -28,7 +28,7 @@ winogrande_infer_cfg = dict(
winogrande_eval_cfg = dict( winogrande_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
pred_role="BOT", pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )
winogrande_datasets = [ winogrande_datasets = [

View File

@ -48,6 +48,31 @@ def first_capital_postprocess(text: str) -> str:
return '' return ''
def first_option_postprocess(text: str, options) -> str:
"""Find first valid option for text."""
patterns = [
f'[Tt]he answer is [{options}]',
f'[Tt]he correct answer is [{options}]',
f'答案是(.*?)[{options}]',
f'答案为(.*?)[{options}]',
f'固选(.*?)[{options}]',
f'答案应该是(.*?)[{options}]',
f'(\s|^)[{options}][\s。,\.$]', # noqa
f'[{options}]',
]
regexes = [re.compile(pattern) for pattern in patterns]
for regex in regexes:
match = regex.search(text)
if match:
outputs = match.group(0)
for i in options:
if i in outputs:
return i
return ''
@TEXT_POSTPROCESSORS.register_module('first-capital-multi') @TEXT_POSTPROCESSORS.register_module('first-capital-multi')
def first_capital_postprocess_multi(text: str) -> str: def first_capital_postprocess_multi(text: str) -> str:
match = re.search(r'([A-D]+)', text) match = re.search(r'([A-D]+)', text)