support supergpqa

This commit is contained in:
mkj3085003 2025-03-07 09:36:00 +00:00
parent 73c80953c6
commit 4e40563462
18 changed files with 2694 additions and 0 deletions

View File

@ -0,0 +1,14 @@
from mmengine import read_base
with read_base():
# from opencompass.configs.datasets.supergpqa.supergpqa_mixed_gen_d00bdd import \
# supergpqa_mixed_datasets as mixed_datasets
from opencompass.configs.datasets.supergpqa.supergpqa_single_0_shot_gen import \
supergpqa_0shot_single_datasets as zero_shot_datasets
# from opencompass.configs.datasets.supergpqa.supergpqa_single_3_shot_gen import \
# supergpqa_3shot_single_datasets as three_shot_datasets
from opencompass.configs.models.hf_internlm.hf_internlm2_5_7b import \
models as hf_internlm2_5_7b
datasets = zero_shot_datasets
models = hf_internlm2_5_7b

View File

@ -0,0 +1,55 @@
from opencompass.datasets.supergpqa.supergpqa import SuperGPQADataset, SuperGPQAEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
supergpqa_0shot_single_datasets = []
prompt_template = dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='HUMAN',
prompt=''
)
],
round=[
dict(
role='HUMAN',
prompt='{infer_prompt}' # f-string
)
]
)
)
# Reader configuration
reader_cfg = dict(
input_columns=['infer_prompt'],
output_column='answer_letter',
)
# Inference configuration
infer_cfg = dict(
prompt_template=prompt_template,
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024),
)
# Evaluation configuration
eval_cfg = dict(
evaluator=dict(type=SuperGPQAEvaluator),
pred_role='BOT',
)
supergpqa_dataset = dict(
type=SuperGPQADataset,
abbr='supergpqa',
path='opencompass/supergpqa',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
# print(type(supergpqa_0shot_single_datasets))
supergpqa_0shot_single_datasets.append(supergpqa_dataset)

View File

@ -0,0 +1,347 @@
from opencompass.datasets.supergpqa.supergpqa import SuperGPQADataset, SuperGPQAEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
categories =[
"Power Systems and Automation",
"Anesthesiology",
"Oncology",
"Group Theory",
"Thermal Energy Engineering",
"Emergency Medicine",
"Systems Science",
"Geometry and Topology",
"Advanced Algebra",
"Electrical Theory and New Technologies",
"Engineering Thermophysics",
"Operating Systems",
"Guidance, Navigation and Control",
"Harmony",
"Marine Biology",
"Pediatrics",
"Road and Railway Engineering",
"Information Management and Communication",
"Water conservancy and Hydropower Engineering",
"Veterinary Medicine",
"Astronomical Observation and Technology",
"Special Number Theory",
"Philology and Bibliography",
"Textile Materials Science",
"Legal Theory and Legal History",
"Education Economics, Management and Social Security",
"Traditional Chinese Health Preservation",
"Epidemiology and Health Statistics",
"Pitch and Scales",
"Economic History",
"Marine Engineering",
"Labor Economics",
"Materials Processing Engineering",
"Demography and Anthropology",
"Preschool Education",
"Music History, Education, and Technology",
"Instrumentation and Performance",
"Cryptography",
"Mineralogy, Petrology, and Economic Geology",
"Microbiology and Biochemical Pharmacy",
"Poromechanics and Reservoir Physics",
"Imaging and Nuclear Medicine",
"Solid State Physics",
"Microelectronics and Solid-State Electronics",
"Zoology",
"Food Biochemistry",
"Traditional Chinese Pharmacy",
"Neurology",
"Hydrogeology",
"Criminal Law",
"Radiation Medicine",
"Relativity",
"Analytical Chemistry",
"Signal and Information Processing",
"Military Command and Information Systems",
"Literary Theory",
"Textile Chemistry and Dyeing Engineering",
"Urban Infrastructure Engineering",
"Stellar and Interstellar Evolution",
"Geological Resources and Geological Engineering",
"Pattern Recognition",
"Engineering Fluid Mechanics",
"Communication and Information Systems",
"Architectural History",
"Stochastic Processes",
"Microbiology",
"French Language and Literature",
"Principles of Computer Organization",
"Architectural Design and Theory",
"Animal Rearing and Breeding",
"Physical Oceanography",
"Acoustics",
"Organic Chemistry",
"Refrigeration and Cryogenic Engineering",
"Public Finance",
"Dermatology and Venereology",
"Religious Studies",
"Discrete Mathematics",
"Forest Cultivation and Genetic Breeding",
"Vehicle Operation Engineering",
"Physical Chemistry",
"Nutrition and Food Hygiene",
"Ship Mechanics and Design Principles",
"Power Electronics and Electrical Drives",
"Finance",
"Pharmacology",
"Environmental Engineering",
"Ecology",
"Aeronautical and Astronautical Science and Technology",
"Agricultural Mechanization Engineering",
"Computer Architecture",
"Political Economy",
"Principles of Seismic Exploration",
"Elements of Chemical Reaction Engineering",
"Digital Surveying and Remote Sensing Applications",
"History and Theory of Journalism and Media Management",
"Instrument Science and Technology",
"Structural Engineering",
"Computer Networks",
"Power Machinery and Engineering",
"Constitutional and Administrative Law",
"Law and Social Governance",
"Psychology",
"Urban Planning and Design",
"Thermodynamics and Statistical Physics",
"Chemical Transport Engineering",
"Environmental and Resource Protection",
"Fluid Machinery and Engineering",
"Cartography and Geographic Information Engineering",
"Computational Mathematics",
"Pathogen Biology",
"Human Geography",
"Theoretical Optics",
"Solid Mechanics",
"Electrochemistry",
"Aquaculture",
"Logic",
"Mechatronic Engineering",
"Modern and Contemporary Chinese Literature",
"Operations Research and Cybernetics",
"Circuits and Systems",
"Internal Combustion Engineering",
"Atomic and Molecular Physics",
"Marine Chemistry",
"Electromagnetic Field and Microwave Technology",
"Rigid Body Mechanics",
"Physiology",
"Military Chemistry and Pyrotechnics",
"Fundamentals of Dynamics and Control",
"Control Theory and Control Engineering",
"Historical Geography",
"Physical Geography",
"National and Defense Economics",
"Polymer Physics",
"Landscape Plants and Ornamental Horticulture",
"Solar System Science",
"Library and Archival Science",
"Internal Medicine",
"Physical Chemistry of Metallurgical Process",
"Antenna and Radio Communication",
"Genetics",
"Graph Theory",
"Principles of Metallurgy",
"Bridge and Tunnel Engineering",
"Combinatorial Mathematics",
"Otorhinolaryngology",
"Political Science",
"Medicinal Chemistry",
"Health Toxicology and Environmental Health",
"Archaeology and Museology",
"Geotechnical Engineering",
"Land Resource Management and Administrative Management",
"Thermodynamics",
"Atmospheric Physics and Atmospheric Environment",
"Broadcasting and Television Art",
"Numerical Analysis",
"Statistical Mechanics",
"Mineral Processing Engineering",
"Mathematical Analysis",
"Philosophy of Science and Technology",
"Western Economics",
"Data Structures",
"Fine Arts",
"Economic Statistics",
"Environmental Science",
"Military Thought and History",
"Drama and Opera Studies",
"Film Studies",
"High Voltage and Insulation Technology",
"Military Law",
"Wood Science and Technology",
"Obstetrics and Gynecology",
"Hydraulics and Hydrology",
"Cell Biology",
"Biochemistry and Molecular Biology",
"Fluid Flow and Heat Transfer in Chemical Engineering",
"Formal Languages",
"Optoelectronic Technology",
"Crop Science",
"Fundamental Mathematics",
"Immunology",
"Surgery",
"Ophthalmology",
"Social Medicine and Health Management",
"Industrial Economics",
"Traffic Information Engineering and Control",
"Traditional Chinese Medicine Theory",
"Polymer Chemistry and Physics",
"Maternal, Child and Adolescent Health",
"Radiation Protection and Nuclear Technology Applications",
"Food Processing and Storage Engineering",
"Fluid Physics",
"Materials Physics and Chemistry",
"Pharmaceutical Analysis",
"Semiconductor Physics",
"Optical Fiber Communication",
"Ethics",
"Psychiatry and Mental Health",
"Management Science and Engineering",
"Number Theory",
"Contract Law",
"Inorganic Chemistry",
"Design Arts",
"Human Anatomy and Histology-Embryology",
"Iron and Steel Metallurgy",
"Dance Studies",
"Structural Geology",
"Special Education",
"Musical Forms and Analysis",
"Philosophical Aesthetics",
"Astrophysics",
"Manufacturing Automation",
"Quantum Mechanics",
"Probability and Statistics",
"Military Logistics and Equipment",
"Heat Transfer",
"Classical Chinese Literature",
"Information Management Science",
"Cosmology",
"Educational Technology and Principles",
"Ordinary Differential Equations",
"Underwater Acoustics",
"Business and Accounting Management",
"Dynamic Meteorology",
"Military Management",
"Journalism and News Practice",
"Animal Nutrition and Feed Science",
"Applied Optics",
"Theoretical Fluid Mechanics",
"Communication Principles",
"Physical Education and Training",
"Geodesy and Surveying Engineering",
"Meteorology",
"Sports Science and Medicine",
"Solid Earth Geophysics",
"Particle and Nuclear Physics",
"International Law",
"Oil and Gas Field Development and Storage & Transportation Engineering",
"Basic Stomatology",
"Agricultural Environment and Soil-Water Engineering",
"Geochemistry",
"Procedural Law",
"Botany",
"Fuzzy Mathematics",
"Paleontology and Stratigraphy",
"Sports Humanities and Sociology",
"Civil and Commercial Law",
"Electrodynamics",
"Mining and Safety Engineering",
"Mass Transport and Separation Process in Chemical Engineering",
"Advanced Programming Languages",
"Laser Technology",
"Weapon Systems Science and Engineering",
"Quantitative Economics",
"Theoretical Mechanics",
"Nursing and Rehabilitation Medicine",
"Databases",
"Pharmaceutics",
"Space physics",
"Functions of Real Variables",
"Non-ferrous Metallurgy",
"Theory of Curriculum and Instruction",
"Clinical Laboratory Diagnostics",
"Clinical Stomatology",
"Literary History",
"Tourism Management and Technological Economics Management",
"Communication and Broadcasting",
"Pathology and Pathophysiology",
"Functions of Complex Variables",
"World History",
"Forest Engineering",
"Forensic Medicine",
"Linguistics and Applied Linguistics",
"Social and Folklore Studies",
"Computer Software and Theory",
"Subatomic and Atomic Physics",
"Biophysics",
"Radiochemistry",
"Russian Language and Literature",
"International Trade",
"Geriatric Medicine",
"Composition",
"Transportation Planning and Management",
"Polynomials and Series Expansions",
"Nuclear Energy and Reactor Technology"
]
supergpqa_0shot_single_datasets = []
for category in categories:
prompt_template = dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='HUMAN',
prompt=''
)
],
round=[
dict(
role='HUMAN',
prompt='{infer_prompt}' # f-string
)
]
)
)
# Reader configuration
reader_cfg = dict(
input_columns=['infer_prompt'],
output_column='answer_letter',
)
# Inference configuration
infer_cfg = dict(
prompt_template=prompt_template,
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024),
)
# Evaluation configuration
eval_cfg = dict(
evaluator=dict(type=SuperGPQAEvaluator),
pred_role='BOT',
)
supergpqa_dataset = dict(
type=SuperGPQADataset,
abbr=f'supergpqa_{category.replace(" ", "_")}',
# abbr='supergpqa',
path='opencompass/supergpqa',
prompt_mode='zero-shot',
category=category,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
# print(type(supergpqa_0shot_single_datasets))
supergpqa_0shot_single_datasets.append(supergpqa_dataset)

View File

@ -0,0 +1,291 @@
categories =[
"Power Systems and Automation",
"Anesthesiology",
"Oncology",
"Group Theory",
"Thermal Energy Engineering",
"Emergency Medicine",
"Systems Science",
"Geometry and Topology",
"Advanced Algebra",
"Electrical Theory and New Technologies",
"Engineering Thermophysics",
"Operating Systems",
"Guidance, Navigation and Control",
"Harmony",
"Marine Biology",
"Pediatrics",
"Road and Railway Engineering",
"Information Management and Communication",
"Water conservancy and Hydropower Engineering",
"Veterinary Medicine",
"Astronomical Observation and Technology",
"Special Number Theory",
"Philology and Bibliography",
"Textile Materials Science",
"Legal Theory and Legal History",
"Education Economics, Management and Social Security",
"Traditional Chinese Health Preservation",
"Epidemiology and Health Statistics",
"Pitch and Scales",
"Economic History",
"Marine Engineering",
"Labor Economics",
"Materials Processing Engineering",
"Demography and Anthropology",
"Preschool Education",
"Music History, Education, and Technology",
"Instrumentation and Performance",
"Cryptography",
"Mineralogy, Petrology, and Economic Geology",
"Microbiology and Biochemical Pharmacy",
"Poromechanics and Reservoir Physics",
"Imaging and Nuclear Medicine",
"Solid State Physics",
"Microelectronics and Solid-State Electronics",
"Zoology",
"Food Biochemistry",
"Traditional Chinese Pharmacy",
"Neurology",
"Hydrogeology",
"Criminal Law",
"Radiation Medicine",
"Relativity",
"Analytical Chemistry",
"Signal and Information Processing",
"Military Command and Information Systems",
"Literary Theory",
"Textile Chemistry and Dyeing Engineering",
"Urban Infrastructure Engineering",
"Stellar and Interstellar Evolution",
"Geological Resources and Geological Engineering",
"Pattern Recognition",
"Engineering Fluid Mechanics",
"Communication and Information Systems",
"Architectural History",
"Stochastic Processes",
"Microbiology",
"French Language and Literature",
"Principles of Computer Organization",
"Architectural Design and Theory",
"Animal Rearing and Breeding",
"Physical Oceanography",
"Acoustics",
"Organic Chemistry",
"Refrigeration and Cryogenic Engineering",
"Public Finance",
"Dermatology and Venereology",
"Religious Studies",
"Discrete Mathematics",
"Forest Cultivation and Genetic Breeding",
"Vehicle Operation Engineering",
"Physical Chemistry",
"Nutrition and Food Hygiene",
"Ship Mechanics and Design Principles",
"Power Electronics and Electrical Drives",
"Finance",
"Pharmacology",
"Environmental Engineering",
"Ecology",
"Aeronautical and Astronautical Science and Technology",
"Agricultural Mechanization Engineering",
"Computer Architecture",
"Political Economy",
"Principles of Seismic Exploration",
"Elements of Chemical Reaction Engineering",
"Digital Surveying and Remote Sensing Applications",
"History and Theory of Journalism and Media Management",
"Instrument Science and Technology",
"Structural Engineering",
"Computer Networks",
"Power Machinery and Engineering",
"Constitutional and Administrative Law",
"Law and Social Governance",
"Psychology",
"Urban Planning and Design",
"Thermodynamics and Statistical Physics",
"Chemical Transport Engineering",
"Environmental and Resource Protection",
"Fluid Machinery and Engineering",
"Cartography and Geographic Information Engineering",
"Computational Mathematics",
"Pathogen Biology",
"Human Geography",
"Theoretical Optics",
"Solid Mechanics",
"Electrochemistry",
"Aquaculture",
"Logic",
"Mechatronic Engineering",
"Modern and Contemporary Chinese Literature",
"Operations Research and Cybernetics",
"Circuits and Systems",
"Internal Combustion Engineering",
"Atomic and Molecular Physics",
"Marine Chemistry",
"Electromagnetic Field and Microwave Technology",
"Rigid Body Mechanics",
"Physiology",
"Military Chemistry and Pyrotechnics",
"Fundamentals of Dynamics and Control",
"Control Theory and Control Engineering",
"Historical Geography",
"Physical Geography",
"National and Defense Economics",
"Polymer Physics",
"Landscape Plants and Ornamental Horticulture",
"Solar System Science",
"Library and Archival Science",
"Internal Medicine",
"Physical Chemistry of Metallurgical Process",
"Antenna and Radio Communication",
"Genetics",
"Graph Theory",
"Principles of Metallurgy",
"Bridge and Tunnel Engineering",
"Combinatorial Mathematics",
"Otorhinolaryngology",
"Political Science",
"Medicinal Chemistry",
"Health Toxicology and Environmental Health",
"Archaeology and Museology",
"Geotechnical Engineering",
"Land Resource Management and Administrative Management",
"Thermodynamics",
"Atmospheric Physics and Atmospheric Environment",
"Broadcasting and Television Art",
"Numerical Analysis",
"Statistical Mechanics",
"Mineral Processing Engineering",
"Mathematical Analysis",
"Philosophy of Science and Technology",
"Western Economics",
"Data Structures",
"Fine Arts",
"Economic Statistics",
"Environmental Science",
"Military Thought and History",
"Drama and Opera Studies",
"Film Studies",
"High Voltage and Insulation Technology",
"Military Law",
"Wood Science and Technology",
"Obstetrics and Gynecology",
"Hydraulics and Hydrology",
"Cell Biology",
"Biochemistry and Molecular Biology",
"Fluid Flow and Heat Transfer in Chemical Engineering",
"Formal Languages",
"Optoelectronic Technology",
"Crop Science",
"Fundamental Mathematics",
"Immunology",
"Surgery",
"Ophthalmology",
"Social Medicine and Health Management",
"Industrial Economics",
"Traffic Information Engineering and Control",
"Traditional Chinese Medicine Theory",
"Polymer Chemistry and Physics",
"Maternal, Child and Adolescent Health",
"Radiation Protection and Nuclear Technology Applications",
"Food Processing and Storage Engineering",
"Fluid Physics",
"Materials Physics and Chemistry",
"Pharmaceutical Analysis",
"Semiconductor Physics",
"Optical Fiber Communication",
"Ethics",
"Psychiatry and Mental Health",
"Management Science and Engineering",
"Number Theory",
"Contract Law",
"Inorganic Chemistry",
"Design Arts",
"Human Anatomy and Histology-Embryology",
"Iron and Steel Metallurgy",
"Dance Studies",
"Structural Geology",
"Special Education",
"Musical Forms and Analysis",
"Philosophical Aesthetics",
"Astrophysics",
"Manufacturing Automation",
"Quantum Mechanics",
"Probability and Statistics",
"Military Logistics and Equipment",
"Heat Transfer",
"Classical Chinese Literature",
"Information Management Science",
"Cosmology",
"Educational Technology and Principles",
"Ordinary Differential Equations",
"Underwater Acoustics",
"Business and Accounting Management",
"Dynamic Meteorology",
"Military Management",
"Journalism and News Practice",
"Animal Nutrition and Feed Science",
"Applied Optics",
"Theoretical Fluid Mechanics",
"Communication Principles",
"Physical Education and Training",
"Geodesy and Surveying Engineering",
"Meteorology",
"Sports Science and Medicine",
"Solid Earth Geophysics",
"Particle and Nuclear Physics",
"International Law",
"Oil and Gas Field Development and Storage & Transportation Engineering",
"Basic Stomatology",
"Agricultural Environment and Soil-Water Engineering",
"Geochemistry",
"Procedural Law",
"Botany",
"Fuzzy Mathematics",
"Paleontology and Stratigraphy",
"Sports Humanities and Sociology",
"Civil and Commercial Law",
"Electrodynamics",
"Mining and Safety Engineering",
"Mass Transport and Separation Process in Chemical Engineering",
"Advanced Programming Languages",
"Laser Technology",
"Weapon Systems Science and Engineering",
"Quantitative Economics",
"Theoretical Mechanics",
"Nursing and Rehabilitation Medicine",
"Databases",
"Pharmaceutics",
"Space physics",
"Functions of Real Variables",
"Non-ferrous Metallurgy",
"Theory of Curriculum and Instruction",
"Clinical Laboratory Diagnostics",
"Clinical Stomatology",
"Literary History",
"Tourism Management and Technological Economics Management",
"Communication and Broadcasting",
"Pathology and Pathophysiology",
"Functions of Complex Variables",
"World History",
"Forest Engineering",
"Forensic Medicine",
"Linguistics and Applied Linguistics",
"Social and Folklore Studies",
"Computer Software and Theory",
"Subatomic and Atomic Physics",
"Biophysics",
"Radiochemistry",
"Russian Language and Literature",
"International Trade",
"Geriatric Medicine",
"Composition",
"Transportation Planning and Management",
"Polynomials and Series Expansions",
"Nuclear Energy and Reactor Technology"
]
supergpqa_summary_groups = [
{'name': 'supergpqa', 'subsets': ['supergpqa_' + c.replace(' ', '_') for c in categories]},
]

View File

@ -0,0 +1,296 @@
from mmengine.config import read_base
with read_base():
from .groups.supergpqa import supergpqa_summary_groups
summarizer = dict(
dataset_abbrs=[
"supergpqa",
"supergpqa_Procedural_Law",
"supergpqa_Microbiology",
"supergpqa_World_History",
"supergpqa_Civil_and_Commercial_Law",
"supergpqa_Relativity",
"supergpqa_Discrete_Mathematics",
"supergpqa_Laser_Technology",
"supergpqa_Power_Machinery_and_Engineering",
"supergpqa_Geotechnical_Engineering",
"supergpqa_Mineralogy,_Petrology,_and_Economic_Geology",
"supergpqa_Fluid_Flow_and_Heat_Transfer_in_Chemical_Engineering",
"supergpqa_Composition",
"supergpqa_Biophysics",
"supergpqa_Geriatric_Medicine",
"supergpqa_Cell_Biology",
"supergpqa_Underwater_Acoustics",
"supergpqa_Political_Science",
"supergpqa_Atomic_and_Molecular_Physics",
"supergpqa_Industrial_Economics",
"supergpqa_Marine_Chemistry",
"supergpqa_Ophthalmology",
"supergpqa_Geochemistry",
"supergpqa_Anesthesiology",
"supergpqa_Sports_Science_and_Medicine",
"supergpqa_Forest_Cultivation_and_Genetic_Breeding",
"supergpqa_Philology_and_Bibliography",
"supergpqa_Cryptography",
"supergpqa_Road_and_Railway_Engineering",
"supergpqa_Literary_History",
"supergpqa_Mining_and_Safety_Engineering",
"supergpqa_Group_Theory",
"supergpqa_Crop_Science",
"supergpqa_Food_Biochemistry",
"supergpqa_Textile_Materials_Science",
"supergpqa_Fundamental_Mathematics",
"supergpqa_Microelectronics_and_Solid-State_Electronics",
"supergpqa_International_Law",
"supergpqa_Agricultural_Environment_and_Soil-Water_Engineering",
"supergpqa_Environmental_Science",
"supergpqa_Urban_Infrastructure_Engineering",
"supergpqa_Solid_State_Physics",
"supergpqa_Mechatronic_Engineering",
"supergpqa_Economic_History",
"supergpqa_Power_Electronics_and_Electrical_Drives",
"supergpqa_History_and_Theory_of_Journalism_and_Media_Management",
"supergpqa_Neurology",
"supergpqa_Computer_Networks",
"supergpqa_Animal_Nutrition_and_Feed_Science",
"supergpqa_Marine_Engineering",
"supergpqa_Materials_Physics_and_Chemistry",
"supergpqa_Business_and_Accounting_Management",
"supergpqa_Basic_Stomatology",
"supergpqa_Space_physics",
"supergpqa_Transportation_Planning_and_Management",
"supergpqa_Information_Management_and_Communication",
"supergpqa_Quantitative_Economics",
"supergpqa_Elements_of_Chemical_Reaction_Engineering",
"supergpqa_Library_and_Archival_Science",
"supergpqa_Electrodynamics",
"supergpqa_Fluid_Machinery_and_Engineering",
"supergpqa_Dynamic_Meteorology",
"supergpqa_Functions_of_Real_Variables",
"supergpqa_Pharmacology",
"supergpqa_Communication_Principles",
"supergpqa_Communication_and_Broadcasting",
"supergpqa_Musical_Forms_and_Analysis",
"supergpqa_Cartography_and_Geographic_Information_Engineering",
"supergpqa_Maternal,_Child_and_Adolescent_Health",
"supergpqa_Clinical_Stomatology",
"supergpqa_Data_Structures",
"supergpqa_Optoelectronic_Technology",
"supergpqa_Physiology",
"supergpqa_Thermodynamics_and_Statistical_Physics",
"supergpqa_Pediatrics",
"supergpqa_Geodesy_and_Surveying_Engineering",
"supergpqa_Theoretical_Mechanics",
"supergpqa_Hydraulics_and_Hydrology",
"supergpqa_International_Trade",
"supergpqa_Military_Chemistry_and_Pyrotechnics",
"supergpqa_Finance",
"supergpqa_Psychiatry_and_Mental_Health",
"supergpqa_Fundamentals_of_Dynamics_and_Control",
"supergpqa_Sports_Humanities_and_Sociology",
"supergpqa_Harmony",
"supergpqa_Control_Theory_and_Control_Engineering",
"supergpqa_Surgery",
"supergpqa_Analytical_Chemistry",
"supergpqa_Political_Economy",
"supergpqa_Theory_of_Curriculum_and_Instruction",
"supergpqa_High_Voltage_and_Insulation_Technology",
"supergpqa_Numerical_Analysis",
"supergpqa_Physical_Geography",
"supergpqa_Physical_Education_and_Training",
"supergpqa_Applied_Optics",
"supergpqa_Mathematical_Analysis",
"supergpqa_Advanced_Programming_Languages",
"supergpqa_Western_Economics",
"supergpqa_Organic_Chemistry",
"supergpqa_French_Language_and_Literature",
"supergpqa_Urban_Planning_and_Design",
"supergpqa_Polynomials_and_Series_Expansions",
"supergpqa_Functions_of_Complex_Variables",
"supergpqa_Advanced_Algebra",
"supergpqa_Operating_Systems",
"supergpqa_Internal_Combustion_Engineering",
"supergpqa_Food_Processing_and_Storage_Engineering",
"supergpqa_Educational_Technology_and_Principles",
"supergpqa_Acoustics",
"supergpqa_Quantum_Mechanics",
"supergpqa_Iron_and_Steel_Metallurgy",
"supergpqa_Land_Resource_Management_and_Administrative_Management",
"supergpqa_Fuzzy_Mathematics",
"supergpqa_Special_Education",
"supergpqa_Solid_Mechanics",
"supergpqa_Zoology",
"supergpqa_Demography_and_Anthropology",
"supergpqa_Tourism_Management_and_Technological_Economics_Management",
"supergpqa_Theoretical_Optics",
"supergpqa_Genetics",
"supergpqa_Constitutional_and_Administrative_Law",
"supergpqa_Structural_Engineering",
"supergpqa_Principles_of_Metallurgy",
"supergpqa_Medicinal_Chemistry",
"supergpqa_Electromagnetic_Field_and_Microwave_Technology",
"supergpqa_Clinical_Laboratory_Diagnostics",
"supergpqa_Theoretical_Fluid_Mechanics",
"supergpqa_Pitch_and_Scales",
"supergpqa_Stochastic_Processes",
"supergpqa_Ethics",
"supergpqa_Circuits_and_Systems",
"supergpqa_Engineering_Thermophysics",
"supergpqa_Landscape_Plants_and_Ornamental_Horticulture",
"supergpqa_Polymer_Physics",
"supergpqa_Wood_Science_and_Technology",
"supergpqa_Biochemistry_and_Molecular_Biology",
"supergpqa_Preschool_Education",
"supergpqa_Psychology",
"supergpqa_Traditional_Chinese_Health_Preservation",
"supergpqa_Modern_and_Contemporary_Chinese_Literature",
"supergpqa_Religious_Studies",
"supergpqa_Subatomic_and_Atomic_Physics",
"supergpqa_Human_Geography",
"supergpqa_Water_conservancy_and_Hydropower_Engineering",
"supergpqa_Thermal_Energy_Engineering",
"supergpqa_Immunology",
"supergpqa_Communication_and_Information_Systems",
"supergpqa_Meteorology",
"supergpqa_Bridge_and_Tunnel_Engineering",
"supergpqa_Military_Management",
"supergpqa_Russian_Language_and_Literature",
"supergpqa_Particle_and_Nuclear_Physics",
"supergpqa_Rigid_Body_Mechanics",
"supergpqa_Nuclear_Energy_and_Reactor_Technology",
"supergpqa_Oncology",
"supergpqa_Public_Finance",
"supergpqa_Classical_Chinese_Literature",
"supergpqa_Ecology",
"supergpqa_Principles_of_Computer_Organization",
"supergpqa_Pattern_Recognition",
"supergpqa_Databases",
"supergpqa_Ordinary_Differential_Equations",
"supergpqa_Electrochemistry",
"supergpqa_Traditional_Chinese_Pharmacy",
"supergpqa_Dance_Studies",
"supergpqa_Pharmaceutical_Analysis",
"supergpqa_Otorhinolaryngology",
"supergpqa_Principles_of_Seismic_Exploration",
"supergpqa_Physical_Chemistry",
"supergpqa_Special_Number_Theory",
"supergpqa_Astrophysics",
"supergpqa_Physical_Oceanography",
"supergpqa_Instrumentation_and_Performance",
"supergpqa_Military_Law",
"supergpqa_Signal_and_Information_Processing",
"supergpqa_Thermodynamics",
"supergpqa_Architectural_Design_and_Theory",
"supergpqa_Non-ferrous_Metallurgy",
"supergpqa_Internal_Medicine",
"supergpqa_Film_Studies",
"supergpqa_Fluid_Physics",
"supergpqa_Refrigeration_and_Cryogenic_Engineering",
"supergpqa_Broadcasting_and_Television_Art",
"supergpqa_Social_Medicine_and_Health_Management",
"supergpqa_Military_Logistics_and_Equipment",
"supergpqa_Criminal_Law",
"supergpqa_Electrical_Theory_and_New_Technologies",
"supergpqa_Nutrition_and_Food_Hygiene",
"supergpqa_Literary_Theory",
"supergpqa_Instrument_Science_and_Technology",
"supergpqa_Legal_Theory_and_Legal_History",
"supergpqa_Computer_Architecture",
"supergpqa_Chemical_Transport_Engineering",
"supergpqa_Military_Thought_and_History",
"supergpqa_Archaeology_and_Museology",
"supergpqa_Architectural_History",
"supergpqa_Microbiology_and_Biochemical_Pharmacy",
"supergpqa_Philosophy_of_Science_and_Technology",
"supergpqa_Labor_Economics",
"supergpqa_Dermatology_and_Venereology",
"supergpqa_Materials_Processing_Engineering",
"supergpqa_Human_Anatomy_and_Histology-Embryology",
"supergpqa_Optical_Fiber_Communication",
"supergpqa_Journalism_and_News_Practice",
"supergpqa_Emergency_Medicine",
"supergpqa_Veterinary_Medicine",
"supergpqa_Heat_Transfer",
"supergpqa_Information_Management_Science",
"supergpqa_Physical_Chemistry_of_Metallurgical_Process",
"supergpqa_Radiochemistry",
"supergpqa_Guidance,_Navigation_and_Control",
"supergpqa_Solid_Earth_Geophysics",
"supergpqa_Systems_Science",
"supergpqa_Weapon_Systems_Science_and_Engineering",
"supergpqa_Manufacturing_Automation",
"supergpqa_Engineering_Fluid_Mechanics",
"supergpqa_Mineral_Processing_Engineering",
"supergpqa_Animal_Rearing_and_Breeding",
"supergpqa_Philosophical_Aesthetics",
"supergpqa_Solar_System_Science",
"supergpqa_Antenna_and_Radio_Communication",
"supergpqa_Computational_Mathematics",
"supergpqa_Health_Toxicology_and_Environmental_Health",
"supergpqa_Design_Arts",
"supergpqa_Computer_Software_and_Theory",
"supergpqa_Aquaculture",
"supergpqa_Nursing_and_Rehabilitation_Medicine",
"supergpqa_Inorganic_Chemistry",
"supergpqa_Traffic_Information_Engineering_and_Control",
"supergpqa_Botany",
"supergpqa_Number_Theory",
"supergpqa_Hydrogeology",
"supergpqa_Marine_Biology",
"supergpqa_Law_and_Social_Governance",
"supergpqa_Contract_Law",
"supergpqa_Vehicle_Operation_Engineering",
"supergpqa_Aeronautical_and_Astronautical_Science_and_Technology",
"supergpqa_Poromechanics_and_Reservoir_Physics",
"supergpqa_Pathogen_Biology",
"supergpqa_Power_Systems_and_Automation",
"supergpqa_Epidemiology_and_Health_Statistics",
"supergpqa_Drama_and_Opera_Studies",
"supergpqa_Environmental_Engineering",
"supergpqa_Polymer_Chemistry_and_Physics",
"supergpqa_Digital_Surveying_and_Remote_Sensing_Applications",
"supergpqa_Atmospheric_Physics_and_Atmospheric_Environment",
"supergpqa_Education_Economics,_Management_and_Social_Security",
"supergpqa_Probability_and_Statistics",
"supergpqa_Geometry_and_Topology",
"supergpqa_Linguistics_and_Applied_Linguistics",
"supergpqa_Astronomical_Observation_and_Technology",
"supergpqa_Forensic_Medicine",
"supergpqa_Fine_Arts",
"supergpqa_Paleontology_and_Stratigraphy",
"supergpqa_Management_Science_and_Engineering",
"supergpqa_Logic",
"supergpqa_Agricultural_Mechanization_Engineering",
"supergpqa_Traditional_Chinese_Medicine_Theory",
"supergpqa_Obstetrics_and_Gynecology",
"supergpqa_Ship_Mechanics_and_Design_Principles",
"supergpqa_Statistical_Mechanics",
"supergpqa_Combinatorial_Mathematics",
"supergpqa_Mass_Transport_and_Separation_Process_in_Chemical_Engineering",
"supergpqa_Economic_Statistics",
"supergpqa_Operations_Research_and_Cybernetics",
"supergpqa_Formal_Languages",
"supergpqa_Oil_and_Gas_Field_Development_and_Storage_&_Transportation_Engineering",
"supergpqa_Environmental_and_Resource_Protection",
"supergpqa_Structural_Geology",
"supergpqa_Semiconductor_Physics",
"supergpqa_Social_and_Folklore_Studies",
"supergpqa_Music_History,_Education,_and_Technology",
"supergpqa_Radiation_Protection_and_Nuclear_Technology_Applications",
"supergpqa_Pathology_and_Pathophysiology",
"supergpqa_Textile_Chemistry_and_Dyeing_Engineering",
"supergpqa_Military_Command_and_Information_Systems",
"supergpqa_Forest_Engineering",
"supergpqa_Graph_Theory",
"supergpqa_Radiation_Medicine",
"supergpqa_Geological_Resources_and_Geological_Engineering",
"supergpqa_Historical_Geography",
"supergpqa_Cosmology",
"supergpqa_Pharmaceutics",
"supergpqa_National_and_Defense_Economics",
"supergpqa_Imaging_and_Nuclear_Medicine",
"supergpqa_Stellar_and_Interstellar_Evolution"
],
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
)

View File

@ -146,3 +146,4 @@ from .xcopa import * # noqa: F401, F403
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
from .xlsum import * # noqa: F401, F403
from .xsum import * # noqa: F401, F403
from .supergpqa import *

View File

@ -0,0 +1,152 @@
import csv
import json
import os.path as osp
from os import environ
from datasets import load_dataset
import os
from datasets import Dataset, DatasetDict
from opencompass.datasets.supergpqa.supergpqa_utils import (
evaluate_responses, find_file, load_json_or_jsonl,
load_json_or_jsonl_with_idx, load_yaml)
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS,LOAD_DATASET
import unittest
from opencompass.utils import get_data_path
from opencompass.datasets.supergpqa.supergpqa_eval import (extract_option_labels,extract_option_content)
from ..base import BaseDataset
def _parse(item, template, prompt_mode):
prompt_format = [item['question']+'\n'+'\n'.join([f'{chr(65+i)}) {option}' for i, option in enumerate(item['options'])])]
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
item['prompt_mode'] = prompt_mode
return item
@LOAD_DATASET.register_module()
class SuperGPQADataset(BaseDataset):
@staticmethod
def load(path: str, prompt_mode: str,category:str, **kwargs):
path = get_data_path(path)
dataset = load_dataset(path, split='train')
dataset = dataset.filter(lambda x: x['subfield'] == category)
#get prompt template
template_path = None
if prompt_mode == 'zero-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/zero-shot.yaml')
elif prompt_mode == 'five-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/five-shot.yaml')
try:
template = load_yaml(template_path)
except FileNotFoundError:
print(f'[ERROR] Missing prompt template: {template_path}')
return Dataset.from_list([])
dataset =dataset.map(lambda item: _parse(item, template, prompt_mode))
return dataset
@ICL_EVALUATORS.register_module()
class SuperGPQAEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()
def score(self,predictions, references,test_set):
mode= test_set[0]['prompt_mode']
acc = 0
count = 0
err = 0
miss = 0
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
stats = {
'discipline': {},
'field': {},
'subfield': {}
}
for i,sample in enumerate(test_set):
prediction=predictions[i]
gold=references[i]
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(prediction, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
sample["extracted_answer"] = predict
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(response, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
if predict == None:
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(prediction, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
sample["extracted_answer"] = predict
discipline = sample.get("discipline", "unknown")
field = sample.get("field", "unknown")
subfield = sample.get("subfield", "unknown")
difficulty = sample.get("difficulty", "unknown")
for level, key in [
('discipline', discipline),
('field', f"{discipline}/{field}"),
('subfield', f"{discipline}/{field}/{subfield}")
]:
if key not in stats[level]:
stats[level][key] = {
"correct": 0,
"total": 0,
"miss": 0,
"error": 0,
"discipline": discipline,
"field": field,
"subfield": subfield,
"difficulty": {
"easy": {"correct": 0, "total": 0},
"middle": {"correct": 0, "total": 0},
"hard": {"correct": 0, "total": 0}
}
}
stats[level][key]["total"] += 1
stats[level][key]["difficulty"][difficulty]["total"] += 1
answer_letter = sample["answer_letter"]
assert answer_letter==gold
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
sample["status"] = "correct"
stats[level][key]["correct"] += 1
stats[level][key]["difficulty"][difficulty]["correct"] += 1
elif predict == None or predict == "":
miss += 1
sample["status"] = "miss"
stats[level][key]["miss"] += 1
elif predict == 'error':
err += 1
sample["status"] = "error"
stats[level][key]["error"] += 1
else:
sample["status"] = "incorrect"
count += 1
count_difficulty[difficulty] += 1
return {
'accuracy': acc / count if count > 0 else 0,
# 'error_rate': err / count if count > 0 else 0,
# 'miss_rate': miss / count if count > 0 else 0,
# 'hard_accuracy': acc_difficulty["hard"] / count_difficulty["hard"] if count_difficulty["hard"] > 0 else 0,
# 'middle_accuracy': acc_difficulty["middle"] / count_difficulty["middle"] if count_difficulty["middle"] > 0 else 0,
# 'easy_accuracy': acc_difficulty["easy"] / count_difficulty["easy"] if count_difficulty["easy"] > 0 else 0
}

View File

@ -0,0 +1,17 @@
response_key: 'response'
error_key: 'error'
id_key:
- 'uuid'
prompt_key: 'prompt'
history_key: 'history'
status_key: 'status'
save_prompt: True
max_tokens: 4096
temperatrue: 0.0
max_rounds: 30
BoN: 32

View File

@ -0,0 +1,17 @@
response_key: 'response'
error_key: 'error'
id_key:
- 'uuid'
prompt_key: 'prompt'
history_key: 'history'
status_key: 'status'
save_prompt: True
max_tokens: 32768
temperatrue: 0.0
max_rounds: 30
BoN: 32

View File

@ -0,0 +1,51 @@
import yaml
import uuid
class ConfigWrapper:
def __init__(self, config_path):
self._config = {}
with open(config_path, 'r') as file:
self._config = yaml.safe_load(file)
for key, value in self._config.items():
setattr(self, key, value)
def __setattr__(self, key, value):
if key.startswith('_'):
super().__setattr__(key, value)
else:
self._config[key] = value
super().__setattr__(key, value)
def __getattr__(self, key):
if key in self._config:
return self._config[key]
raise AttributeError(f"'ConfigWrapper' object has no attribute '{key}'")
def get_id(self, data):
if isinstance(self._config.get('id_key'), str):
return data.get(self._config.get('id_key'), None)
elif isinstance(self._config.get('id_key'), list):
return '_'.join([str(data[key]) for key in self._config.get('id_key') if key in data])
def print_all_keys(self):
print("config keys:")
for key, value in self._config.items():
print(f" - {key}: {value}")
config_wrapper = None
def initialize_config(config_path):
global config_wrapper
config_wrapper = ConfigWrapper(config_path)
def get_config_wrapper():
global config_wrapper
if config_wrapper is None:
raise RuntimeError("ConfigWrapper not initialized. Call initialize_config first.")
return config_wrapper
if __name__ == '__main__':
config_path = 'config/config.yaml'
initialize_config(config_path)
data = {'idx': '50', 'step':21, 'question': 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\nPlease provide the decrypted answer, encapsulated in double square brackets. For example, the format should be: [[decrypted answer]].', 'answer': '[[P]]', 'category': 'Decryption', 'rule_id': '23', 'input': 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"', 'steps_num': 23, 'description': 'For a number c=228 in the ciphertext:\nCalculate z = c^e mod n. Here ^ means multiplication.\nz is 80.\nBased on the decimal number represented by z, use the ascii code to find the corresponding letter as the plaintext letter p.\nPlease give the letter p in [[...]] format.\n', 'atom': 80}
print(config_wrapper.get_id(data))

View File

@ -0,0 +1,91 @@
prompt_format:
- |
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
Question:
A refracting telescope consists of two converging lenses separated by 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification of the telescope is
A) 10
B) 40
C) 6
D) 25
E) 15
F) 50
G) 30
H) 4
I) 5
J) 20
Answer: Let's think step by step. In a refracting telescope, if both lenses are converging, the focus of both lenses must be between the two lenses, and thus the focal lengths of the two lenses must add up to their separation. Since the focal length of one lens is 20 cm, the focal length of the other must be 80 cm. The magnification is the ratio of these two focal lengths, or 4.
Answer: H.
Question:
Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?
A) 1000 times more
B) 50 times more
C) 5000 times more
D) 500 times more
E) 10000 times more
F) 20000 times more
G) 2000 times more
H) 100 times more
I) 10 times more
J) N/A
Answer: Let's think step by step. The amount of light a telescope can gather compared to the human eye is proportional to the area of its apertures. The area of a circle is given by the formula $A = \pi \left(\frac{{D}}{{2}}\right)^2$, where $D$ is the diameter. Therefore, the relative light-gathering power is calculated as:
\[
\frac{{\left(\frac{{50 \text{{ cm}}}}{{2}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{2}}\right)^2}} = \frac{{\left(\frac{{50 \text{{ cm}}}}{{0.1 \text{{ cm}}}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{0.1 \text{{ cm}}}}\right)^2}} = \frac{{500^2}}{{5^2}} = 10000.
\]
Answer: E.
Question:
Where do most short-period comets come from and how do we know?
A) The Kuiper belt; short period comets tend to be in the plane of the solar system like the Kuiper belt.
B) The asteroid belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the asteroid belt.
C) The asteroid belt; short period comets tend to be in the plane of the solar system just like the asteroid belt.
D) The Oort cloud; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the Oort cloud.
E) The Oort Cloud; short period comets tend to come from random directions indicating a spherical distribution of comets called the Oort Cloud.
F) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.
G) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt.
Answer: Let's think step by step. Most short-period comets originate from the Kuiper belt. This is deduced from the observation that these comets tend to follow orbits that lie in the plane of the solar system, similar to the distribution of objects in the Kuiper belt itself. Thus, the alignment of these cometary orbits with the ecliptic plane points to their Kuiper belt origin.
Answer: A.
Question:
Colors in a soap bubble result from light
A) dispersion
B) deflection
C) refraction
D) reflection
E) interference
F) converted to a different frequency
G) polarization
H) absorption
I) diffraction
J) transmission
Answer: Let's think step by step. The colorful patterns observed in a soap bubble are caused by the phenomenon of light interference. This occurs when light waves bounce between the two surfaces of the soap film, combining constructively or destructively based on their phase differences and the varying thickness of the film. These interactions result in vibrant color patterns due to variations in the intensity of different wavelengths of light.
Answer: E.
Question:
A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?
A) 240 W
B) 120 W
C) 10 W
D) 480 W
E) 360 W
F) 200 W
G) 30 W
H) 150 W
I) 60 W
J) 300 W
Answer: Let's think step by step. The rate of energy usage, known as power, in an electrical circuit is calculated by the product of voltage and current. For a microwave oven connected to a 120 V outlet and drawing a current of 2 amps, the power consumption can be calculated as follows:
\[
\text{{Power}} = \text{{Voltage}} \times \text{{Current}} = 120 \, \text{{V}} \times 2 \, \text{{A}} = 240 \, \text{{W}}.
\]
Therefore, the microwave oven uses energy at a rate of 240 watts.
Answer: A.
Question:
{}
Answer: Let's think step by step.

