diff --git a/configs/datasets/mmlu/mmlu_all_sets.py b/configs/datasets/mmlu/mmlu_all_sets.py new file mode 100644 index 00000000..9d902613 --- /dev/null +++ b/configs/datasets/mmlu/mmlu_all_sets.py @@ -0,0 +1,59 @@ +mmlu_all_sets = [ + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_physics", + "electrical_engineering", + "astronomy", + "anatomy", + "abstract_algebra", + "machine_learning", + "clinical_knowledge", + "global_facts", + "management", + "nutrition", + "marketing", + "professional_accounting", + "high_school_geography", + "international_law", + "moral_scenarios", + "computer_security", + "high_school_microeconomics", + "professional_law", + "medical_genetics", + "professional_psychology", + "jurisprudence", + "world_religions", + "philosophy", + "virology", + "high_school_chemistry", + "public_relations", + "high_school_macroeconomics", + "human_sexuality", + "elementary_mathematics", + "high_school_physics", + "high_school_computer_science", + "high_school_european_history", + "business_ethics", + "moral_disputes", + "high_school_statistics", + "miscellaneous", + "formal_logic", + "high_school_government_and_politics", + "prehistory", + "security_studies", + "high_school_biology", + "logical_fallacies", + "high_school_world_history", + "professional_medicine", + "high_school_mathematics", + "college_medicine", + "high_school_us_history", + "sociology", + "econometrics", + "high_school_psychology", + "human_aging", + "us_foreign_policy", + "conceptual_physics", +] diff --git a/configs/datasets/mmlu/mmlu_openai_simple_evals_gen_b618ea.py b/configs/datasets/mmlu/mmlu_openai_simple_evals_gen_b618ea.py new file mode 100644 index 00000000..b4ee62b0 --- /dev/null +++ b/configs/datasets/mmlu/mmlu_openai_simple_evals_gen_b618ea.py @@ -0,0 +1,59 @@ +from mmengine.config import read_base +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import MMLUDataset +from opencompass.utils.text_postprocessors import match_answer_pattern + +with read_base(): + from .mmlu_all_sets import mmlu_all_sets + +# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader +# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar + +QUERY_TEMPLATE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{input} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +mmlu_reader_cfg = dict( + input_columns=["input", "A", "B", "C", "D"], + output_column="target", + train_split='dev') + +mmlu_datasets = [] +for name in mmlu_all_sets: + mmlu_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role="HUMAN", prompt=QUERY_TEMPLATE), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + + mmlu_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=match_answer_pattern, answer_pattern=r"(?i)ANSWER\s*:\s*([A-D])")) + + mmlu_datasets.append( + dict( + abbr=f"lukaemon_mmlu_{name}", + type=MMLUDataset, + path="./data/mmlu/", + name=name, + reader_cfg=mmlu_reader_cfg, + infer_cfg=mmlu_infer_cfg, + eval_cfg=mmlu_eval_cfg, + )) diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index 968b4c34..23ce8b77 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -182,3 +182,9 @@ def general_eval_wrapper_postprocess(text: str, return postprocess(text, **kwargs) else: return text + + +def match_answer_pattern(response_text: str, answer_pattern: str): + match = re.search(answer_pattern, response_text) + extracted_answer = match.group(1) if match else '' + return extracted_answer