diff --git a/configs/datasets/mmlu/mmlu_gen_general_postprocess.py b/configs/datasets/mmlu/mmlu_gen_general_postprocess.py new file mode 100644 index 00000000..3f876c66 --- /dev/null +++ b/configs/datasets/mmlu/mmlu_gen_general_postprocess.py @@ -0,0 +1,124 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import MMLUDataset +from opencompass.utils.text_postprocessors import general_post_capital_postprocess + +# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader +# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar + +mmlu_reader_cfg = dict( + input_columns=["input", "A", "B", "C", "D"], + output_column="target", + train_split='dev') + +mmlu_all_sets = [ + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_physics", + "electrical_engineering", + "astronomy", + "anatomy", + "abstract_algebra", + "machine_learning", + "clinical_knowledge", + "global_facts", + "management", + "nutrition", + "marketing", + "professional_accounting", + "high_school_geography", + "international_law", + "moral_scenarios", + "computer_security", + "high_school_microeconomics", + "professional_law", + "medical_genetics", + "professional_psychology", + "jurisprudence", + "world_religions", + "philosophy", + "virology", + "high_school_chemistry", + "public_relations", + "high_school_macroeconomics", + "human_sexuality", + "elementary_mathematics", + "high_school_physics", + "high_school_computer_science", + "high_school_european_history", + "business_ethics", + "moral_disputes", + "high_school_statistics", + "miscellaneous", + "formal_logic", + "high_school_government_and_politics", + "prehistory", + "security_studies", + "high_school_biology", + "logical_fallacies", + "high_school_world_history", + "professional_medicine", + "high_school_mathematics", + "college_medicine", + "high_school_us_history", + "sociology", + "econometrics", + "high_school_psychology", + "human_aging", + "us_foreign_policy", + "conceptual_physics", +] + +mmlu_datasets = [] +for _name in mmlu_all_sets: + _hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.' + mmlu_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role="HUMAN", + prompt= + f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: " + ), + dict(role="BOT", prompt="{target}\n") + ]), + ), + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin="", + round=[ + dict( + role="HUMAN", + prompt= + f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: " + ), + ], + ), + ice_token="", + ), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), + ) + + mmlu_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=general_post_capital_postprocess)) + + mmlu_datasets.append( + dict( + abbr=f"lukaemon_mmlu_{_name}", + type=MMLUDataset, + path="./data/mmlu/", + name=_name, + reader_cfg=mmlu_reader_cfg, + infer_cfg=mmlu_infer_cfg, + eval_cfg=mmlu_eval_cfg, + )) + +del _name, _hint diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index 7b478402..8039e935 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -56,6 +56,28 @@ def last_capital_postprocess(text: str) -> str: return t return '' +@TEXT_POSTPROCESSORS.register_module('general_post_capital') +def general_post_capital_postprocess(text: str) -> str: + patterns = [ + "the answer is ([A-E])", + "the answer is([A-E])", + "the correct answer is ([A-E])", + "the correct answer is([A-E])", + "Answer: ([A-E])", + "Answer: \(([A-E])\)", + "Option \(([A-E])\)", + "Answer:([A-E])", + "Option ([A-E])", + "Opt ([A-E])" + ] + for pattern in patterns: + match = re.search(pattern,text,re.IGNORECASE) + if match: + return match.group(1) + match = re.findall("[A-D]", text) + if match: + return match[0] + return "" def first_option_postprocess(text: str, options: str, cushion=True) -> str: """Find first valid option for text."""