View File

@ -0,0 +1,23 @@
initial_prompt_0:
- |
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}
initial_prompt_1:
- |
You are a helpful assistant. Answer the given multiple-choice question. Only one option is correct. The last line of your response should be in the format 'The correct answer is: $LETTER', where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}
initial_prompt_2:
- |
Select the correct answer for the following multiple-choice question. There is only one valid choice. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}
initial_prompt_3:
- |
Review the following multiple-choice question and choose the one correct answer. Ensure that your response concludes with a line exactly formatted as 'The correct answer is: $LETTER', where LETTER represents one of A, B, C, D, E, F, G, H, I, or J.
{}

View File

@ -0,0 +1,5 @@
prompt_format:
- |
Answer the following multiple choice question about {}. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}

View File

@ -0,0 +1,5 @@
prompt_format:
- |
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}

View File

@ -0,0 +1,633 @@
import json
import re
import argparse
import os
from prettytable import PrettyTable
import pandas as pd
# from openpyxl.styles import PatternFill, Font, Alignment
from tqdm import tqdm
import timeout_decorator
import multiprocessing
import time
from functools import partial
@timeout_decorator.timeout(5) # 5 seconds timeout
def safe_regex_search(pattern, text, flags=0):
try:
return re.search(pattern, text, flags)
except timeout_decorator.TimeoutError:
print(f"Regex match timeout: pattern={pattern}, text={text[:100]}...")
return None
except Exception as e:
print(f"Regex match error: {str(e)}")
return None
def extract_option_labels(text, options='ABCDEFGHIJ'):
if not isinstance(text, str) or not isinstance(options, str):
return 'error'
text = text.rstrip()
last_line = text.split('\n')[-1]
option_str = ''.join([chr(65 + i) for i in range(len(options))]) if options else 'ABCDEFGHIJ'
patterns = [
# e.g. "The final answer to this question is: A."
# "The best option is $\boxed{B}:"
# "The correct answer is (C)."
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is?:?\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
# e.g. "ANSWER: A"
# "Answer: $\boxed{B}."
# "ANSWER: (C):"
f'(?i:Answer)[\*\s]*:\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
# e.g. "A"
# "$\boxed{B}$"
# "(C)."
# "[D]:"
f'^[^\w\r\n]*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
]
for pattern in patterns:
match = safe_regex_search(pattern, last_line, re.IGNORECASE)
if match:
return match.group(1)
for pattern in patterns:
match = safe_regex_search(pattern, text, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_option_content(text, options_content=None):
if not isinstance(text, str) or not isinstance(options_content, list):
return 'error'
escaped_options_content = [re.escape(option_content) for option_content in options_content]
escaped_options_content_str = '|'.join(escaped_options_content)
text = text.rstrip()
last_line = text.split('\n')[-1]
patterns = [
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is:?\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'(?i:Answer)\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'^[^\w\r\n]*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
]
for pattern in patterns:
match = safe_regex_search(pattern, last_line)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(match.group(1))]
else:
return match.group(1)
for pattern in patterns:
match = safe_regex_search(pattern, text)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(match.group(1))]
else:
return match.group(1)
return None
def calculate_accuracy(file_path, save_dir, mode):
data = []
acc = 0
count = 0
err = 0
miss = 0
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
stats = {
'discipline': {},
'field': {},
'subfield': {}
}
with open(file_path, "r") as file:
for line in tqdm(file, desc=f"Reading {os.path.basename(file_path)} data", leave=False):
data.append(json.loads(line))
if not data:
print(f"Warning: No data found in {file_path}")
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
for sample in tqdm(data, desc=f"Processing {os.path.basename(file_path)} samples", leave=False):
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(prediction, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
sample["extracted_answer"] = predict
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(response, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
if predict == None:
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(prediction, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
sample["extracted_answer"] = predict
discipline = sample.get("discipline", "unknown")
field = sample.get("field", "unknown")
subfield = sample.get("subfield", "unknown")
difficulty = sample.get("difficulty", "unknown")
for level, key in [
('discipline', discipline),
('field', f"{discipline}/{field}"),
('subfield', f"{discipline}/{field}/{subfield}")
]:
if key not in stats[level]:
stats[level][key] = {
"correct": 0,
"total": 0,
"miss": 0,
"error": 0,
"discipline": discipline,
"field": field,
"subfield": subfield,
"difficulty": {
"easy": {"correct": 0, "total": 0},
"middle": {"correct": 0, "total": 0},
"hard": {"correct": 0, "total": 0}
}
}
stats[level][key]["total"] += 1
stats[level][key]["difficulty"][difficulty]["total"] += 1
answer_letter = sample["answer_letter"]
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
sample["status"] = "correct"
stats[level][key]["correct"] += 1
stats[level][key]["difficulty"][difficulty]["correct"] += 1
elif predict == None or predict == "":
miss += 1
sample["status"] = "miss"
stats[level][key]["miss"] += 1
elif predict == 'error':
err += 1
sample["status"] = "error"
stats[level][key]["error"] += 1
else:
sample["status"] = "incorrect"
count += 1
count_difficulty[difficulty] += 1
if count == 0:
print(f"Warning: No valid samples found in {file_path}")
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
accuracy = acc / count
error_rate = err / count
miss_rate = miss / count
hard_accuracy = acc_difficulty["hard"] / count_difficulty["hard"]
middle_accuracy = acc_difficulty["middle"] / count_difficulty["middle"]
easy_accuracy = acc_difficulty["easy"] / count_difficulty["easy"]
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, os.path.basename(file_path))
with open(save_path, "w") as file:
for sample in data:
json.dump(sample, file)
file.write("\n")
return accuracy, error_rate, miss_rate, hard_accuracy, middle_accuracy, easy_accuracy, stats
def calculate_total_row(hierarchy_stats, model_results, metric_name):
"""Calculate overall summary row, including sample-wise and weighted average across dimensions"""
total_rows = []
# Calculate total counts across dimensions
total_samples = 0
if metric_name in ['Hard', 'Middle', 'Easy']:
total_subfields = sum(subfield['difficulty'][metric_name.lower()]['total'] > 0 for subfield in hierarchy_stats['subfield'].values())
total_fields = sum(field['difficulty'][metric_name.lower()]['total'] > 0 for field in hierarchy_stats['field'].values())
total_disciplines = sum(discipline['difficulty'][metric_name.lower()]['total'] > 0 for discipline in hierarchy_stats['discipline'].values())
else:
total_subfields = len(hierarchy_stats['subfield'])
total_fields = len(hierarchy_stats['field'])
total_disciplines = len(hierarchy_stats['discipline'])
# Calculate total sample count
for discipline_stats in hierarchy_stats['discipline'].values():
if metric_name in ['Hard', 'Middle', 'Easy']:
total_samples += discipline_stats['difficulty'][metric_name.lower()]['total']
else:
total_samples += discipline_stats['total']
if metric_name == 'Accuracy':
row_types = [
(f'Overall (sample-wise) (Total samples: {total_samples})', 'sample'),
(f'Overall (subfield-wise) (Total subfields: {total_subfields})', 'subfield'),
(f'Overall (field-wise) (Total fields: {total_fields})', 'field'),
(f'Overall (discipline-wise) (Total disciplines: {total_disciplines})', 'discipline')
]
elif metric_name in ['Hard', 'Middle', 'Easy']:
row_types = [
(f'Overall (sample-wise) (Total {metric_name.lower()} samples: {total_samples})', 'sample'),
(f'Overall (subfield-wise) (Total {metric_name.lower()} subfields: {total_subfields})', 'subfield'),
(f'Overall (field-wise) (Total {metric_name.lower()} fields: {total_fields})', 'field'),
(f'Overall (discipline-wise) (Total {metric_name.lower()} disciplines: {total_disciplines})', 'discipline')
]
else: # Error Rate and Miss Rate
row_types = [(f'Overall (Total samples: {total_samples})', 'sample')]
for row_name, stat_type in row_types:
total_row = {
'Discipline': row_name,
'Field': '',
'Subfield': ''
}
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
if stat_type == 'sample':
# sample-wise statistics (weighted by sample count)
stats = {'total': 0, 'correct': 0, 'error': 0, 'miss': 0}
for discipline_stats in hierarchy_stats['discipline'].values():
if 'model_stats' in discipline_stats and model_name in discipline_stats['model_stats']:
curr_stats = discipline_stats['model_stats'][model_name].get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
stats['total'] += difficulty_stats.get('total', 0)
stats['correct'] += difficulty_stats.get('correct', 0)
else:
for key in ['total', 'correct', 'error', 'miss']:
stats[key] += curr_stats.get(key, 0)
if stats['total'] > 0:
if metric_name in ['Hard', 'Middle', 'Easy'] or metric_name == 'Accuracy':
value = stats['correct'] / stats['total']
elif metric_name == 'Error Rate':
value = stats['error'] / stats['total']
else: # Miss Rate
value = stats['miss'] / stats['total']
else:
value = 0
else:
# Other dimension statistics (direct average of correct rates across categories)
scores = []
if stat_type == 'discipline':
categories = hierarchy_stats['discipline']
elif stat_type == 'field':
categories = hierarchy_stats['field']
else: # subfield
categories = hierarchy_stats['subfield']
for cat_stats in categories.values():
if 'model_stats' in cat_stats and model_name in cat_stats['model_stats']:
curr_stats = cat_stats['model_stats'][model_name].get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
score = difficulty_stats['correct'] / difficulty_stats['total']
scores.append(score)
else:
if curr_stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
score = curr_stats['correct'] / curr_stats['total']
scores.append(score)
value = sum(scores) / len(scores) if scores else 0
total_row[f'{model_name}_{mode}'] = f"{value:.2%}"
total_rows.append(total_row)
return total_rows
def create_excel_report_from_stats(model_results, hierarchy_stats, save_path):
print("Starting Excel report generation...")
# Create six different DataFrames for storing different metrics and difficulties
metrics = {
'Accuracy': {'rows': [], 'color': '000000'}, # black
'Error Rate': {'rows': [], 'color': '000000'}, # black
'Miss Rate': {'rows': [], 'color': '000000'}, # black
'Hard': {'rows': [], 'color': '000000'}, # black
'Middle': {'rows': [], 'color': '000000'}, # black
'Easy': {'rows': [], 'color': '000000'} # black
}
# Organize data by hierarchy
for discipline in tqdm(sorted(hierarchy_stats['discipline'].keys()), desc="Processing discipline level"):
discipline_stats = hierarchy_stats['discipline'][discipline]
discipline_total = discipline_stats['total']
# Get all fields under this discipline
categories = [k for k in hierarchy_stats['field'].keys()
if k.startswith(f"{discipline}/")]
for field_key in sorted(categories):
field_stats = hierarchy_stats['field'][field_key]
field = field_stats['field']
field_total = field_stats['total']
# Get all subfields under this field
subcategories = [k for k in hierarchy_stats['subfield'].keys()
if k.startswith(f"{discipline}/{field}/")]
# Add subfield row
for subfield_key in sorted(subcategories):
subfield_stats = hierarchy_stats['subfield'][subfield_key]
# Create base row data for each metric
for metric_name in metrics:
if metric_name in ['Hard', 'Middle', 'Easy']:
base_row = {
'Discipline': discipline,
'Field': field,
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['difficulty'][metric_name.lower()]['total']})"
}
else:
base_row = {
'Discipline': discipline,
'Field': field,
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['total']})"
}
row_data = base_row.copy()
# Add score for each model
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
stats = subfield_stats['model_stats'].get(model_name, {}).get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
else:
value = '0.00%'
else:
if stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
value = f"{stats['correct'] / stats['total']:.2%}"
elif metric_name == 'Error Rate':
value = f"{stats['error'] / stats['total']:.2%}"
else: # Miss Rate
value = f"{stats['miss'] / stats['total']:.2%}"
else:
value = '0.00%'
row_data[f'{model_name}_{mode}'] = value
metrics[metric_name]['rows'].append(row_data)
# Add field summary row
for metric_name in metrics:
if metric_name in ['Hard', 'Middle', 'Easy']:
field_row = {
'Discipline': discipline,
'Field': f"{field} (Total: {field_stats['difficulty'][metric_name.lower()]['total']})",
'Subfield': ''
}
else:
field_row = {
'Discipline': discipline,
'Field': f"{field} (Total: {field_total})",
'Subfield': ''
}
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
stats = field_stats['model_stats'].get(model_name, {}).get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
else:
value = '0.00%'
else:
if stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
value = f"{stats['correct'] / stats['total']:.2%}"
elif metric_name == 'Error Rate':
value = f"{stats['error'] / stats['total']:.2%}"
else: # Miss Rate
value = f"{stats['miss'] / stats['total']:.2%}"
else:
value = '0.00%'
field_row[f'{model_name}_{mode}'] = value
metrics[metric_name]['rows'].append(field_row)
# Add discipline summary row
for metric_name in metrics:
if metric_name in ['Hard', 'Middle', 'Easy']:
discipline_row = {
'Discipline': f"{discipline} (Total: {discipline_stats['difficulty'][metric_name.lower()]['total']})",
'Field': '',
'Subfield': ''
}
else:
discipline_row = {
'Discipline': f"{discipline} (Total: {discipline_total})",
'Field': '',
'Subfield': ''
}
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
stats = discipline_stats['model_stats'].get(model_name, {}).get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
else:
value = '0.00%'
else:
if stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
value = f"{stats['correct'] / stats['total']:.2%}"
elif metric_name == 'Error Rate':
value = f"{stats['error'] / stats['total']:.2%}"
else: # Miss Rate
value = f"{stats['miss'] / stats['total']:.2%}"
else:
value = '0.00%'
discipline_row[f'{model_name}_{mode}'] = value
metrics[metric_name]['rows'].append(discipline_row)
# Create DataFrames
dfs = {metric: pd.DataFrame(data['rows']) for metric, data in metrics.items()}
# Add overall summary row to each DataFrame
for metric_name, df in dfs.items():
total_rows = calculate_total_row(hierarchy_stats, model_results, metric_name)
dfs[metric_name] = pd.concat([df, pd.DataFrame(total_rows)], ignore_index=True)
# Save to Excel, one sheet per metric
with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
for metric_name, df in dfs.items():
df.to_excel(writer, sheet_name=metric_name, index=False)
format_worksheet(writer.sheets[metric_name], df, metrics[metric_name]['color'])
print(f"Report generation completed, Excel file saved: {save_path}")
def format_worksheet(worksheet, df, color):
"""Format worksheet"""
# Set default font
for row in worksheet.rows:
for cell in row:
cell.font = Font(name='Arial', color='000000') # Use black font uniformly
# Set background color
discipline_fill = PatternFill(start_color='FFFF00', end_color='FFFF00', fill_type='solid')
field_fill = PatternFill(start_color='FFFFD4', end_color='FFFFD4', fill_type='solid')
# Overall row background color
sample_wise_fill = PatternFill(start_color='B8CCE4', end_color='B8CCE4', fill_type='solid') # Bright but not bright blue
subfield_wise_fill = PatternFill(start_color='DCE6F1', end_color='DCE6F1', fill_type='solid') # Light blue
field_wise_fill = PatternFill(start_color='E9EEF5', end_color='E9EEF5', fill_type='solid') # Lighter blue
discipline_wise_fill = PatternFill(start_color='F2F5F9', end_color='F2F5F9', fill_type='solid') # Lightest blue
error_rate_fill = PatternFill(start_color='FFB6C1', end_color='FFB6C1', fill_type='solid') # Red
miss_rate_fill = PatternFill(start_color='D3D3D3', end_color='D3D3D3', fill_type='solid') # Gray
# Set column width
for column in worksheet.columns:
max_length = 0
column = list(column)
for cell in column:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except:
pass
adjusted_width = (max_length + 2)
worksheet.column_dimensions[column[0].column_letter].width = adjusted_width
# Merge cells and apply background color
current_discipline = None
discipline_start = None
current_field = None
field_start = None
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
discipline = row[0].value
field = row[1].value
# Process discipline (Discipline) merge
if discipline and "Total:" in str(discipline):
# If there was an unmerged discipline row before
if discipline_start and current_discipline:
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
# Apply background color to current total row
for cell in row:
cell.fill = discipline_fill
# Reset tracking variables
current_discipline = None
discipline_start = None
elif discipline and discipline != current_discipline:
# If there was an unmerged discipline row before
if discipline_start and current_discipline:
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
current_discipline = discipline
discipline_start = row_idx
# Process field (Field) merge
if field and "Total:" in str(field):
# If there was an unmerged field row before
if field_start and current_field:
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
# Apply background color to current total row
for cell in row:
cell.fill = field_fill
# Reset tracking variables
current_field = None
field_start = None
elif field and field != current_field:
# If there was an unmerged field row before
if field_start and current_field:
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
current_field = field
field_start = row_idx
# Process last unmerged cells
last_row = worksheet.max_row
if discipline_start and current_discipline:
worksheet.merge_cells(f'A{discipline_start}:A{last_row}')
if field_start and current_field:
worksheet.merge_cells(f'B{field_start}:B{last_row}')
# Apply special background color to Overall row
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
cell_value = row[0].value
if cell_value:
if 'Overall (sample-wise)' in str(cell_value):
for cell in row:
cell.fill = sample_wise_fill
elif 'Overall (subfield-wise)' in str(cell_value):
for cell in row:
cell.fill = subfield_wise_fill
elif 'Overall (field-wise)' in str(cell_value):
for cell in row:
cell.fill = field_wise_fill
elif 'Overall (discipline-wise)' in str(cell_value):
for cell in row:
cell.fill = discipline_wise_fill
elif worksheet.title == 'Error Rate' and 'Overall' in str(cell_value):
for cell in row:
cell.fill = error_rate_fill
elif worksheet.title == 'Miss Rate' and 'Overall' in str(cell_value):
for cell in row:
cell.fill = miss_rate_fill
# Set value format to keep two decimal places
for row in worksheet.iter_rows(min_row=2):
for cell in row[3:]: # Start from 4th column (skip Discipline, Field, Subfield columns)
if isinstance(cell.value, str) and '%' in cell.value:
try:
value = float(cell.value.strip('%')) / 100
cell.value = f"{value:.2%}"
except ValueError:
pass
# Set all cells to center alignment
for row in worksheet.rows:
for cell in row:
cell.alignment = Alignment(horizontal='center', vertical='center')
def format_cell_value(stats):
"""Format cell content, return string with acc/error/miss"""
total = stats['total']
if total == 0:
return '0%/0%/0%'
acc = stats['correct'] / total
error = stats['error'] / total
miss = stats['miss'] / total
return f"{acc:.1%}/{error:.1%}/{miss:.1%}"

View File

@ -0,0 +1,691 @@
import json
import os
import re
import sympy as sp
import yaml
from sympy.parsing.latex import parse_latex
def load_yaml(yaml_path):
"""Load a YAML file."""
if not os.path.exists(yaml_path):
raise FileNotFoundError(f'YAML file not found: {yaml_path}')
with open(yaml_path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
def load_json_or_jsonl(file_path):
"""Load data from a JSON or JSONL file."""
if not os.path.exists(file_path):
return None
with open(file_path, 'r', encoding='utf-8') as file:
if file_path.endswith('.json'):
return json.load(file)
elif file_path.endswith('.jsonl'):
return [json.loads(line) for line in file]
return None
def find_file(base_path, sub_path, extensions=('json', 'jsonl')):
"""Find the first available file with given extensions."""
for ext in extensions:
file_path = os.path.join(base_path, f'{sub_path}.{ext}')
if os.path.exists(file_path):
return file_path
return None
def load_json_or_jsonl_with_idx(data_path, split='', idx=None):
base_path = os.path.join(data_path, split)
if os.path.exists(f'{base_path}.json'):
file_path = f'{base_path}.json'
elif os.path.exists(f'{base_path}.jsonl'):
file_path = f'{base_path}.jsonl'
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
file_path = base_path
else:
raise FileNotFoundError('No JSON or JSONL file found.')
with open(file_path, 'r', encoding='utf-8') as file:
if file_path.endswith('.json'):
data = json.load(file)
elif file_path.endswith('.jsonl'):
data = [json.loads(line) for line in file]
if idx is not None:
try:
return next(item for item in data if item.get('idx') == idx)
except StopIteration:
raise ValueError(f'No entry found for idx {idx}')
else:
return data
def load_split_data(base_path, split_name):
"""Load the rule and sample data for a specific split."""
split_path = os.path.join(base_path, split_name)
rule_path = find_file(split_path, 'rule')
sample_path = find_file(split_path, 'sample')
rules = load_json_or_jsonl(rule_path) if rule_path else []
samples = load_json_or_jsonl(sample_path) if sample_path else []
return {'rules': rules, 'samples': samples}
def process_mixed_data(base_path, mode):
"""Load and process data for the 'mixed' split and specific mode."""
mixed_path = os.path.join(base_path, 'mixed')
file_path = find_file(mixed_path, mode)
if not file_path:
print(f'[WARNING] Missing file for mixed mode: {mode}')
return []
data = load_json_or_jsonl(file_path)
template_path = os.path.join(base_path, 'config/prompt/mixed.yaml')
template = load_yaml(template_path)
processed = []
for item in data:
rules = '\n'.join(item.get('rule_list', []))
questions = '\n'.join(item.get('question_list', []))
item['prompt'] = template['prompt_format'][0].format(rules, questions)
processed.append(item)
return processed
class ConfigWrapper:
def __init__(self, config_path):
self._config = {}
with open(config_path, 'r') as file:
self._config = yaml.safe_load(file)
for key, value in self._config.items():
setattr(self, key, value)
def __setattr__(self, key, value):
if key.startswith('_'):
super().__setattr__(key, value)
else:
self._config[key] = value
super().__setattr__(key, value)
def __getattr__(self, key):
if key in self._config:
return self._config[key]
raise AttributeError(
f"'ConfigWrapper' object has no attribute '{key}'")
def get_id(self, data):
if isinstance(self._config.get('id_key'), str):
return data.get(self._config.get('id_key'), None)
elif isinstance(self._config.get('id_key'), list):
return '_'.join([
str(data[key]) for key in self._config.get('id_key')
if key in data
])
def print_all_keys(self):
print('config keys:')
for key, value in self._config.items():
print(f' - {key}: {value}')
config_wrapper = None
def initialize_config(config_path):
global config_wrapper
config_wrapper = ConfigWrapper(config_path)
def get_config_wrapper():
global config_wrapper
if config_wrapper is None:
raise RuntimeError(
'ConfigWrapper not initialized. Call initialize_config first.')
return config_wrapper
if __name__ == '__main__':
config_path = 'config/config.yaml'
initialize_config(config_path)
data = {
'idx':
'50',
'step':
21,
'question':
('Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
'Please provide the decrypted answer, encapsulated in double '
'square brackets. '
'For example, the format should be: [[decrypted answer]].'),
'answer':
'[[P]]',
'category':
'Decryption',
'rule_id':
'23',
'input':
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
'steps_num':
23,
'description':
('For a number c=228 in the ciphertext:\n'
'Calculate z = c^e mod n. Here ^ means multiplication.\n'
'z is 80.\nBased on the decimal number represented by z, '
'use the ascii code to find the corresponding letter '
'as the plaintext letter p.\n'
'Please give the letter p in [[...]] format.\n'),
'atom':
80
}
print(config_wrapper.get_id(data))
def read_yaml(config='default'):
if os.path.exists(f'config/prompt/{config}.yaml'):
yaml_file = f'config/prompt/{config}.yaml'
else:
yaml_file = config
with open(yaml_file, 'r') as yaml_file:
return yaml.safe_load(yaml_file)
def write_jsonl_lines(file, data):
config_wrapper = get_config_wrapper()
if config_wrapper.save_prompt:
json.dump(data, file, ensure_ascii=False)
else:
data.pop(config_wrapper.prompt_key)
json.dump(data, file, ensure_ascii=False)
file.write('\n')
file.flush()
def print_info(info):
print('-' * 100)
print('[INFO] model_name:', info['model_name'])
print('[INFO] splits:', info['splits'])
print('[INFO] modes:', info['modes'])
print('[INFO] output_dir:', info['output_dir'])
print('[INFO] Infer Limit:',
'No limit' if info['infer_limit'] is None else info['infer_limit'])
print('[INFO] Number of Workers:', info['num_workers'])
print('[INFO] Batch Size:', info['batch_size'])
print('[INFO] Use Accel:', info['use_accel'])
print('-' * 100)
def read_json_or_jsonl(data_path, split='', mapping_key=None):
base_path = os.path.join(data_path, split)
if os.path.exists(f'{base_path}.json'):
file_path = f'{base_path}.json'
elif os.path.exists(f'{base_path}.jsonl'):
file_path = f'{base_path}.jsonl'
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
file_path = base_path
else:
raise FileNotFoundError('No JSON or JSONL file found.')
with open(file_path, 'r') as file:
if file_path.endswith('.json'):
data = json.load(file)
elif file_path.endswith('.jsonl'):
data = [json.loads(line) for line in file]
if mapping_key:
return {
item[mapping_key]: item
for item in data if mapping_key in item
}
else:
return data
def read_json_or_jsonl_with_idx(data_path, split='', idx=None):
base_path = os.path.join(data_path, split)
if os.path.exists(f'{base_path}.json'):
file_path = f'{base_path}.json'
elif os.path.exists(f'{base_path}.jsonl'):
file_path = f'{base_path}.jsonl'
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
file_path = base_path
else:
raise FileNotFoundError('No JSON or JSONL file found.')
with open(file_path, 'r', encoding='utf-8') as file:
if file_path.endswith('.json'):
data = json.load(file)
elif file_path.endswith('.jsonl'):
data = [json.loads(line) for line in file]
if idx is not None:
try:
return next(item for item in data if item.get('idx') == idx)
except StopIteration:
raise ValueError(f'No entry found for idx {idx}')
else:
return data
idx_ranges = [
[18],
[73, 74, 77],
[94],
[115, 116, 117],
[121, 122, 123, 125],
[131, 132, 134, 135, 136],
[141, 143, 149],
list(range(145, 148)),
list(range(151, 157)),
[160, 161, 162],
[164, 165, 166],
[170],
[206, 209],
list(range(211, 216)),
[217, 218],
]
def clean_json_string(json_str):
json_str = re.sub(r'[\x00-\x1F\x7F]', '', json_str)
return json_str
def is_in_idx_ranges(idx, idx_ranges):
for range_list in idx_ranges:
if int(idx) in range_list:
return True
return False
def extract_json(text):
matches = re.findall(r'{.*}', text, re.DOTALL)
if matches:
json_str = matches[-1]
json_str = clean_json_string(json_str)
try:
data = json.loads(json_str)
return data
except json.JSONDecodeError as e:
print(f'Error decoding JSON: {e}')
return 'NULL'
return 'NULL'
def extract_all_responses_from_json(response_json):
results = []
for key, value in response_json.items():
results.append(str(value))
return results
def clean_latex(latex_expr):
if '=' in latex_expr:
latex_expr = latex_expr.rsplit('=', 1)[1]
latex_expr = re.sub(r'\\[()\[\]]', '', latex_expr)
latex_expr = re.sub(r'\\text\{.*?\}', '', latex_expr)
latex_expr = re.sub(r'\\(left|right|displaystyle)', '', latex_expr)
latex_expr = latex_expr.replace('\\\\', '\\')
return latex_expr
def extract_text_from_brackets(text, clean_level='basic'):
matches = re.findall(r'\[\[\s*(.*?)\s*\]\]', text, re.DOTALL)
if not matches:
matches = re.findall(r'\$\\boxed\{(.*?)\}\$', text, re.DOTALL)
if not matches:
matches = re.findall(r'\[\s*(.*?)\s*\]', text, re.DOTALL)
if matches:
match_str = matches[0].strip()
if clean_level == 'clean':
match_str = match_str.replace('"', '').replace('\n', '').replace(
' ', '').replace('[', '').replace(']', '')
elif clean_level == 'logic':
match_str = match_str.replace('"', '').replace('\n', '').replace(
' ', '').replace('.', '')
elif clean_level == 'math':
match_str = match_str.replace('"', '').replace('\n', '').replace(
'[', '').replace(']', '').replace('$', '')
return f'{clean_latex(match_str)}'
return f'[[{match_str}]]'
return 'NULL'
def extract_inner_text_from_brackets(text):
if not isinstance(text, str):
print(f'text type: {type(text)}, text value: {text}')
return 'NULL'
match = re.search(r'\[\[(.*?)\]\]', text, re.DOTALL)
return match.group(1) if match else 'NULL'
def extract_numbers(str):
numbers = re.findall(r'\d+', str)
numbers = list(map(int, numbers))
return numbers
def extract_and_sort_inequalities(latex_expr):
pattern = r'(≥|≤)\s*([-]?\d+\.?\d*)'
matches = re.findall(pattern, latex_expr)
extracted_inequalities = [''.join(match) for match in matches]
sorted_inequalities = sorted(extracted_inequalities)
return sorted_inequalities
def rule5_normalize_content(content):
parts = [part for part in content.split(';')]
sorted_parts = sorted(parts)
return sorted_parts
def normalize_string(s):
s = re.sub(r'[^0-9]', '', s)
pairs = s.split(',')
pairs.sort()
return pairs
def remove_commas_and_spaces(s):
return re.sub(r'[,\s\[\]]+', '', s)
def remove_non_alphanumeric(s):
return re.sub(r'\W+', '', s)
def contains_or(answer):
return 'or' in answer
def compare_multi_results(response, answer):
try:
response_text = extract_text_from_brackets(response, 'clean')
response_text = re.sub(r'\\text\{or\}', 'or', response_text)
if response_text == 'NULL':
return False
answer = extract_text_from_brackets(answer, 'clean')
response_split = response_text.strip('[[]]').split('or')
answer_split = answer.strip('[[]]').split('or')
response_sorted = sorted([x.strip() for x in response_split])
answer_sorted = sorted([x.strip() for x in answer_split])
return response_sorted == answer_sorted
except Exception as e:
print(f'Error during comparison: {e}')
return False
def split_or_expression(expression):
return [part.strip() for part in expression.split('or')]
def compare_math_expressions(response, answer):
response_text = extract_text_from_brackets(response, 'math')
answer_text = extract_text_from_brackets(answer, 'math')
if response_text == 'NULL':
return False
if contains_or(answer_text):
response_parts = split_or_expression(response_text)
answer_parts = split_or_expression(answer_text)
try:
response_exprs = {
sp.simplify(parse_latex(part))
for part in response_parts
}
answer_exprs = {
sp.simplify(parse_latex(part))
for part in answer_parts
}
return response_exprs == answer_exprs
except Exception as e:
print(f'Error during simplification or parsing: {e}')
return response_text == answer_text
else:
try:
response_expr = sp.simplify(parse_latex(response_text))
answer_expr = sp.simplify(parse_latex(answer_text))
return response_expr == answer_expr
except Exception as e:
print(f'Error during simplification or parsing: {e}')
return response_text == answer_text
def method_equal(response_text, answer):
return response_text == answer
def method_1(response_text, answer):
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
cleaned_string = cleaned_string.lower()
answer = re.sub(r'[^A-Za-z]', '', answer)
answer = answer.lower()
return cleaned_string == answer
def method_2(response_text, answer):
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
cleaned_string = cleaned_string.lower()
answer = answer.split(',')
return cleaned_string in answer
def method_3(response_text, answer):
response_text = response_text.lower()
pairs1 = re.split(r'\W+', response_text)
pairs2 = answer.split(' ')
pairs1 = [word for word in pairs1 if word]
pairs1.sort()
pairs2.sort()
return pairs1 == pairs2
def method_4(response_text, answer):
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
cleaned_string = cleaned_string.lower()
return cleaned_string in answer
def method_5(response_text, answer):
response_text = re.sub(r'\s+', '', response_text)
response_text = response_text.split(',')
answer = answer.split(',')
response_text.sort()
answer.sort()
return response_text == answer
def method_9(response_text, answer):
response_text = response_text.replace('×', '*').replace('', '-')
answer = answer.replace('×', '*').replace('', '-')
def extract_operators(s):
return re.findall(r'[+\-*/]', s)
response_ops = extract_operators(response_text.split('=')[0])
answer_ops = extract_operators(answer.split('=')[0])
if response_ops != answer_ops:
return False
match = re.search(r'=\s*(-?\d+)', answer)
expected_result = int(match.group(1))
try:
left_side = response_text.split('=')[0]
result = eval(left_side)
except Exception as e:
print(f'Error during evaluation: {e}')
return False
return result == expected_result
def method_10(response_text, answer):
response_text = response_text.replace('×', '*').replace('', '-')
response_text = response_text.split('=')[0]
answer = answer.split('\n')[0].split('=')[0]
response_ops = sorted(remove_non_alphanumeric(response_text))
answer_ops = sorted(remove_non_alphanumeric(answer))
if response_ops != answer_ops:
return False
try:
result = eval(response_text)
except Exception as e:
print(f'Error during evaluation: {e}')
return False
return result == 24
def method_18(response_text, answer):
cleaned_s1 = remove_commas_and_spaces(response_text)
cleaned_s2 = remove_commas_and_spaces(answer)
return cleaned_s1 == cleaned_s2
def method_general(response_text, answer):
cleaned_s1 = remove_non_alphanumeric(response_text)
cleaned_s2 = remove_non_alphanumeric(answer)
return cleaned_s1 == cleaned_s2
question_methods = {
'1': method_1,
'2': method_2,
'3': method_3,
'4': method_4,
'5': method_5,
'9': method_9,
'10': method_10,
'18': method_18,
}
def evaluate_response_vs_answer(response, answer, question_type, rule_id, idx):
if question_type == 'logic' and rule_id == '5':
response_text = extract_text_from_brackets(response, 'logic')
answer_text = extract_text_from_brackets(answer, 'logic')
if response_text is None:
return False
normalized_response = rule5_normalize_content(response_text)
normalized_answer = rule5_normalize_content(answer)
return normalized_response == normalized_answer
elif question_type == 'logic':
response_text = extract_text_from_brackets(response, 'logic')
answer_text = extract_text_from_brackets(answer, 'logic')
return response_text == answer_text
elif question_type == 'operation' and (idx == '178' or idx == '179'):
response_text = extract_text_from_brackets(response, 'clean')
response_text = extract_and_sort_inequalities(response_text)
answer_text = extract_and_sort_inequalities(answer)
# print(response_text, answer_text)
return response_text == answer_text
elif question_type == 'operation' and rule_id == '18':
response_text = extract_text_from_brackets(response, 'clean')
answer = extract_inner_text_from_brackets(answer)
response_text = ''.join(sorted(re.sub(r'\W+', '', response_text)))
answer = ''.join(sorted(re.sub(r'\W+', '', answer)))
return response_text == answer
elif question_type == 'operation' and rule_id in {'23', '24', '25'}:
response_text = extract_text_from_brackets(response, 'clean')
if response_text is None:
return False
response_text = extract_numbers(response_text)
answer_text = extract_numbers(answer)
return response_text == answer_text
elif question_type == 'operation' and is_in_idx_ranges(idx, idx_ranges):
return compare_math_expressions(response, answer)
elif question_type == 'operation' and contains_or(answer):
return compare_multi_results(response, answer)
elif question_type == 'puzzle':
response_text = extract_inner_text_from_brackets(response)
answer = extract_inner_text_from_brackets(answer)
method = question_methods.get(rule_id)
if method:
return method(response_text, answer)
return method_general(response_text, answer)
else:
response_text = extract_text_from_brackets(response, 'clean')
return response_text == answer
def compute_one_mixed_question_pass_rate(idx,
question_list,
response_json,
base_path=None):
if response_json == 'NULL':
result_dict = {
'idx': idx,
'response': response_json,
'details': None,
'pass_rate': 0,
'is_correct': False
}
return result_dict
response_list = extract_all_responses_from_json(response_json)
correct_num = 0
results = []
for q_idx, question in enumerate(question_list):
category, question_idx = question.rsplit('_', 1)
question_content = load_json_or_jsonl_with_idx(base_path,
os.path.join(
category, 'sample'),
idx=question_idx)
answer = question_content['answer']
if q_idx >= len(response_list):
break
response = response_list[q_idx]
response_text = extract_text_from_brackets(response)
rule_id = question_content['rule_id']
is_correct = evaluate_response_vs_answer(response, answer, category,
rule_id, q_idx)
if is_correct:
correct_num += 1
results.append({
'question': question,
'response_text': response_text,
'answer': answer,
'is_correct': is_correct
})
pass_rate = correct_num / len(question_list)
question_correct = pass_rate == 1.0
result_dict = {
'idx': idx,
'response': response_json,
'details': results,
'pass_rate': pass_rate,
'is_correct': question_correct
}
return result_dict
def evaluate_responses(data, mode, base_path=None):
results = []
# Iterate over the values of the dictionary (numerical keys)
for key, record in data.items():
idx = key # Use the dictionary key as the "idx"
response = record.get('prediction', '')
question_type = record.get('category', '')
response_text = extract_text_from_brackets(response)
answer = record.get('gold', '')
rule_id = record.get('rule_id', '')
is_correct = evaluate_response_vs_answer(response, answer,
question_type, rule_id,
idx)
result_dict = {
'idx': idx,
'response': response,
'response_text': response_text,
'answer': answer,
'is_correct': is_correct
}
if question_type == 'counterfactual':
real_life_answer = record.get('real_life_answer', '')
is_real_life = evaluate_response_vs_answer(
response, real_life_answer, question_type, rule_id, idx)
result_dict['real_life_answer'] = real_life_answer
result_dict['is_real_life'] = is_real_life
if question_type == 'cipher' and mode == 'subquestions':
result_dict['type'] = record.get('type', '')
results.append(result_dict)
return results

View File

@ -403,6 +403,11 @@ DATASETS_MAPPING = {
"hf_id": "",
"local": "./data/OlympiadBench",
},
"opencompass/supergpqa": {
"ms_id": "",
"hf_id": "m-a-p/SuperGPQA",
"local": "./data/supergpqa",
},
}
DATASETS_URL = {