mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
support supergpqa
This commit is contained in:
parent
73c80953c6
commit
4e40563462
14
examples/eval_supergpqa.py
Normal file
14
examples/eval_supergpqa.py
Normal file
@ -0,0 +1,14 @@
|
||||
from mmengine import read_base
|
||||
|
||||
with read_base():
|
||||
# from opencompass.configs.datasets.supergpqa.supergpqa_mixed_gen_d00bdd import \
|
||||
# supergpqa_mixed_datasets as mixed_datasets
|
||||
from opencompass.configs.datasets.supergpqa.supergpqa_single_0_shot_gen import \
|
||||
supergpqa_0shot_single_datasets as zero_shot_datasets
|
||||
# from opencompass.configs.datasets.supergpqa.supergpqa_single_3_shot_gen import \
|
||||
# supergpqa_3shot_single_datasets as three_shot_datasets
|
||||
from opencompass.configs.models.hf_internlm.hf_internlm2_5_7b import \
|
||||
models as hf_internlm2_5_7b
|
||||
|
||||
datasets = zero_shot_datasets
|
||||
models = hf_internlm2_5_7b
|
@ -0,0 +1,55 @@
|
||||
from opencompass.datasets.supergpqa.supergpqa import SuperGPQADataset, SuperGPQAEvaluator
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
|
||||
|
||||
supergpqa_0shot_single_datasets = []
|
||||
prompt_template = dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=''
|
||||
)
|
||||
],
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt='{infer_prompt}' # f-string
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Reader configuration
|
||||
reader_cfg = dict(
|
||||
input_columns=['infer_prompt'],
|
||||
output_column='answer_letter',
|
||||
)
|
||||
|
||||
# Inference configuration
|
||||
infer_cfg = dict(
|
||||
prompt_template=prompt_template,
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=1024),
|
||||
)
|
||||
|
||||
# Evaluation configuration
|
||||
eval_cfg = dict(
|
||||
evaluator=dict(type=SuperGPQAEvaluator),
|
||||
pred_role='BOT',
|
||||
)
|
||||
supergpqa_dataset = dict(
|
||||
type=SuperGPQADataset,
|
||||
abbr='supergpqa',
|
||||
path='opencompass/supergpqa',
|
||||
prompt_mode='zero-shot',
|
||||
reader_cfg=reader_cfg,
|
||||
infer_cfg=infer_cfg,
|
||||
eval_cfg=eval_cfg,
|
||||
)
|
||||
# print(type(supergpqa_0shot_single_datasets))
|
||||
|
||||
supergpqa_0shot_single_datasets.append(supergpqa_dataset)
|
@ -0,0 +1,347 @@
|
||||
from opencompass.datasets.supergpqa.supergpqa import SuperGPQADataset, SuperGPQAEvaluator
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
|
||||
categories =[
|
||||
"Power Systems and Automation",
|
||||
"Anesthesiology",
|
||||
"Oncology",
|
||||
"Group Theory",
|
||||
"Thermal Energy Engineering",
|
||||
"Emergency Medicine",
|
||||
"Systems Science",
|
||||
"Geometry and Topology",
|
||||
"Advanced Algebra",
|
||||
"Electrical Theory and New Technologies",
|
||||
"Engineering Thermophysics",
|
||||
"Operating Systems",
|
||||
"Guidance, Navigation and Control",
|
||||
"Harmony",
|
||||
"Marine Biology",
|
||||
"Pediatrics",
|
||||
"Road and Railway Engineering",
|
||||
"Information Management and Communication",
|
||||
"Water conservancy and Hydropower Engineering",
|
||||
"Veterinary Medicine",
|
||||
"Astronomical Observation and Technology",
|
||||
"Special Number Theory",
|
||||
"Philology and Bibliography",
|
||||
"Textile Materials Science",
|
||||
"Legal Theory and Legal History",
|
||||
"Education Economics, Management and Social Security",
|
||||
"Traditional Chinese Health Preservation",
|
||||
"Epidemiology and Health Statistics",
|
||||
"Pitch and Scales",
|
||||
"Economic History",
|
||||
"Marine Engineering",
|
||||
"Labor Economics",
|
||||
"Materials Processing Engineering",
|
||||
"Demography and Anthropology",
|
||||
"Preschool Education",
|
||||
"Music History, Education, and Technology",
|
||||
"Instrumentation and Performance",
|
||||
"Cryptography",
|
||||
"Mineralogy, Petrology, and Economic Geology",
|
||||
"Microbiology and Biochemical Pharmacy",
|
||||
"Poromechanics and Reservoir Physics",
|
||||
"Imaging and Nuclear Medicine",
|
||||
"Solid State Physics",
|
||||
"Microelectronics and Solid-State Electronics",
|
||||
"Zoology",
|
||||
"Food Biochemistry",
|
||||
"Traditional Chinese Pharmacy",
|
||||
"Neurology",
|
||||
"Hydrogeology",
|
||||
"Criminal Law",
|
||||
"Radiation Medicine",
|
||||
"Relativity",
|
||||
"Analytical Chemistry",
|
||||
"Signal and Information Processing",
|
||||
"Military Command and Information Systems",
|
||||
"Literary Theory",
|
||||
"Textile Chemistry and Dyeing Engineering",
|
||||
"Urban Infrastructure Engineering",
|
||||
"Stellar and Interstellar Evolution",
|
||||
"Geological Resources and Geological Engineering",
|
||||
"Pattern Recognition",
|
||||
"Engineering Fluid Mechanics",
|
||||
"Communication and Information Systems",
|
||||
"Architectural History",
|
||||
"Stochastic Processes",
|
||||
"Microbiology",
|
||||
"French Language and Literature",
|
||||
"Principles of Computer Organization",
|
||||
"Architectural Design and Theory",
|
||||
"Animal Rearing and Breeding",
|
||||
"Physical Oceanography",
|
||||
"Acoustics",
|
||||
"Organic Chemistry",
|
||||
"Refrigeration and Cryogenic Engineering",
|
||||
"Public Finance",
|
||||
"Dermatology and Venereology",
|
||||
"Religious Studies",
|
||||
"Discrete Mathematics",
|
||||
"Forest Cultivation and Genetic Breeding",
|
||||
"Vehicle Operation Engineering",
|
||||
"Physical Chemistry",
|
||||
"Nutrition and Food Hygiene",
|
||||
"Ship Mechanics and Design Principles",
|
||||
"Power Electronics and Electrical Drives",
|
||||
"Finance",
|
||||
"Pharmacology",
|
||||
"Environmental Engineering",
|
||||
"Ecology",
|
||||
"Aeronautical and Astronautical Science and Technology",
|
||||
"Agricultural Mechanization Engineering",
|
||||
"Computer Architecture",
|
||||
"Political Economy",
|
||||
"Principles of Seismic Exploration",
|
||||
"Elements of Chemical Reaction Engineering",
|
||||
"Digital Surveying and Remote Sensing Applications",
|
||||
"History and Theory of Journalism and Media Management",
|
||||
"Instrument Science and Technology",
|
||||
"Structural Engineering",
|
||||
"Computer Networks",
|
||||
"Power Machinery and Engineering",
|
||||
"Constitutional and Administrative Law",
|
||||
"Law and Social Governance",
|
||||
"Psychology",
|
||||
"Urban Planning and Design",
|
||||
"Thermodynamics and Statistical Physics",
|
||||
"Chemical Transport Engineering",
|
||||
"Environmental and Resource Protection",
|
||||
"Fluid Machinery and Engineering",
|
||||
"Cartography and Geographic Information Engineering",
|
||||
"Computational Mathematics",
|
||||
"Pathogen Biology",
|
||||
"Human Geography",
|
||||
"Theoretical Optics",
|
||||
"Solid Mechanics",
|
||||
"Electrochemistry",
|
||||
"Aquaculture",
|
||||
"Logic",
|
||||
"Mechatronic Engineering",
|
||||
"Modern and Contemporary Chinese Literature",
|
||||
"Operations Research and Cybernetics",
|
||||
"Circuits and Systems",
|
||||
"Internal Combustion Engineering",
|
||||
"Atomic and Molecular Physics",
|
||||
"Marine Chemistry",
|
||||
"Electromagnetic Field and Microwave Technology",
|
||||
"Rigid Body Mechanics",
|
||||
"Physiology",
|
||||
"Military Chemistry and Pyrotechnics",
|
||||
"Fundamentals of Dynamics and Control",
|
||||
"Control Theory and Control Engineering",
|
||||
"Historical Geography",
|
||||
"Physical Geography",
|
||||
"National and Defense Economics",
|
||||
"Polymer Physics",
|
||||
"Landscape Plants and Ornamental Horticulture",
|
||||
"Solar System Science",
|
||||
"Library and Archival Science",
|
||||
"Internal Medicine",
|
||||
"Physical Chemistry of Metallurgical Process",
|
||||
"Antenna and Radio Communication",
|
||||
"Genetics",
|
||||
"Graph Theory",
|
||||
"Principles of Metallurgy",
|
||||
"Bridge and Tunnel Engineering",
|
||||
"Combinatorial Mathematics",
|
||||
"Otorhinolaryngology",
|
||||
"Political Science",
|
||||
"Medicinal Chemistry",
|
||||
"Health Toxicology and Environmental Health",
|
||||
"Archaeology and Museology",
|
||||
"Geotechnical Engineering",
|
||||
"Land Resource Management and Administrative Management",
|
||||
"Thermodynamics",
|
||||
"Atmospheric Physics and Atmospheric Environment",
|
||||
"Broadcasting and Television Art",
|
||||
"Numerical Analysis",
|
||||
"Statistical Mechanics",
|
||||
"Mineral Processing Engineering",
|
||||
"Mathematical Analysis",
|
||||
"Philosophy of Science and Technology",
|
||||
"Western Economics",
|
||||
"Data Structures",
|
||||
"Fine Arts",
|
||||
"Economic Statistics",
|
||||
"Environmental Science",
|
||||
"Military Thought and History",
|
||||
"Drama and Opera Studies",
|
||||
"Film Studies",
|
||||
"High Voltage and Insulation Technology",
|
||||
"Military Law",
|
||||
"Wood Science and Technology",
|
||||
"Obstetrics and Gynecology",
|
||||
"Hydraulics and Hydrology",
|
||||
"Cell Biology",
|
||||
"Biochemistry and Molecular Biology",
|
||||
"Fluid Flow and Heat Transfer in Chemical Engineering",
|
||||
"Formal Languages",
|
||||
"Optoelectronic Technology",
|
||||
"Crop Science",
|
||||
"Fundamental Mathematics",
|
||||
"Immunology",
|
||||
"Surgery",
|
||||
"Ophthalmology",
|
||||
"Social Medicine and Health Management",
|
||||
"Industrial Economics",
|
||||
"Traffic Information Engineering and Control",
|
||||
"Traditional Chinese Medicine Theory",
|
||||
"Polymer Chemistry and Physics",
|
||||
"Maternal, Child and Adolescent Health",
|
||||
"Radiation Protection and Nuclear Technology Applications",
|
||||
"Food Processing and Storage Engineering",
|
||||
"Fluid Physics",
|
||||
"Materials Physics and Chemistry",
|
||||
"Pharmaceutical Analysis",
|
||||
"Semiconductor Physics",
|
||||
"Optical Fiber Communication",
|
||||
"Ethics",
|
||||
"Psychiatry and Mental Health",
|
||||
"Management Science and Engineering",
|
||||
"Number Theory",
|
||||
"Contract Law",
|
||||
"Inorganic Chemistry",
|
||||
"Design Arts",
|
||||
"Human Anatomy and Histology-Embryology",
|
||||
"Iron and Steel Metallurgy",
|
||||
"Dance Studies",
|
||||
"Structural Geology",
|
||||
"Special Education",
|
||||
"Musical Forms and Analysis",
|
||||
"Philosophical Aesthetics",
|
||||
"Astrophysics",
|
||||
"Manufacturing Automation",
|
||||
"Quantum Mechanics",
|
||||
"Probability and Statistics",
|
||||
"Military Logistics and Equipment",
|
||||
"Heat Transfer",
|
||||
"Classical Chinese Literature",
|
||||
"Information Management Science",
|
||||
"Cosmology",
|
||||
"Educational Technology and Principles",
|
||||
"Ordinary Differential Equations",
|
||||
"Underwater Acoustics",
|
||||
"Business and Accounting Management",
|
||||
"Dynamic Meteorology",
|
||||
"Military Management",
|
||||
"Journalism and News Practice",
|
||||
"Animal Nutrition and Feed Science",
|
||||
"Applied Optics",
|
||||
"Theoretical Fluid Mechanics",
|
||||
"Communication Principles",
|
||||
"Physical Education and Training",
|
||||
"Geodesy and Surveying Engineering",
|
||||
"Meteorology",
|
||||
"Sports Science and Medicine",
|
||||
"Solid Earth Geophysics",
|
||||
"Particle and Nuclear Physics",
|
||||
"International Law",
|
||||
"Oil and Gas Field Development and Storage & Transportation Engineering",
|
||||
"Basic Stomatology",
|
||||
"Agricultural Environment and Soil-Water Engineering",
|
||||
"Geochemistry",
|
||||
"Procedural Law",
|
||||
"Botany",
|
||||
"Fuzzy Mathematics",
|
||||
"Paleontology and Stratigraphy",
|
||||
"Sports Humanities and Sociology",
|
||||
"Civil and Commercial Law",
|
||||
"Electrodynamics",
|
||||
"Mining and Safety Engineering",
|
||||
"Mass Transport and Separation Process in Chemical Engineering",
|
||||
"Advanced Programming Languages",
|
||||
"Laser Technology",
|
||||
"Weapon Systems Science and Engineering",
|
||||
"Quantitative Economics",
|
||||
"Theoretical Mechanics",
|
||||
"Nursing and Rehabilitation Medicine",
|
||||
"Databases",
|
||||
"Pharmaceutics",
|
||||
"Space physics",
|
||||
"Functions of Real Variables",
|
||||
"Non-ferrous Metallurgy",
|
||||
"Theory of Curriculum and Instruction",
|
||||
"Clinical Laboratory Diagnostics",
|
||||
"Clinical Stomatology",
|
||||
"Literary History",
|
||||
"Tourism Management and Technological Economics Management",
|
||||
"Communication and Broadcasting",
|
||||
"Pathology and Pathophysiology",
|
||||
"Functions of Complex Variables",
|
||||
"World History",
|
||||
"Forest Engineering",
|
||||
"Forensic Medicine",
|
||||
"Linguistics and Applied Linguistics",
|
||||
"Social and Folklore Studies",
|
||||
"Computer Software and Theory",
|
||||
"Subatomic and Atomic Physics",
|
||||
"Biophysics",
|
||||
"Radiochemistry",
|
||||
"Russian Language and Literature",
|
||||
"International Trade",
|
||||
"Geriatric Medicine",
|
||||
"Composition",
|
||||
"Transportation Planning and Management",
|
||||
"Polynomials and Series Expansions",
|
||||
"Nuclear Energy and Reactor Technology"
|
||||
]
|
||||
|
||||
supergpqa_0shot_single_datasets = []
|
||||
|
||||
for category in categories:
|
||||
|
||||
|
||||
prompt_template = dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=''
|
||||
)
|
||||
],
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt='{infer_prompt}' # f-string
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Reader configuration
|
||||
reader_cfg = dict(
|
||||
input_columns=['infer_prompt'],
|
||||
output_column='answer_letter',
|
||||
)
|
||||
|
||||
# Inference configuration
|
||||
infer_cfg = dict(
|
||||
prompt_template=prompt_template,
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=1024),
|
||||
)
|
||||
|
||||
# Evaluation configuration
|
||||
eval_cfg = dict(
|
||||
evaluator=dict(type=SuperGPQAEvaluator),
|
||||
pred_role='BOT',
|
||||
)
|
||||
supergpqa_dataset = dict(
|
||||
type=SuperGPQADataset,
|
||||
abbr=f'supergpqa_{category.replace(" ", "_")}',
|
||||
# abbr='supergpqa',
|
||||
path='opencompass/supergpqa',
|
||||
prompt_mode='zero-shot',
|
||||
category=category,
|
||||
reader_cfg=reader_cfg,
|
||||
infer_cfg=infer_cfg,
|
||||
eval_cfg=eval_cfg,
|
||||
)
|
||||
# print(type(supergpqa_0shot_single_datasets))
|
||||
supergpqa_0shot_single_datasets.append(supergpqa_dataset)
|
291
opencompass/configs/summarizers/groups/supergpqa.py
Normal file
291
opencompass/configs/summarizers/groups/supergpqa.py
Normal file
@ -0,0 +1,291 @@
|
||||
categories =[
|
||||
"Power Systems and Automation",
|
||||
"Anesthesiology",
|
||||
"Oncology",
|
||||
"Group Theory",
|
||||
"Thermal Energy Engineering",
|
||||
"Emergency Medicine",
|
||||
"Systems Science",
|
||||
"Geometry and Topology",
|
||||
"Advanced Algebra",
|
||||
"Electrical Theory and New Technologies",
|
||||
"Engineering Thermophysics",
|
||||
"Operating Systems",
|
||||
"Guidance, Navigation and Control",
|
||||
"Harmony",
|
||||
"Marine Biology",
|
||||
"Pediatrics",
|
||||
"Road and Railway Engineering",
|
||||
"Information Management and Communication",
|
||||
"Water conservancy and Hydropower Engineering",
|
||||
"Veterinary Medicine",
|
||||
"Astronomical Observation and Technology",
|
||||
"Special Number Theory",
|
||||
"Philology and Bibliography",
|
||||
"Textile Materials Science",
|
||||
"Legal Theory and Legal History",
|
||||
"Education Economics, Management and Social Security",
|
||||
"Traditional Chinese Health Preservation",
|
||||
"Epidemiology and Health Statistics",
|
||||
"Pitch and Scales",
|
||||
"Economic History",
|
||||
"Marine Engineering",
|
||||
"Labor Economics",
|
||||
"Materials Processing Engineering",
|
||||
"Demography and Anthropology",
|
||||
"Preschool Education",
|
||||
"Music History, Education, and Technology",
|
||||
"Instrumentation and Performance",
|
||||
"Cryptography",
|
||||
"Mineralogy, Petrology, and Economic Geology",
|
||||
"Microbiology and Biochemical Pharmacy",
|
||||
"Poromechanics and Reservoir Physics",
|
||||
"Imaging and Nuclear Medicine",
|
||||
"Solid State Physics",
|
||||
"Microelectronics and Solid-State Electronics",
|
||||
"Zoology",
|
||||
"Food Biochemistry",
|
||||
"Traditional Chinese Pharmacy",
|
||||
"Neurology",
|
||||
"Hydrogeology",
|
||||
"Criminal Law",
|
||||
"Radiation Medicine",
|
||||
"Relativity",
|
||||
"Analytical Chemistry",
|
||||
"Signal and Information Processing",
|
||||
"Military Command and Information Systems",
|
||||
"Literary Theory",
|
||||
"Textile Chemistry and Dyeing Engineering",
|
||||
"Urban Infrastructure Engineering",
|
||||
"Stellar and Interstellar Evolution",
|
||||
"Geological Resources and Geological Engineering",
|
||||
"Pattern Recognition",
|
||||
"Engineering Fluid Mechanics",
|
||||
"Communication and Information Systems",
|
||||
"Architectural History",
|
||||
"Stochastic Processes",
|
||||
"Microbiology",
|
||||
"French Language and Literature",
|
||||
"Principles of Computer Organization",
|
||||
"Architectural Design and Theory",
|
||||
"Animal Rearing and Breeding",
|
||||
"Physical Oceanography",
|
||||
"Acoustics",
|
||||
"Organic Chemistry",
|
||||
"Refrigeration and Cryogenic Engineering",
|
||||
"Public Finance",
|
||||
"Dermatology and Venereology",
|
||||
"Religious Studies",
|
||||
"Discrete Mathematics",
|
||||
"Forest Cultivation and Genetic Breeding",
|
||||
"Vehicle Operation Engineering",
|
||||
"Physical Chemistry",
|
||||
"Nutrition and Food Hygiene",
|
||||
"Ship Mechanics and Design Principles",
|
||||
"Power Electronics and Electrical Drives",
|
||||
"Finance",
|
||||
"Pharmacology",
|
||||
"Environmental Engineering",
|
||||
"Ecology",
|
||||
"Aeronautical and Astronautical Science and Technology",
|
||||
"Agricultural Mechanization Engineering",
|
||||
"Computer Architecture",
|
||||
"Political Economy",
|
||||
"Principles of Seismic Exploration",
|
||||
"Elements of Chemical Reaction Engineering",
|
||||
"Digital Surveying and Remote Sensing Applications",
|
||||
"History and Theory of Journalism and Media Management",
|
||||
"Instrument Science and Technology",
|
||||
"Structural Engineering",
|
||||
"Computer Networks",
|
||||
"Power Machinery and Engineering",
|
||||
"Constitutional and Administrative Law",
|
||||
"Law and Social Governance",
|
||||
"Psychology",
|
||||
"Urban Planning and Design",
|
||||
"Thermodynamics and Statistical Physics",
|
||||
"Chemical Transport Engineering",
|
||||
"Environmental and Resource Protection",
|
||||
"Fluid Machinery and Engineering",
|
||||
"Cartography and Geographic Information Engineering",
|
||||
"Computational Mathematics",
|
||||
"Pathogen Biology",
|
||||
"Human Geography",
|
||||
"Theoretical Optics",
|
||||
"Solid Mechanics",
|
||||
"Electrochemistry",
|
||||
"Aquaculture",
|
||||
"Logic",
|
||||
"Mechatronic Engineering",
|
||||
"Modern and Contemporary Chinese Literature",
|
||||
"Operations Research and Cybernetics",
|
||||
"Circuits and Systems",
|
||||
"Internal Combustion Engineering",
|
||||
"Atomic and Molecular Physics",
|
||||
"Marine Chemistry",
|
||||
"Electromagnetic Field and Microwave Technology",
|
||||
"Rigid Body Mechanics",
|
||||
"Physiology",
|
||||
"Military Chemistry and Pyrotechnics",
|
||||
"Fundamentals of Dynamics and Control",
|
||||
"Control Theory and Control Engineering",
|
||||
"Historical Geography",
|
||||
"Physical Geography",
|
||||
"National and Defense Economics",
|
||||
"Polymer Physics",
|
||||
"Landscape Plants and Ornamental Horticulture",
|
||||
"Solar System Science",
|
||||
"Library and Archival Science",
|
||||
"Internal Medicine",
|
||||
"Physical Chemistry of Metallurgical Process",
|
||||
"Antenna and Radio Communication",
|
||||
"Genetics",
|
||||
"Graph Theory",
|
||||
"Principles of Metallurgy",
|
||||
"Bridge and Tunnel Engineering",
|
||||
"Combinatorial Mathematics",
|
||||
"Otorhinolaryngology",
|
||||
"Political Science",
|
||||
"Medicinal Chemistry",
|
||||
"Health Toxicology and Environmental Health",
|
||||
"Archaeology and Museology",
|
||||
"Geotechnical Engineering",
|
||||
"Land Resource Management and Administrative Management",
|
||||
"Thermodynamics",
|
||||
"Atmospheric Physics and Atmospheric Environment",
|
||||
"Broadcasting and Television Art",
|
||||
"Numerical Analysis",
|
||||
"Statistical Mechanics",
|
||||
"Mineral Processing Engineering",
|
||||
"Mathematical Analysis",
|
||||
"Philosophy of Science and Technology",
|
||||
"Western Economics",
|
||||
"Data Structures",
|
||||
"Fine Arts",
|
||||
"Economic Statistics",
|
||||
"Environmental Science",
|
||||
"Military Thought and History",
|
||||
"Drama and Opera Studies",
|
||||
"Film Studies",
|
||||
"High Voltage and Insulation Technology",
|
||||
"Military Law",
|
||||
"Wood Science and Technology",
|
||||
"Obstetrics and Gynecology",
|
||||
"Hydraulics and Hydrology",
|
||||
"Cell Biology",
|
||||
"Biochemistry and Molecular Biology",
|
||||
"Fluid Flow and Heat Transfer in Chemical Engineering",
|
||||
"Formal Languages",
|
||||
"Optoelectronic Technology",
|
||||
"Crop Science",
|
||||
"Fundamental Mathematics",
|
||||
"Immunology",
|
||||
"Surgery",
|
||||
"Ophthalmology",
|
||||
"Social Medicine and Health Management",
|
||||
"Industrial Economics",
|
||||
"Traffic Information Engineering and Control",
|
||||
"Traditional Chinese Medicine Theory",
|
||||
"Polymer Chemistry and Physics",
|
||||
"Maternal, Child and Adolescent Health",
|
||||
"Radiation Protection and Nuclear Technology Applications",
|
||||
"Food Processing and Storage Engineering",
|
||||
"Fluid Physics",
|
||||
"Materials Physics and Chemistry",
|
||||
"Pharmaceutical Analysis",
|
||||
"Semiconductor Physics",
|
||||
"Optical Fiber Communication",
|
||||
"Ethics",
|
||||
"Psychiatry and Mental Health",
|
||||
"Management Science and Engineering",
|
||||
"Number Theory",
|
||||
"Contract Law",
|
||||
"Inorganic Chemistry",
|
||||
"Design Arts",
|
||||
"Human Anatomy and Histology-Embryology",
|
||||
"Iron and Steel Metallurgy",
|
||||
"Dance Studies",
|
||||
"Structural Geology",
|
||||
"Special Education",
|
||||
"Musical Forms and Analysis",
|
||||
"Philosophical Aesthetics",
|
||||
"Astrophysics",
|
||||
"Manufacturing Automation",
|
||||
"Quantum Mechanics",
|
||||
"Probability and Statistics",
|
||||
"Military Logistics and Equipment",
|
||||
"Heat Transfer",
|
||||
"Classical Chinese Literature",
|
||||
"Information Management Science",
|
||||
"Cosmology",
|
||||
"Educational Technology and Principles",
|
||||
"Ordinary Differential Equations",
|
||||
"Underwater Acoustics",
|
||||
"Business and Accounting Management",
|
||||
"Dynamic Meteorology",
|
||||
"Military Management",
|
||||
"Journalism and News Practice",
|
||||
"Animal Nutrition and Feed Science",
|
||||
"Applied Optics",
|
||||
"Theoretical Fluid Mechanics",
|
||||
"Communication Principles",
|
||||
"Physical Education and Training",
|
||||
"Geodesy and Surveying Engineering",
|
||||
"Meteorology",
|
||||
"Sports Science and Medicine",
|
||||
"Solid Earth Geophysics",
|
||||
"Particle and Nuclear Physics",
|
||||
"International Law",
|
||||
"Oil and Gas Field Development and Storage & Transportation Engineering",
|
||||
"Basic Stomatology",
|
||||
"Agricultural Environment and Soil-Water Engineering",
|
||||
"Geochemistry",
|
||||
"Procedural Law",
|
||||
"Botany",
|
||||
"Fuzzy Mathematics",
|
||||
"Paleontology and Stratigraphy",
|
||||
"Sports Humanities and Sociology",
|
||||
"Civil and Commercial Law",
|
||||
"Electrodynamics",
|
||||
"Mining and Safety Engineering",
|
||||
"Mass Transport and Separation Process in Chemical Engineering",
|
||||
"Advanced Programming Languages",
|
||||
"Laser Technology",
|
||||
"Weapon Systems Science and Engineering",
|
||||
"Quantitative Economics",
|
||||
"Theoretical Mechanics",
|
||||
"Nursing and Rehabilitation Medicine",
|
||||
"Databases",
|
||||
"Pharmaceutics",
|
||||
"Space physics",
|
||||
"Functions of Real Variables",
|
||||
"Non-ferrous Metallurgy",
|
||||
"Theory of Curriculum and Instruction",
|
||||
"Clinical Laboratory Diagnostics",
|
||||
"Clinical Stomatology",
|
||||
"Literary History",
|
||||
"Tourism Management and Technological Economics Management",
|
||||
"Communication and Broadcasting",
|
||||
"Pathology and Pathophysiology",
|
||||
"Functions of Complex Variables",
|
||||
"World History",
|
||||
"Forest Engineering",
|
||||
"Forensic Medicine",
|
||||
"Linguistics and Applied Linguistics",
|
||||
"Social and Folklore Studies",
|
||||
"Computer Software and Theory",
|
||||
"Subatomic and Atomic Physics",
|
||||
"Biophysics",
|
||||
"Radiochemistry",
|
||||
"Russian Language and Literature",
|
||||
"International Trade",
|
||||
"Geriatric Medicine",
|
||||
"Composition",
|
||||
"Transportation Planning and Management",
|
||||
"Polynomials and Series Expansions",
|
||||
"Nuclear Energy and Reactor Technology"
|
||||
]
|
||||
|
||||
supergpqa_summary_groups = [
|
||||
{'name': 'supergpqa', 'subsets': ['supergpqa_' + c.replace(' ', '_') for c in categories]},
|
||||
]
|
296
opencompass/configs/summarizers/supergpqa.py
Normal file
296
opencompass/configs/summarizers/supergpqa.py
Normal file
@ -0,0 +1,296 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .groups.supergpqa import supergpqa_summary_groups
|
||||
|
||||
summarizer = dict(
|
||||
dataset_abbrs=[
|
||||
"supergpqa",
|
||||
"supergpqa_Procedural_Law",
|
||||
"supergpqa_Microbiology",
|
||||
"supergpqa_World_History",
|
||||
"supergpqa_Civil_and_Commercial_Law",
|
||||
"supergpqa_Relativity",
|
||||
"supergpqa_Discrete_Mathematics",
|
||||
"supergpqa_Laser_Technology",
|
||||
"supergpqa_Power_Machinery_and_Engineering",
|
||||
"supergpqa_Geotechnical_Engineering",
|
||||
"supergpqa_Mineralogy,_Petrology,_and_Economic_Geology",
|
||||
"supergpqa_Fluid_Flow_and_Heat_Transfer_in_Chemical_Engineering",
|
||||
"supergpqa_Composition",
|
||||
"supergpqa_Biophysics",
|
||||
"supergpqa_Geriatric_Medicine",
|
||||
"supergpqa_Cell_Biology",
|
||||
"supergpqa_Underwater_Acoustics",
|
||||
"supergpqa_Political_Science",
|
||||
"supergpqa_Atomic_and_Molecular_Physics",
|
||||
"supergpqa_Industrial_Economics",
|
||||
"supergpqa_Marine_Chemistry",
|
||||
"supergpqa_Ophthalmology",
|
||||
"supergpqa_Geochemistry",
|
||||
"supergpqa_Anesthesiology",
|
||||
"supergpqa_Sports_Science_and_Medicine",
|
||||
"supergpqa_Forest_Cultivation_and_Genetic_Breeding",
|
||||
"supergpqa_Philology_and_Bibliography",
|
||||
"supergpqa_Cryptography",
|
||||
"supergpqa_Road_and_Railway_Engineering",
|
||||
"supergpqa_Literary_History",
|
||||
"supergpqa_Mining_and_Safety_Engineering",
|
||||
"supergpqa_Group_Theory",
|
||||
"supergpqa_Crop_Science",
|
||||
"supergpqa_Food_Biochemistry",
|
||||
"supergpqa_Textile_Materials_Science",
|
||||
"supergpqa_Fundamental_Mathematics",
|
||||
"supergpqa_Microelectronics_and_Solid-State_Electronics",
|
||||
"supergpqa_International_Law",
|
||||
"supergpqa_Agricultural_Environment_and_Soil-Water_Engineering",
|
||||
"supergpqa_Environmental_Science",
|
||||
"supergpqa_Urban_Infrastructure_Engineering",
|
||||
"supergpqa_Solid_State_Physics",
|
||||
"supergpqa_Mechatronic_Engineering",
|
||||
"supergpqa_Economic_History",
|
||||
"supergpqa_Power_Electronics_and_Electrical_Drives",
|
||||
"supergpqa_History_and_Theory_of_Journalism_and_Media_Management",
|
||||
"supergpqa_Neurology",
|
||||
"supergpqa_Computer_Networks",
|
||||
"supergpqa_Animal_Nutrition_and_Feed_Science",
|
||||
"supergpqa_Marine_Engineering",
|
||||
"supergpqa_Materials_Physics_and_Chemistry",
|
||||
"supergpqa_Business_and_Accounting_Management",
|
||||
"supergpqa_Basic_Stomatology",
|
||||
"supergpqa_Space_physics",
|
||||
"supergpqa_Transportation_Planning_and_Management",
|
||||
"supergpqa_Information_Management_and_Communication",
|
||||
"supergpqa_Quantitative_Economics",
|
||||
"supergpqa_Elements_of_Chemical_Reaction_Engineering",
|
||||
"supergpqa_Library_and_Archival_Science",
|
||||
"supergpqa_Electrodynamics",
|
||||
"supergpqa_Fluid_Machinery_and_Engineering",
|
||||
"supergpqa_Dynamic_Meteorology",
|
||||
"supergpqa_Functions_of_Real_Variables",
|
||||
"supergpqa_Pharmacology",
|
||||
"supergpqa_Communication_Principles",
|
||||
"supergpqa_Communication_and_Broadcasting",
|
||||
"supergpqa_Musical_Forms_and_Analysis",
|
||||
"supergpqa_Cartography_and_Geographic_Information_Engineering",
|
||||
"supergpqa_Maternal,_Child_and_Adolescent_Health",
|
||||
"supergpqa_Clinical_Stomatology",
|
||||
"supergpqa_Data_Structures",
|
||||
"supergpqa_Optoelectronic_Technology",
|
||||
"supergpqa_Physiology",
|
||||
"supergpqa_Thermodynamics_and_Statistical_Physics",
|
||||
"supergpqa_Pediatrics",
|
||||
"supergpqa_Geodesy_and_Surveying_Engineering",
|
||||
"supergpqa_Theoretical_Mechanics",
|
||||
"supergpqa_Hydraulics_and_Hydrology",
|
||||
"supergpqa_International_Trade",
|
||||
"supergpqa_Military_Chemistry_and_Pyrotechnics",
|
||||
"supergpqa_Finance",
|
||||
"supergpqa_Psychiatry_and_Mental_Health",
|
||||
"supergpqa_Fundamentals_of_Dynamics_and_Control",
|
||||
"supergpqa_Sports_Humanities_and_Sociology",
|
||||
"supergpqa_Harmony",
|
||||
"supergpqa_Control_Theory_and_Control_Engineering",
|
||||
"supergpqa_Surgery",
|
||||
"supergpqa_Analytical_Chemistry",
|
||||
"supergpqa_Political_Economy",
|
||||
"supergpqa_Theory_of_Curriculum_and_Instruction",
|
||||
"supergpqa_High_Voltage_and_Insulation_Technology",
|
||||
"supergpqa_Numerical_Analysis",
|
||||
"supergpqa_Physical_Geography",
|
||||
"supergpqa_Physical_Education_and_Training",
|
||||
"supergpqa_Applied_Optics",
|
||||
"supergpqa_Mathematical_Analysis",
|
||||
"supergpqa_Advanced_Programming_Languages",
|
||||
"supergpqa_Western_Economics",
|
||||
"supergpqa_Organic_Chemistry",
|
||||
"supergpqa_French_Language_and_Literature",
|
||||
"supergpqa_Urban_Planning_and_Design",
|
||||
"supergpqa_Polynomials_and_Series_Expansions",
|
||||
"supergpqa_Functions_of_Complex_Variables",
|
||||
"supergpqa_Advanced_Algebra",
|
||||
"supergpqa_Operating_Systems",
|
||||
"supergpqa_Internal_Combustion_Engineering",
|
||||
"supergpqa_Food_Processing_and_Storage_Engineering",
|
||||
"supergpqa_Educational_Technology_and_Principles",
|
||||
"supergpqa_Acoustics",
|
||||
"supergpqa_Quantum_Mechanics",
|
||||
"supergpqa_Iron_and_Steel_Metallurgy",
|
||||
"supergpqa_Land_Resource_Management_and_Administrative_Management",
|
||||
"supergpqa_Fuzzy_Mathematics",
|
||||
"supergpqa_Special_Education",
|
||||
"supergpqa_Solid_Mechanics",
|
||||
"supergpqa_Zoology",
|
||||
"supergpqa_Demography_and_Anthropology",
|
||||
"supergpqa_Tourism_Management_and_Technological_Economics_Management",
|
||||
"supergpqa_Theoretical_Optics",
|
||||
"supergpqa_Genetics",
|
||||
"supergpqa_Constitutional_and_Administrative_Law",
|
||||
"supergpqa_Structural_Engineering",
|
||||
"supergpqa_Principles_of_Metallurgy",
|
||||
"supergpqa_Medicinal_Chemistry",
|
||||
"supergpqa_Electromagnetic_Field_and_Microwave_Technology",
|
||||
"supergpqa_Clinical_Laboratory_Diagnostics",
|
||||
"supergpqa_Theoretical_Fluid_Mechanics",
|
||||
"supergpqa_Pitch_and_Scales",
|
||||
"supergpqa_Stochastic_Processes",
|
||||
"supergpqa_Ethics",
|
||||
"supergpqa_Circuits_and_Systems",
|
||||
"supergpqa_Engineering_Thermophysics",
|
||||
"supergpqa_Landscape_Plants_and_Ornamental_Horticulture",
|
||||
"supergpqa_Polymer_Physics",
|
||||
"supergpqa_Wood_Science_and_Technology",
|
||||
"supergpqa_Biochemistry_and_Molecular_Biology",
|
||||
"supergpqa_Preschool_Education",
|
||||
"supergpqa_Psychology",
|
||||
"supergpqa_Traditional_Chinese_Health_Preservation",
|
||||
"supergpqa_Modern_and_Contemporary_Chinese_Literature",
|
||||
"supergpqa_Religious_Studies",
|
||||
"supergpqa_Subatomic_and_Atomic_Physics",
|
||||
"supergpqa_Human_Geography",
|
||||
"supergpqa_Water_conservancy_and_Hydropower_Engineering",
|
||||
"supergpqa_Thermal_Energy_Engineering",
|
||||
"supergpqa_Immunology",
|
||||
"supergpqa_Communication_and_Information_Systems",
|
||||
"supergpqa_Meteorology",
|
||||
"supergpqa_Bridge_and_Tunnel_Engineering",
|
||||
"supergpqa_Military_Management",
|
||||
"supergpqa_Russian_Language_and_Literature",
|
||||
"supergpqa_Particle_and_Nuclear_Physics",
|
||||
"supergpqa_Rigid_Body_Mechanics",
|
||||
"supergpqa_Nuclear_Energy_and_Reactor_Technology",
|
||||
"supergpqa_Oncology",
|
||||
"supergpqa_Public_Finance",
|
||||
"supergpqa_Classical_Chinese_Literature",
|
||||
"supergpqa_Ecology",
|
||||
"supergpqa_Principles_of_Computer_Organization",
|
||||
"supergpqa_Pattern_Recognition",
|
||||
"supergpqa_Databases",
|
||||
"supergpqa_Ordinary_Differential_Equations",
|
||||
"supergpqa_Electrochemistry",
|
||||
"supergpqa_Traditional_Chinese_Pharmacy",
|
||||
"supergpqa_Dance_Studies",
|
||||
"supergpqa_Pharmaceutical_Analysis",
|
||||
"supergpqa_Otorhinolaryngology",
|
||||
"supergpqa_Principles_of_Seismic_Exploration",
|
||||
"supergpqa_Physical_Chemistry",
|
||||
"supergpqa_Special_Number_Theory",
|
||||
"supergpqa_Astrophysics",
|
||||
"supergpqa_Physical_Oceanography",
|
||||
"supergpqa_Instrumentation_and_Performance",
|
||||
"supergpqa_Military_Law",
|
||||
"supergpqa_Signal_and_Information_Processing",
|
||||
"supergpqa_Thermodynamics",
|
||||
"supergpqa_Architectural_Design_and_Theory",
|
||||
"supergpqa_Non-ferrous_Metallurgy",
|
||||
"supergpqa_Internal_Medicine",
|
||||
"supergpqa_Film_Studies",
|
||||
"supergpqa_Fluid_Physics",
|
||||
"supergpqa_Refrigeration_and_Cryogenic_Engineering",
|
||||
"supergpqa_Broadcasting_and_Television_Art",
|
||||
"supergpqa_Social_Medicine_and_Health_Management",
|
||||
"supergpqa_Military_Logistics_and_Equipment",
|
||||
"supergpqa_Criminal_Law",
|
||||
"supergpqa_Electrical_Theory_and_New_Technologies",
|
||||
"supergpqa_Nutrition_and_Food_Hygiene",
|
||||
"supergpqa_Literary_Theory",
|
||||
"supergpqa_Instrument_Science_and_Technology",
|
||||
"supergpqa_Legal_Theory_and_Legal_History",
|
||||
"supergpqa_Computer_Architecture",
|
||||
"supergpqa_Chemical_Transport_Engineering",
|
||||
"supergpqa_Military_Thought_and_History",
|
||||
"supergpqa_Archaeology_and_Museology",
|
||||
"supergpqa_Architectural_History",
|
||||
"supergpqa_Microbiology_and_Biochemical_Pharmacy",
|
||||
"supergpqa_Philosophy_of_Science_and_Technology",
|
||||
"supergpqa_Labor_Economics",
|
||||
"supergpqa_Dermatology_and_Venereology",
|
||||
"supergpqa_Materials_Processing_Engineering",
|
||||
"supergpqa_Human_Anatomy_and_Histology-Embryology",
|
||||
"supergpqa_Optical_Fiber_Communication",
|
||||
"supergpqa_Journalism_and_News_Practice",
|
||||
"supergpqa_Emergency_Medicine",
|
||||
"supergpqa_Veterinary_Medicine",
|
||||
"supergpqa_Heat_Transfer",
|
||||
"supergpqa_Information_Management_Science",
|
||||
"supergpqa_Physical_Chemistry_of_Metallurgical_Process",
|
||||
"supergpqa_Radiochemistry",
|
||||
"supergpqa_Guidance,_Navigation_and_Control",
|
||||
"supergpqa_Solid_Earth_Geophysics",
|
||||
"supergpqa_Systems_Science",
|
||||
"supergpqa_Weapon_Systems_Science_and_Engineering",
|
||||
"supergpqa_Manufacturing_Automation",
|
||||
"supergpqa_Engineering_Fluid_Mechanics",
|
||||
"supergpqa_Mineral_Processing_Engineering",
|
||||
"supergpqa_Animal_Rearing_and_Breeding",
|
||||
"supergpqa_Philosophical_Aesthetics",
|
||||
"supergpqa_Solar_System_Science",
|
||||
"supergpqa_Antenna_and_Radio_Communication",
|
||||
"supergpqa_Computational_Mathematics",
|
||||
"supergpqa_Health_Toxicology_and_Environmental_Health",
|
||||
"supergpqa_Design_Arts",
|
||||
"supergpqa_Computer_Software_and_Theory",
|
||||
"supergpqa_Aquaculture",
|
||||
"supergpqa_Nursing_and_Rehabilitation_Medicine",
|
||||
"supergpqa_Inorganic_Chemistry",
|
||||
"supergpqa_Traffic_Information_Engineering_and_Control",
|
||||
"supergpqa_Botany",
|
||||
"supergpqa_Number_Theory",
|
||||
"supergpqa_Hydrogeology",
|
||||
"supergpqa_Marine_Biology",
|
||||
"supergpqa_Law_and_Social_Governance",
|
||||
"supergpqa_Contract_Law",
|
||||
"supergpqa_Vehicle_Operation_Engineering",
|
||||
"supergpqa_Aeronautical_and_Astronautical_Science_and_Technology",
|
||||
"supergpqa_Poromechanics_and_Reservoir_Physics",
|
||||
"supergpqa_Pathogen_Biology",
|
||||
"supergpqa_Power_Systems_and_Automation",
|
||||
"supergpqa_Epidemiology_and_Health_Statistics",
|
||||
"supergpqa_Drama_and_Opera_Studies",
|
||||
"supergpqa_Environmental_Engineering",
|
||||
"supergpqa_Polymer_Chemistry_and_Physics",
|
||||
"supergpqa_Digital_Surveying_and_Remote_Sensing_Applications",
|
||||
"supergpqa_Atmospheric_Physics_and_Atmospheric_Environment",
|
||||
"supergpqa_Education_Economics,_Management_and_Social_Security",
|
||||
"supergpqa_Probability_and_Statistics",
|
||||
"supergpqa_Geometry_and_Topology",
|
||||
"supergpqa_Linguistics_and_Applied_Linguistics",
|
||||
"supergpqa_Astronomical_Observation_and_Technology",
|
||||
"supergpqa_Forensic_Medicine",
|
||||
"supergpqa_Fine_Arts",
|
||||
"supergpqa_Paleontology_and_Stratigraphy",
|
||||
"supergpqa_Management_Science_and_Engineering",
|
||||
"supergpqa_Logic",
|
||||
"supergpqa_Agricultural_Mechanization_Engineering",
|
||||
"supergpqa_Traditional_Chinese_Medicine_Theory",
|
||||
"supergpqa_Obstetrics_and_Gynecology",
|
||||
"supergpqa_Ship_Mechanics_and_Design_Principles",
|
||||
"supergpqa_Statistical_Mechanics",
|
||||
"supergpqa_Combinatorial_Mathematics",
|
||||
"supergpqa_Mass_Transport_and_Separation_Process_in_Chemical_Engineering",
|
||||
"supergpqa_Economic_Statistics",
|
||||
"supergpqa_Operations_Research_and_Cybernetics",
|
||||
"supergpqa_Formal_Languages",
|
||||
"supergpqa_Oil_and_Gas_Field_Development_and_Storage_&_Transportation_Engineering",
|
||||
"supergpqa_Environmental_and_Resource_Protection",
|
||||
"supergpqa_Structural_Geology",
|
||||
"supergpqa_Semiconductor_Physics",
|
||||
"supergpqa_Social_and_Folklore_Studies",
|
||||
"supergpqa_Music_History,_Education,_and_Technology",
|
||||
"supergpqa_Radiation_Protection_and_Nuclear_Technology_Applications",
|
||||
"supergpqa_Pathology_and_Pathophysiology",
|
||||
"supergpqa_Textile_Chemistry_and_Dyeing_Engineering",
|
||||
"supergpqa_Military_Command_and_Information_Systems",
|
||||
"supergpqa_Forest_Engineering",
|
||||
"supergpqa_Graph_Theory",
|
||||
"supergpqa_Radiation_Medicine",
|
||||
"supergpqa_Geological_Resources_and_Geological_Engineering",
|
||||
"supergpqa_Historical_Geography",
|
||||
"supergpqa_Cosmology",
|
||||
"supergpqa_Pharmaceutics",
|
||||
"supergpqa_National_and_Defense_Economics",
|
||||
"supergpqa_Imaging_and_Nuclear_Medicine",
|
||||
"supergpqa_Stellar_and_Interstellar_Evolution"
|
||||
],
|
||||
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
|
||||
)
|
@ -146,3 +146,4 @@ from .xcopa import * # noqa: F401, F403
|
||||
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
|
||||
from .xlsum import * # noqa: F401, F403
|
||||
from .xsum import * # noqa: F401, F403
|
||||
from .supergpqa import *
|
0
opencompass/datasets/supergpqa/__init__.py
Normal file
0
opencompass/datasets/supergpqa/__init__.py
Normal file
152
opencompass/datasets/supergpqa/supergpqa.py
Normal file
152
opencompass/datasets/supergpqa/supergpqa.py
Normal file
@ -0,0 +1,152 @@
|
||||
import csv
|
||||
import json
|
||||
import os.path as osp
|
||||
from os import environ
|
||||
from datasets import load_dataset
|
||||
import os
|
||||
from datasets import Dataset, DatasetDict
|
||||
from opencompass.datasets.supergpqa.supergpqa_utils import (
|
||||
evaluate_responses, find_file, load_json_or_jsonl,
|
||||
load_json_or_jsonl_with_idx, load_yaml)
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS,LOAD_DATASET
|
||||
import unittest
|
||||
from opencompass.utils import get_data_path
|
||||
from opencompass.datasets.supergpqa.supergpqa_eval import (extract_option_labels,extract_option_content)
|
||||
from ..base import BaseDataset
|
||||
|
||||
|
||||
def _parse(item, template, prompt_mode):
|
||||
prompt_format = [item['question']+'\n'+'\n'.join([f'{chr(65+i)}) {option}' for i, option in enumerate(item['options'])])]
|
||||
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
|
||||
item['prompt_mode'] = prompt_mode
|
||||
return item
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class SuperGPQADataset(BaseDataset):
|
||||
@staticmethod
|
||||
def load(path: str, prompt_mode: str,category:str, **kwargs):
|
||||
path = get_data_path(path)
|
||||
dataset = load_dataset(path, split='train')
|
||||
dataset = dataset.filter(lambda x: x['subfield'] == category)
|
||||
|
||||
#get prompt template
|
||||
template_path = None
|
||||
if prompt_mode == 'zero-shot':
|
||||
template_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'supergpqa_dataset_config/prompt/zero-shot.yaml')
|
||||
elif prompt_mode == 'five-shot':
|
||||
template_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'supergpqa_dataset_config/prompt/five-shot.yaml')
|
||||
try:
|
||||
template = load_yaml(template_path)
|
||||
except FileNotFoundError:
|
||||
print(f'[ERROR] Missing prompt template: {template_path}')
|
||||
return Dataset.from_list([])
|
||||
|
||||
dataset =dataset.map(lambda item: _parse(item, template, prompt_mode))
|
||||
return dataset
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class SuperGPQAEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def score(self,predictions, references,test_set):
|
||||
mode= test_set[0]['prompt_mode']
|
||||
acc = 0
|
||||
count = 0
|
||||
err = 0
|
||||
miss = 0
|
||||
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
stats = {
|
||||
'discipline': {},
|
||||
'field': {},
|
||||
'subfield': {}
|
||||
}
|
||||
|
||||
for i,sample in enumerate(test_set):
|
||||
prediction=predictions[i]
|
||||
gold=references[i]
|
||||
if mode == 'zero-shot':
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
sample["extracted_answer"] = predict
|
||||
elif mode == 'five-shot':
|
||||
response = prediction.split('Question:')[0]
|
||||
predict = extract_option_labels(response, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(response, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
if predict == None:
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
sample["extracted_answer"] = predict
|
||||
|
||||
discipline = sample.get("discipline", "unknown")
|
||||
field = sample.get("field", "unknown")
|
||||
subfield = sample.get("subfield", "unknown")
|
||||
difficulty = sample.get("difficulty", "unknown")
|
||||
|
||||
for level, key in [
|
||||
('discipline', discipline),
|
||||
('field', f"{discipline}/{field}"),
|
||||
('subfield', f"{discipline}/{field}/{subfield}")
|
||||
]:
|
||||
if key not in stats[level]:
|
||||
stats[level][key] = {
|
||||
"correct": 0,
|
||||
"total": 0,
|
||||
"miss": 0,
|
||||
"error": 0,
|
||||
"discipline": discipline,
|
||||
"field": field,
|
||||
"subfield": subfield,
|
||||
"difficulty": {
|
||||
"easy": {"correct": 0, "total": 0},
|
||||
"middle": {"correct": 0, "total": 0},
|
||||
"hard": {"correct": 0, "total": 0}
|
||||
}
|
||||
}
|
||||
|
||||
stats[level][key]["total"] += 1
|
||||
stats[level][key]["difficulty"][difficulty]["total"] += 1
|
||||
|
||||
answer_letter = sample["answer_letter"]
|
||||
assert answer_letter==gold
|
||||
if predict and answer_letter == predict:
|
||||
acc += 1
|
||||
acc_difficulty[difficulty] += 1
|
||||
sample["status"] = "correct"
|
||||
stats[level][key]["correct"] += 1
|
||||
stats[level][key]["difficulty"][difficulty]["correct"] += 1
|
||||
elif predict == None or predict == "":
|
||||
miss += 1
|
||||
sample["status"] = "miss"
|
||||
stats[level][key]["miss"] += 1
|
||||
elif predict == 'error':
|
||||
err += 1
|
||||
sample["status"] = "error"
|
||||
stats[level][key]["error"] += 1
|
||||
else:
|
||||
sample["status"] = "incorrect"
|
||||
count += 1
|
||||
count_difficulty[difficulty] += 1
|
||||
|
||||
return {
|
||||
'accuracy': acc / count if count > 0 else 0,
|
||||
# 'error_rate': err / count if count > 0 else 0,
|
||||
# 'miss_rate': miss / count if count > 0 else 0,
|
||||
# 'hard_accuracy': acc_difficulty["hard"] / count_difficulty["hard"] if count_difficulty["hard"] > 0 else 0,
|
||||
# 'middle_accuracy': acc_difficulty["middle"] / count_difficulty["middle"] if count_difficulty["middle"] > 0 else 0,
|
||||
# 'easy_accuracy': acc_difficulty["easy"] / count_difficulty["easy"] if count_difficulty["easy"] > 0 else 0
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
response_key: 'response'
|
||||
error_key: 'error'
|
||||
id_key:
|
||||
- 'uuid'
|
||||
prompt_key: 'prompt'
|
||||
|
||||
|
||||
|
||||
history_key: 'history'
|
||||
status_key: 'status'
|
||||
|
||||
save_prompt: True
|
||||
max_tokens: 4096
|
||||
temperatrue: 0.0
|
||||
|
||||
max_rounds: 30
|
||||
BoN: 32
|
@ -0,0 +1,17 @@
|
||||
response_key: 'response'
|
||||
error_key: 'error'
|
||||
id_key:
|
||||
- 'uuid'
|
||||
prompt_key: 'prompt'
|
||||
|
||||
|
||||
|
||||
history_key: 'history'
|
||||
status_key: 'status'
|
||||
|
||||
save_prompt: True
|
||||
max_tokens: 32768
|
||||
temperatrue: 0.0
|
||||
|
||||
max_rounds: 30
|
||||
BoN: 32
|
@ -0,0 +1,51 @@
|
||||
import yaml
|
||||
import uuid
|
||||
|
||||
class ConfigWrapper:
|
||||
def __init__(self, config_path):
|
||||
self._config = {}
|
||||
with open(config_path, 'r') as file:
|
||||
self._config = yaml.safe_load(file)
|
||||
for key, value in self._config.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super().__setattr__(key, value)
|
||||
else:
|
||||
self._config[key] = value
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self._config:
|
||||
return self._config[key]
|
||||
raise AttributeError(f"'ConfigWrapper' object has no attribute '{key}'")
|
||||
|
||||
def get_id(self, data):
|
||||
if isinstance(self._config.get('id_key'), str):
|
||||
return data.get(self._config.get('id_key'), None)
|
||||
elif isinstance(self._config.get('id_key'), list):
|
||||
return '_'.join([str(data[key]) for key in self._config.get('id_key') if key in data])
|
||||
|
||||
def print_all_keys(self):
|
||||
print("config keys:")
|
||||
for key, value in self._config.items():
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
config_wrapper = None
|
||||
|
||||
def initialize_config(config_path):
|
||||
global config_wrapper
|
||||
config_wrapper = ConfigWrapper(config_path)
|
||||
|
||||
def get_config_wrapper():
|
||||
global config_wrapper
|
||||
if config_wrapper is None:
|
||||
raise RuntimeError("ConfigWrapper not initialized. Call initialize_config first.")
|
||||
return config_wrapper
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_path = 'config/config.yaml'
|
||||
initialize_config(config_path)
|
||||
data = {'idx': '50', 'step':21, 'question': 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\nPlease provide the decrypted answer, encapsulated in double square brackets. For example, the format should be: [[decrypted answer]].', 'answer': '[[P]]', 'category': 'Decryption', 'rule_id': '23', 'input': 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"', 'steps_num': 23, 'description': 'For a number c=228 in the ciphertext:\nCalculate z = c^e mod n. Here ^ means multiplication.\nz is 80.\nBased on the decimal number represented by z, use the ascii code to find the corresponding letter as the plaintext letter p.\nPlease give the letter p in [[...]] format.\n', 'atom': 80}
|
||||
print(config_wrapper.get_id(data))
|
@ -0,0 +1,91 @@
|
||||
prompt_format:
|
||||
- |
|
||||
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
Question:
|
||||
A refracting telescope consists of two converging lenses separated by 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification of the telescope is
|
||||
A) 10
|
||||
B) 40
|
||||
C) 6
|
||||
D) 25
|
||||
E) 15
|
||||
F) 50
|
||||
G) 30
|
||||
H) 4
|
||||
I) 5
|
||||
J) 20
|
||||
|
||||
Answer: Let's think step by step. In a refracting telescope, if both lenses are converging, the focus of both lenses must be between the two lenses, and thus the focal lengths of the two lenses must add up to their separation. Since the focal length of one lens is 20 cm, the focal length of the other must be 80 cm. The magnification is the ratio of these two focal lengths, or 4.
|
||||
Answer: H.
|
||||
|
||||
Question:
|
||||
Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?
|
||||
A) 1000 times more
|
||||
B) 50 times more
|
||||
C) 5000 times more
|
||||
D) 500 times more
|
||||
E) 10000 times more
|
||||
F) 20000 times more
|
||||
G) 2000 times more
|
||||
H) 100 times more
|
||||
I) 10 times more
|
||||
J) N/A
|
||||
|
||||
Answer: Let's think step by step. The amount of light a telescope can gather compared to the human eye is proportional to the area of its apertures. The area of a circle is given by the formula $A = \pi \left(\frac{{D}}{{2}}\right)^2$, where $D$ is the diameter. Therefore, the relative light-gathering power is calculated as:
|
||||
\[
|
||||
\frac{{\left(\frac{{50 \text{{ cm}}}}{{2}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{2}}\right)^2}} = \frac{{\left(\frac{{50 \text{{ cm}}}}{{0.1 \text{{ cm}}}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{0.1 \text{{ cm}}}}\right)^2}} = \frac{{500^2}}{{5^2}} = 10000.
|
||||
\]
|
||||
Answer: E.
|
||||
|
||||
Question:
|
||||
Where do most short-period comets come from and how do we know?
|
||||
A) The Kuiper belt; short period comets tend to be in the plane of the solar system like the Kuiper belt.
|
||||
B) The asteroid belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the asteroid belt.
|
||||
C) The asteroid belt; short period comets tend to be in the plane of the solar system just like the asteroid belt.
|
||||
D) The Oort cloud; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the Oort cloud.
|
||||
E) The Oort Cloud; short period comets tend to come from random directions indicating a spherical distribution of comets called the Oort Cloud.
|
||||
F) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.
|
||||
G) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt.
|
||||
Answer: Let's think step by step. Most short-period comets originate from the Kuiper belt. This is deduced from the observation that these comets tend to follow orbits that lie in the plane of the solar system, similar to the distribution of objects in the Kuiper belt itself. Thus, the alignment of these cometary orbits with the ecliptic plane points to their Kuiper belt origin.
|
||||
Answer: A.
|
||||
|
||||
Question:
|
||||
Colors in a soap bubble result from light
|
||||
A) dispersion
|
||||
B) deflection
|
||||
C) refraction
|
||||
D) reflection
|
||||
E) interference
|
||||
F) converted to a different frequency
|
||||
G) polarization
|
||||
H) absorption
|
||||
I) diffraction
|
||||
J) transmission
|
||||
|
||||
Answer: Let's think step by step. The colorful patterns observed in a soap bubble are caused by the phenomenon of light interference. This occurs when light waves bounce between the two surfaces of the soap film, combining constructively or destructively based on their phase differences and the varying thickness of the film. These interactions result in vibrant color patterns due to variations in the intensity of different wavelengths of light.
|
||||
Answer: E.
|
||||
|
||||
Question:
|
||||
A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?
|
||||
A) 240 W
|
||||
B) 120 W
|
||||
C) 10 W
|
||||
D) 480 W
|
||||
E) 360 W
|
||||
F) 200 W
|
||||
G) 30 W
|
||||
H) 150 W
|
||||
I) 60 W
|
||||
J) 300 W
|
||||
|
||||
Answer: Let's think step by step. The rate of energy usage, known as power, in an electrical circuit is calculated by the product of voltage and current. For a microwave oven connected to a 120 V outlet and drawing a current of 2 amps, the power consumption can be calculated as follows:
|
||||
\[
|
||||
\text{{Power}} = \text{{Voltage}} \times \text{{Current}} = 120 \, \text{{V}} \times 2 \, \text{{A}} = 240 \, \text{{W}}.
|
||||
\]
|
||||
Therefore, the microwave oven uses energy at a rate of 240 watts.
|
||||
Answer: A.
|
||||
|
||||
Question:
|
||||
{}
|
||||
|
||||
Answer: Let's think step by step.
|
@ -0,0 +1,23 @@
|
||||
initial_prompt_0:
|
||||
- |
|
||||
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
||||
|
||||
initial_prompt_1:
|
||||
- |
|
||||
You are a helpful assistant. Answer the given multiple-choice question. Only one option is correct. The last line of your response should be in the format 'The correct answer is: $LETTER', where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
||||
|
||||
initial_prompt_2:
|
||||
- |
|
||||
Select the correct answer for the following multiple-choice question. There is only one valid choice. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
||||
|
||||
initial_prompt_3:
|
||||
- |
|
||||
Review the following multiple-choice question and choose the one correct answer. Ensure that your response concludes with a line exactly formatted as 'The correct answer is: $LETTER', where LETTER represents one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
@ -0,0 +1,5 @@
|
||||
prompt_format:
|
||||
- |
|
||||
Answer the following multiple choice question about {}. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
@ -0,0 +1,5 @@
|
||||
prompt_format:
|
||||
- |
|
||||
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
633
opencompass/datasets/supergpqa/supergpqa_eval.py
Normal file
633
opencompass/datasets/supergpqa/supergpqa_eval.py
Normal file
@ -0,0 +1,633 @@
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
import os
|
||||
from prettytable import PrettyTable
|
||||
import pandas as pd
|
||||
# from openpyxl.styles import PatternFill, Font, Alignment
|
||||
from tqdm import tqdm
|
||||
import timeout_decorator
|
||||
import multiprocessing
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
@timeout_decorator.timeout(5) # 5 seconds timeout
|
||||
def safe_regex_search(pattern, text, flags=0):
|
||||
try:
|
||||
return re.search(pattern, text, flags)
|
||||
except timeout_decorator.TimeoutError:
|
||||
print(f"Regex match timeout: pattern={pattern}, text={text[:100]}...")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Regex match error: {str(e)}")
|
||||
return None
|
||||
|
||||
def extract_option_labels(text, options='ABCDEFGHIJ'):
|
||||
if not isinstance(text, str) or not isinstance(options, str):
|
||||
return 'error'
|
||||
|
||||
text = text.rstrip()
|
||||
last_line = text.split('\n')[-1]
|
||||
|
||||
option_str = ''.join([chr(65 + i) for i in range(len(options))]) if options else 'ABCDEFGHIJ'
|
||||
|
||||
patterns = [
|
||||
# e.g. "The final answer to this question is: A."
|
||||
# "The best option is $\boxed{B}:"
|
||||
# "The correct answer is (C)."
|
||||
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is?:?\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
|
||||
# e.g. "ANSWER: A"
|
||||
# "Answer: $\boxed{B}."
|
||||
# "ANSWER: (C):"
|
||||
f'(?i:Answer)[\*\s]*:\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
|
||||
# e.g. "A"
|
||||
# "$\boxed{B}$"
|
||||
# "(C)."
|
||||
# "[D]:"
|
||||
f'^[^\w\r\n]*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, last_line, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
def extract_option_content(text, options_content=None):
|
||||
if not isinstance(text, str) or not isinstance(options_content, list):
|
||||
return 'error'
|
||||
|
||||
escaped_options_content = [re.escape(option_content) for option_content in options_content]
|
||||
escaped_options_content_str = '|'.join(escaped_options_content)
|
||||
|
||||
text = text.rstrip()
|
||||
last_line = text.split('\n')[-1]
|
||||
|
||||
patterns = [
|
||||
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is:?\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
|
||||
f'(?i:Answer)\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
|
||||
f'^[^\w\r\n]*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, last_line)
|
||||
if match:
|
||||
if match.group(1) in escaped_options_content:
|
||||
return options_content[escaped_options_content.index(match.group(1))]
|
||||
else:
|
||||
return match.group(1)
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, text)
|
||||
if match:
|
||||
if match.group(1) in escaped_options_content:
|
||||
return options_content[escaped_options_content.index(match.group(1))]
|
||||
else:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
def calculate_accuracy(file_path, save_dir, mode):
|
||||
data = []
|
||||
acc = 0
|
||||
count = 0
|
||||
err = 0
|
||||
miss = 0
|
||||
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
|
||||
stats = {
|
||||
'discipline': {},
|
||||
'field': {},
|
||||
'subfield': {}
|
||||
}
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
for line in tqdm(file, desc=f"Reading {os.path.basename(file_path)} data", leave=False):
|
||||
data.append(json.loads(line))
|
||||
|
||||
if not data:
|
||||
print(f"Warning: No data found in {file_path}")
|
||||
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
|
||||
|
||||
for sample in tqdm(data, desc=f"Processing {os.path.basename(file_path)} samples", leave=False):
|
||||
if mode == 'zero-shot':
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
sample["extracted_answer"] = predict
|
||||
elif mode == 'five-shot':
|
||||
response = prediction.split('Question:')[0]
|
||||
predict = extract_option_labels(response, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(response, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
if predict == None:
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
sample["extracted_answer"] = predict
|
||||
|
||||
discipline = sample.get("discipline", "unknown")
|
||||
field = sample.get("field", "unknown")
|
||||
subfield = sample.get("subfield", "unknown")
|
||||
difficulty = sample.get("difficulty", "unknown")
|
||||
|
||||
for level, key in [
|
||||
('discipline', discipline),
|
||||
('field', f"{discipline}/{field}"),
|
||||
('subfield', f"{discipline}/{field}/{subfield}")
|
||||
]:
|
||||
if key not in stats[level]:
|
||||
stats[level][key] = {
|
||||
"correct": 0,
|
||||
"total": 0,
|
||||
"miss": 0,
|
||||
"error": 0,
|
||||
"discipline": discipline,
|
||||
"field": field,
|
||||
"subfield": subfield,
|
||||
"difficulty": {
|
||||
"easy": {"correct": 0, "total": 0},
|
||||
"middle": {"correct": 0, "total": 0},
|
||||
"hard": {"correct": 0, "total": 0}
|
||||
}
|
||||
}
|
||||
|
||||
stats[level][key]["total"] += 1
|
||||
stats[level][key]["difficulty"][difficulty]["total"] += 1
|
||||
|
||||
answer_letter = sample["answer_letter"]
|
||||
|
||||
if predict and answer_letter == predict:
|
||||
acc += 1
|
||||
acc_difficulty[difficulty] += 1
|
||||
sample["status"] = "correct"
|
||||
stats[level][key]["correct"] += 1
|
||||
stats[level][key]["difficulty"][difficulty]["correct"] += 1
|
||||
elif predict == None or predict == "":
|
||||
miss += 1
|
||||
sample["status"] = "miss"
|
||||
stats[level][key]["miss"] += 1
|
||||
elif predict == 'error':
|
||||
err += 1
|
||||
sample["status"] = "error"
|
||||
stats[level][key]["error"] += 1
|
||||
else:
|
||||
sample["status"] = "incorrect"
|
||||
count += 1
|
||||
count_difficulty[difficulty] += 1
|
||||
|
||||
if count == 0:
|
||||
print(f"Warning: No valid samples found in {file_path}")
|
||||
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
|
||||
|
||||
accuracy = acc / count
|
||||
error_rate = err / count
|
||||
miss_rate = miss / count
|
||||
hard_accuracy = acc_difficulty["hard"] / count_difficulty["hard"]
|
||||
middle_accuracy = acc_difficulty["middle"] / count_difficulty["middle"]
|
||||
easy_accuracy = acc_difficulty["easy"] / count_difficulty["easy"]
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, os.path.basename(file_path))
|
||||
with open(save_path, "w") as file:
|
||||
for sample in data:
|
||||
json.dump(sample, file)
|
||||
file.write("\n")
|
||||
|
||||
return accuracy, error_rate, miss_rate, hard_accuracy, middle_accuracy, easy_accuracy, stats
|
||||
|
||||
def calculate_total_row(hierarchy_stats, model_results, metric_name):
|
||||
"""Calculate overall summary row, including sample-wise and weighted average across dimensions"""
|
||||
total_rows = []
|
||||
|
||||
# Calculate total counts across dimensions
|
||||
total_samples = 0
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
total_subfields = sum(subfield['difficulty'][metric_name.lower()]['total'] > 0 for subfield in hierarchy_stats['subfield'].values())
|
||||
total_fields = sum(field['difficulty'][metric_name.lower()]['total'] > 0 for field in hierarchy_stats['field'].values())
|
||||
total_disciplines = sum(discipline['difficulty'][metric_name.lower()]['total'] > 0 for discipline in hierarchy_stats['discipline'].values())
|
||||
else:
|
||||
total_subfields = len(hierarchy_stats['subfield'])
|
||||
total_fields = len(hierarchy_stats['field'])
|
||||
total_disciplines = len(hierarchy_stats['discipline'])
|
||||
|
||||
# Calculate total sample count
|
||||
for discipline_stats in hierarchy_stats['discipline'].values():
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
total_samples += discipline_stats['difficulty'][metric_name.lower()]['total']
|
||||
else:
|
||||
total_samples += discipline_stats['total']
|
||||
|
||||
if metric_name == 'Accuracy':
|
||||
row_types = [
|
||||
(f'Overall (sample-wise) (Total samples: {total_samples})', 'sample'),
|
||||
(f'Overall (subfield-wise) (Total subfields: {total_subfields})', 'subfield'),
|
||||
(f'Overall (field-wise) (Total fields: {total_fields})', 'field'),
|
||||
(f'Overall (discipline-wise) (Total disciplines: {total_disciplines})', 'discipline')
|
||||
]
|
||||
elif metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
row_types = [
|
||||
(f'Overall (sample-wise) (Total {metric_name.lower()} samples: {total_samples})', 'sample'),
|
||||
(f'Overall (subfield-wise) (Total {metric_name.lower()} subfields: {total_subfields})', 'subfield'),
|
||||
(f'Overall (field-wise) (Total {metric_name.lower()} fields: {total_fields})', 'field'),
|
||||
(f'Overall (discipline-wise) (Total {metric_name.lower()} disciplines: {total_disciplines})', 'discipline')
|
||||
]
|
||||
else: # Error Rate and Miss Rate
|
||||
row_types = [(f'Overall (Total samples: {total_samples})', 'sample')]
|
||||
|
||||
for row_name, stat_type in row_types:
|
||||
total_row = {
|
||||
'Discipline': row_name,
|
||||
'Field': '',
|
||||
'Subfield': ''
|
||||
}
|
||||
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
if stat_type == 'sample':
|
||||
# sample-wise statistics (weighted by sample count)
|
||||
stats = {'total': 0, 'correct': 0, 'error': 0, 'miss': 0}
|
||||
|
||||
for discipline_stats in hierarchy_stats['discipline'].values():
|
||||
if 'model_stats' in discipline_stats and model_name in discipline_stats['model_stats']:
|
||||
curr_stats = discipline_stats['model_stats'][model_name].get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
stats['total'] += difficulty_stats.get('total', 0)
|
||||
stats['correct'] += difficulty_stats.get('correct', 0)
|
||||
else:
|
||||
for key in ['total', 'correct', 'error', 'miss']:
|
||||
stats[key] += curr_stats.get(key, 0)
|
||||
|
||||
if stats['total'] > 0:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy'] or metric_name == 'Accuracy':
|
||||
value = stats['correct'] / stats['total']
|
||||
elif metric_name == 'Error Rate':
|
||||
value = stats['error'] / stats['total']
|
||||
else: # Miss Rate
|
||||
value = stats['miss'] / stats['total']
|
||||
else:
|
||||
value = 0
|
||||
|
||||
else:
|
||||
# Other dimension statistics (direct average of correct rates across categories)
|
||||
scores = []
|
||||
|
||||
if stat_type == 'discipline':
|
||||
categories = hierarchy_stats['discipline']
|
||||
elif stat_type == 'field':
|
||||
categories = hierarchy_stats['field']
|
||||
else: # subfield
|
||||
categories = hierarchy_stats['subfield']
|
||||
|
||||
for cat_stats in categories.values():
|
||||
if 'model_stats' in cat_stats and model_name in cat_stats['model_stats']:
|
||||
curr_stats = cat_stats['model_stats'][model_name].get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
score = difficulty_stats['correct'] / difficulty_stats['total']
|
||||
scores.append(score)
|
||||
else:
|
||||
if curr_stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
score = curr_stats['correct'] / curr_stats['total']
|
||||
scores.append(score)
|
||||
value = sum(scores) / len(scores) if scores else 0
|
||||
|
||||
total_row[f'{model_name}_{mode}'] = f"{value:.2%}"
|
||||
|
||||
total_rows.append(total_row)
|
||||
|
||||
return total_rows
|
||||
|
||||
def create_excel_report_from_stats(model_results, hierarchy_stats, save_path):
|
||||
print("Starting Excel report generation...")
|
||||
|
||||
# Create six different DataFrames for storing different metrics and difficulties
|
||||
metrics = {
|
||||
'Accuracy': {'rows': [], 'color': '000000'}, # black
|
||||
'Error Rate': {'rows': [], 'color': '000000'}, # black
|
||||
'Miss Rate': {'rows': [], 'color': '000000'}, # black
|
||||
'Hard': {'rows': [], 'color': '000000'}, # black
|
||||
'Middle': {'rows': [], 'color': '000000'}, # black
|
||||
'Easy': {'rows': [], 'color': '000000'} # black
|
||||
}
|
||||
|
||||
# Organize data by hierarchy
|
||||
for discipline in tqdm(sorted(hierarchy_stats['discipline'].keys()), desc="Processing discipline level"):
|
||||
discipline_stats = hierarchy_stats['discipline'][discipline]
|
||||
discipline_total = discipline_stats['total']
|
||||
|
||||
# Get all fields under this discipline
|
||||
categories = [k for k in hierarchy_stats['field'].keys()
|
||||
if k.startswith(f"{discipline}/")]
|
||||
|
||||
for field_key in sorted(categories):
|
||||
field_stats = hierarchy_stats['field'][field_key]
|
||||
field = field_stats['field']
|
||||
field_total = field_stats['total']
|
||||
|
||||
# Get all subfields under this field
|
||||
subcategories = [k for k in hierarchy_stats['subfield'].keys()
|
||||
if k.startswith(f"{discipline}/{field}/")]
|
||||
|
||||
# Add subfield row
|
||||
for subfield_key in sorted(subcategories):
|
||||
subfield_stats = hierarchy_stats['subfield'][subfield_key]
|
||||
|
||||
# Create base row data for each metric
|
||||
for metric_name in metrics:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
base_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': field,
|
||||
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['difficulty'][metric_name.lower()]['total']})"
|
||||
}
|
||||
else:
|
||||
base_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': field,
|
||||
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['total']})"
|
||||
}
|
||||
|
||||
row_data = base_row.copy()
|
||||
|
||||
# Add score for each model
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
stats = subfield_stats['model_stats'].get(model_name, {}).get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
else:
|
||||
if stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
value = f"{stats['correct'] / stats['total']:.2%}"
|
||||
elif metric_name == 'Error Rate':
|
||||
value = f"{stats['error'] / stats['total']:.2%}"
|
||||
else: # Miss Rate
|
||||
value = f"{stats['miss'] / stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
|
||||
row_data[f'{model_name}_{mode}'] = value
|
||||
|
||||
metrics[metric_name]['rows'].append(row_data)
|
||||
|
||||
# Add field summary row
|
||||
for metric_name in metrics:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
field_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': f"{field} (Total: {field_stats['difficulty'][metric_name.lower()]['total']})",
|
||||
'Subfield': ''
|
||||
}
|
||||
else:
|
||||
field_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': f"{field} (Total: {field_total})",
|
||||
'Subfield': ''
|
||||
}
|
||||
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
stats = field_stats['model_stats'].get(model_name, {}).get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
else:
|
||||
if stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
value = f"{stats['correct'] / stats['total']:.2%}"
|
||||
elif metric_name == 'Error Rate':
|
||||
value = f"{stats['error'] / stats['total']:.2%}"
|
||||
else: # Miss Rate
|
||||
value = f"{stats['miss'] / stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
|
||||
field_row[f'{model_name}_{mode}'] = value
|
||||
|
||||
metrics[metric_name]['rows'].append(field_row)
|
||||
|
||||
# Add discipline summary row
|
||||
for metric_name in metrics:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
discipline_row = {
|
||||
'Discipline': f"{discipline} (Total: {discipline_stats['difficulty'][metric_name.lower()]['total']})",
|
||||
'Field': '',
|
||||
'Subfield': ''
|
||||
}
|
||||
else:
|
||||
discipline_row = {
|
||||
'Discipline': f"{discipline} (Total: {discipline_total})",
|
||||
'Field': '',
|
||||
'Subfield': ''
|
||||
}
|
||||
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
stats = discipline_stats['model_stats'].get(model_name, {}).get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
else:
|
||||
if stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
value = f"{stats['correct'] / stats['total']:.2%}"
|
||||
elif metric_name == 'Error Rate':
|
||||
value = f"{stats['error'] / stats['total']:.2%}"
|
||||
else: # Miss Rate
|
||||
value = f"{stats['miss'] / stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
|
||||
discipline_row[f'{model_name}_{mode}'] = value
|
||||
|
||||
metrics[metric_name]['rows'].append(discipline_row)
|
||||
|
||||
# Create DataFrames
|
||||
dfs = {metric: pd.DataFrame(data['rows']) for metric, data in metrics.items()}
|
||||
|
||||
# Add overall summary row to each DataFrame
|
||||
for metric_name, df in dfs.items():
|
||||
total_rows = calculate_total_row(hierarchy_stats, model_results, metric_name)
|
||||
dfs[metric_name] = pd.concat([df, pd.DataFrame(total_rows)], ignore_index=True)
|
||||
|
||||
# Save to Excel, one sheet per metric
|
||||
with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
|
||||
for metric_name, df in dfs.items():
|
||||
df.to_excel(writer, sheet_name=metric_name, index=False)
|
||||
format_worksheet(writer.sheets[metric_name], df, metrics[metric_name]['color'])
|
||||
|
||||
print(f"Report generation completed, Excel file saved: {save_path}")
|
||||
|
||||
def format_worksheet(worksheet, df, color):
|
||||
"""Format worksheet"""
|
||||
# Set default font
|
||||
for row in worksheet.rows:
|
||||
for cell in row:
|
||||
cell.font = Font(name='Arial', color='000000') # Use black font uniformly
|
||||
|
||||
# Set background color
|
||||
discipline_fill = PatternFill(start_color='FFFF00', end_color='FFFF00', fill_type='solid')
|
||||
field_fill = PatternFill(start_color='FFFFD4', end_color='FFFFD4', fill_type='solid')
|
||||
|
||||
# Overall row background color
|
||||
sample_wise_fill = PatternFill(start_color='B8CCE4', end_color='B8CCE4', fill_type='solid') # Bright but not bright blue
|
||||
subfield_wise_fill = PatternFill(start_color='DCE6F1', end_color='DCE6F1', fill_type='solid') # Light blue
|
||||
field_wise_fill = PatternFill(start_color='E9EEF5', end_color='E9EEF5', fill_type='solid') # Lighter blue
|
||||
discipline_wise_fill = PatternFill(start_color='F2F5F9', end_color='F2F5F9', fill_type='solid') # Lightest blue
|
||||
error_rate_fill = PatternFill(start_color='FFB6C1', end_color='FFB6C1', fill_type='solid') # Red
|
||||
miss_rate_fill = PatternFill(start_color='D3D3D3', end_color='D3D3D3', fill_type='solid') # Gray
|
||||
|
||||
# Set column width
|
||||
for column in worksheet.columns:
|
||||
max_length = 0
|
||||
column = list(column)
|
||||
for cell in column:
|
||||
try:
|
||||
if len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
except:
|
||||
pass
|
||||
adjusted_width = (max_length + 2)
|
||||
worksheet.column_dimensions[column[0].column_letter].width = adjusted_width
|
||||
|
||||
# Merge cells and apply background color
|
||||
current_discipline = None
|
||||
discipline_start = None
|
||||
current_field = None
|
||||
field_start = None
|
||||
|
||||
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
|
||||
discipline = row[0].value
|
||||
field = row[1].value
|
||||
|
||||
# Process discipline (Discipline) merge
|
||||
if discipline and "Total:" in str(discipline):
|
||||
# If there was an unmerged discipline row before
|
||||
if discipline_start and current_discipline:
|
||||
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
|
||||
|
||||
# Apply background color to current total row
|
||||
for cell in row:
|
||||
cell.fill = discipline_fill
|
||||
|
||||
# Reset tracking variables
|
||||
current_discipline = None
|
||||
discipline_start = None
|
||||
elif discipline and discipline != current_discipline:
|
||||
# If there was an unmerged discipline row before
|
||||
if discipline_start and current_discipline:
|
||||
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
|
||||
|
||||
current_discipline = discipline
|
||||
discipline_start = row_idx
|
||||
|
||||
# Process field (Field) merge
|
||||
if field and "Total:" in str(field):
|
||||
# If there was an unmerged field row before
|
||||
if field_start and current_field:
|
||||
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
|
||||
|
||||
# Apply background color to current total row
|
||||
for cell in row:
|
||||
cell.fill = field_fill
|
||||
|
||||
# Reset tracking variables
|
||||
current_field = None
|
||||
field_start = None
|
||||
elif field and field != current_field:
|
||||
# If there was an unmerged field row before
|
||||
if field_start and current_field:
|
||||
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
|
||||
|
||||
current_field = field
|
||||
field_start = row_idx
|
||||
|
||||
# Process last unmerged cells
|
||||
last_row = worksheet.max_row
|
||||
if discipline_start and current_discipline:
|
||||
worksheet.merge_cells(f'A{discipline_start}:A{last_row}')
|
||||
if field_start and current_field:
|
||||
worksheet.merge_cells(f'B{field_start}:B{last_row}')
|
||||
|
||||
# Apply special background color to Overall row
|
||||
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
|
||||
cell_value = row[0].value
|
||||
if cell_value:
|
||||
if 'Overall (sample-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = sample_wise_fill
|
||||
elif 'Overall (subfield-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = subfield_wise_fill
|
||||
elif 'Overall (field-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = field_wise_fill
|
||||
elif 'Overall (discipline-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = discipline_wise_fill
|
||||
elif worksheet.title == 'Error Rate' and 'Overall' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = error_rate_fill
|
||||
elif worksheet.title == 'Miss Rate' and 'Overall' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = miss_rate_fill
|
||||
|
||||
# Set value format to keep two decimal places
|
||||
for row in worksheet.iter_rows(min_row=2):
|
||||
for cell in row[3:]: # Start from 4th column (skip Discipline, Field, Subfield columns)
|
||||
if isinstance(cell.value, str) and '%' in cell.value:
|
||||
try:
|
||||
value = float(cell.value.strip('%')) / 100
|
||||
cell.value = f"{value:.2%}"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Set all cells to center alignment
|
||||
for row in worksheet.rows:
|
||||
for cell in row:
|
||||
cell.alignment = Alignment(horizontal='center', vertical='center')
|
||||
|
||||
def format_cell_value(stats):
|
||||
"""Format cell content, return string with acc/error/miss"""
|
||||
total = stats['total']
|
||||
if total == 0:
|
||||
return '0%/0%/0%'
|
||||
|
||||
acc = stats['correct'] / total
|
||||
error = stats['error'] / total
|
||||
miss = stats['miss'] / total
|
||||
|
||||
return f"{acc:.1%}/{error:.1%}/{miss:.1%}"
|
691
opencompass/datasets/supergpqa/supergpqa_utils.py
Normal file
691
opencompass/datasets/supergpqa/supergpqa_utils.py
Normal file
@ -0,0 +1,691 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import sympy as sp
|
||||
import yaml
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
def load_yaml(yaml_path):
|
||||
"""Load a YAML file."""
|
||||
if not os.path.exists(yaml_path):
|
||||
raise FileNotFoundError(f'YAML file not found: {yaml_path}')
|
||||
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
|
||||
def load_json_or_jsonl(file_path):
|
||||
"""Load data from a JSON or JSONL file."""
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
if file_path.endswith('.json'):
|
||||
return json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
return [json.loads(line) for line in file]
|
||||
return None
|
||||
|
||||
|
||||
def find_file(base_path, sub_path, extensions=('json', 'jsonl')):
|
||||
"""Find the first available file with given extensions."""
|
||||
for ext in extensions:
|
||||
file_path = os.path.join(base_path, f'{sub_path}.{ext}')
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
return None
|
||||
|
||||
|
||||
def load_json_or_jsonl_with_idx(data_path, split='', idx=None):
|
||||
base_path = os.path.join(data_path, split)
|
||||
if os.path.exists(f'{base_path}.json'):
|
||||
file_path = f'{base_path}.json'
|
||||
elif os.path.exists(f'{base_path}.jsonl'):
|
||||
file_path = f'{base_path}.jsonl'
|
||||
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
||||
file_path = base_path
|
||||
else:
|
||||
raise FileNotFoundError('No JSON or JSONL file found.')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
if file_path.endswith('.json'):
|
||||
data = json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
data = [json.loads(line) for line in file]
|
||||
|
||||
if idx is not None:
|
||||
try:
|
||||
return next(item for item in data if item.get('idx') == idx)
|
||||
except StopIteration:
|
||||
raise ValueError(f'No entry found for idx {idx}')
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def load_split_data(base_path, split_name):
|
||||
"""Load the rule and sample data for a specific split."""
|
||||
split_path = os.path.join(base_path, split_name)
|
||||
rule_path = find_file(split_path, 'rule')
|
||||
sample_path = find_file(split_path, 'sample')
|
||||
|
||||
rules = load_json_or_jsonl(rule_path) if rule_path else []
|
||||
samples = load_json_or_jsonl(sample_path) if sample_path else []
|
||||
|
||||
return {'rules': rules, 'samples': samples}
|
||||
|
||||
|
||||
def process_mixed_data(base_path, mode):
|
||||
"""Load and process data for the 'mixed' split and specific mode."""
|
||||
mixed_path = os.path.join(base_path, 'mixed')
|
||||
file_path = find_file(mixed_path, mode)
|
||||
if not file_path:
|
||||
print(f'[WARNING] Missing file for mixed mode: {mode}')
|
||||
return []
|
||||
|
||||
data = load_json_or_jsonl(file_path)
|
||||
template_path = os.path.join(base_path, 'config/prompt/mixed.yaml')
|
||||
template = load_yaml(template_path)
|
||||
|
||||
processed = []
|
||||
for item in data:
|
||||
rules = '\n'.join(item.get('rule_list', []))
|
||||
questions = '\n'.join(item.get('question_list', []))
|
||||
item['prompt'] = template['prompt_format'][0].format(rules, questions)
|
||||
processed.append(item)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
class ConfigWrapper:
|
||||
|
||||
def __init__(self, config_path):
|
||||
self._config = {}
|
||||
with open(config_path, 'r') as file:
|
||||
self._config = yaml.safe_load(file)
|
||||
for key, value in self._config.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super().__setattr__(key, value)
|
||||
else:
|
||||
self._config[key] = value
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self._config:
|
||||
return self._config[key]
|
||||
raise AttributeError(
|
||||
f"'ConfigWrapper' object has no attribute '{key}'")
|
||||
|
||||
def get_id(self, data):
|
||||
if isinstance(self._config.get('id_key'), str):
|
||||
return data.get(self._config.get('id_key'), None)
|
||||
elif isinstance(self._config.get('id_key'), list):
|
||||
return '_'.join([
|
||||
str(data[key]) for key in self._config.get('id_key')
|
||||
if key in data
|
||||
])
|
||||
|
||||
def print_all_keys(self):
|
||||
print('config keys:')
|
||||
for key, value in self._config.items():
|
||||
print(f' - {key}: {value}')
|
||||
|
||||
|
||||
config_wrapper = None
|
||||
|
||||
|
||||
def initialize_config(config_path):
|
||||
global config_wrapper
|
||||
config_wrapper = ConfigWrapper(config_path)
|
||||
|
||||
|
||||
def get_config_wrapper():
|
||||
global config_wrapper
|
||||
if config_wrapper is None:
|
||||
raise RuntimeError(
|
||||
'ConfigWrapper not initialized. Call initialize_config first.')
|
||||
return config_wrapper
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_path = 'config/config.yaml'
|
||||
initialize_config(config_path)
|
||||
data = {
|
||||
'idx':
|
||||
'50',
|
||||
'step':
|
||||
21,
|
||||
'question':
|
||||
('Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
|
||||
'Please provide the decrypted answer, encapsulated in double '
|
||||
'square brackets. '
|
||||
'For example, the format should be: [[decrypted answer]].'),
|
||||
'answer':
|
||||
'[[P]]',
|
||||
'category':
|
||||
'Decryption',
|
||||
'rule_id':
|
||||
'23',
|
||||
'input':
|
||||
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
|
||||
'steps_num':
|
||||
23,
|
||||
'description':
|
||||
('For a number c=228 in the ciphertext:\n'
|
||||
'Calculate z = c^e mod n. Here ^ means multiplication.\n'
|
||||
'z is 80.\nBased on the decimal number represented by z, '
|
||||
'use the ascii code to find the corresponding letter '
|
||||
'as the plaintext letter p.\n'
|
||||
'Please give the letter p in [[...]] format.\n'),
|
||||
'atom':
|
||||
80
|
||||
}
|
||||
print(config_wrapper.get_id(data))
|
||||
|
||||
|
||||
def read_yaml(config='default'):
|
||||
if os.path.exists(f'config/prompt/{config}.yaml'):
|
||||
yaml_file = f'config/prompt/{config}.yaml'
|
||||
else:
|
||||
yaml_file = config
|
||||
with open(yaml_file, 'r') as yaml_file:
|
||||
return yaml.safe_load(yaml_file)
|
||||
|
||||
|
||||
def write_jsonl_lines(file, data):
|
||||
config_wrapper = get_config_wrapper()
|
||||
if config_wrapper.save_prompt:
|
||||
json.dump(data, file, ensure_ascii=False)
|
||||
else:
|
||||
data.pop(config_wrapper.prompt_key)
|
||||
json.dump(data, file, ensure_ascii=False)
|
||||
file.write('\n')
|
||||
file.flush()
|
||||
|
||||
|
||||
def print_info(info):
|
||||
print('-' * 100)
|
||||
print('[INFO] model_name:', info['model_name'])
|
||||
print('[INFO] splits:', info['splits'])
|
||||
print('[INFO] modes:', info['modes'])
|
||||
print('[INFO] output_dir:', info['output_dir'])
|
||||
print('[INFO] Infer Limit:',
|
||||
'No limit' if info['infer_limit'] is None else info['infer_limit'])
|
||||
print('[INFO] Number of Workers:', info['num_workers'])
|
||||
print('[INFO] Batch Size:', info['batch_size'])
|
||||
print('[INFO] Use Accel:', info['use_accel'])
|
||||
print('-' * 100)
|
||||
|
||||
|
||||
def read_json_or_jsonl(data_path, split='', mapping_key=None):
|
||||
base_path = os.path.join(data_path, split)
|
||||
if os.path.exists(f'{base_path}.json'):
|
||||
file_path = f'{base_path}.json'
|
||||
elif os.path.exists(f'{base_path}.jsonl'):
|
||||
file_path = f'{base_path}.jsonl'
|
||||
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
||||
file_path = base_path
|
||||
else:
|
||||
raise FileNotFoundError('No JSON or JSONL file found.')
|
||||
|
||||
with open(file_path, 'r') as file:
|
||||
if file_path.endswith('.json'):
|
||||
data = json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
data = [json.loads(line) for line in file]
|
||||
|
||||
if mapping_key:
|
||||
return {
|
||||
item[mapping_key]: item
|
||||
for item in data if mapping_key in item
|
||||
}
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def read_json_or_jsonl_with_idx(data_path, split='', idx=None):
|
||||
base_path = os.path.join(data_path, split)
|
||||
if os.path.exists(f'{base_path}.json'):
|
||||
file_path = f'{base_path}.json'
|
||||
elif os.path.exists(f'{base_path}.jsonl'):
|
||||
file_path = f'{base_path}.jsonl'
|
||||
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
||||
file_path = base_path
|
||||
else:
|
||||
raise FileNotFoundError('No JSON or JSONL file found.')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
if file_path.endswith('.json'):
|
||||
data = json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
data = [json.loads(line) for line in file]
|
||||
|
||||
if idx is not None:
|
||||
try:
|
||||
return next(item for item in data if item.get('idx') == idx)
|
||||
except StopIteration:
|
||||
raise ValueError(f'No entry found for idx {idx}')
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
idx_ranges = [
|
||||
[18],
|
||||
[73, 74, 77],
|
||||
[94],
|
||||
[115, 116, 117],
|
||||
[121, 122, 123, 125],
|
||||
[131, 132, 134, 135, 136],
|
||||
[141, 143, 149],
|
||||
list(range(145, 148)),
|
||||
list(range(151, 157)),
|
||||
[160, 161, 162],
|
||||
[164, 165, 166],
|
||||
[170],
|
||||
[206, 209],
|
||||
list(range(211, 216)),
|
||||
[217, 218],
|
||||
]
|
||||
|
||||
|
||||
def clean_json_string(json_str):
|
||||
json_str = re.sub(r'[\x00-\x1F\x7F]', '', json_str)
|
||||
return json_str
|
||||
|
||||
|
||||
def is_in_idx_ranges(idx, idx_ranges):
|
||||
for range_list in idx_ranges:
|
||||
if int(idx) in range_list:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
matches = re.findall(r'{.*}', text, re.DOTALL)
|
||||
if matches:
|
||||
json_str = matches[-1]
|
||||
json_str = clean_json_string(json_str)
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f'Error decoding JSON: {e}')
|
||||
return 'NULL'
|
||||
return 'NULL'
|
||||
|
||||
|
||||
def extract_all_responses_from_json(response_json):
|
||||
results = []
|
||||
for key, value in response_json.items():
|
||||
results.append(str(value))
|
||||
return results
|
||||
|
||||
|
||||
def clean_latex(latex_expr):
|
||||
if '=' in latex_expr:
|
||||
latex_expr = latex_expr.rsplit('=', 1)[1]
|
||||
latex_expr = re.sub(r'\\[()\[\]]', '', latex_expr)
|
||||
latex_expr = re.sub(r'\\text\{.*?\}', '', latex_expr)
|
||||
latex_expr = re.sub(r'\\(left|right|displaystyle)', '', latex_expr)
|
||||
latex_expr = latex_expr.replace('\\\\', '\\')
|
||||
return latex_expr
|
||||
|
||||
|
||||
def extract_text_from_brackets(text, clean_level='basic'):
|
||||
matches = re.findall(r'\[\[\s*(.*?)\s*\]\]', text, re.DOTALL)
|
||||
if not matches:
|
||||
matches = re.findall(r'\$\\boxed\{(.*?)\}\$', text, re.DOTALL)
|
||||
if not matches:
|
||||
matches = re.findall(r'\[\s*(.*?)\s*\]', text, re.DOTALL)
|
||||
if matches:
|
||||
match_str = matches[0].strip()
|
||||
if clean_level == 'clean':
|
||||
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
||||
' ', '').replace('[', '').replace(']', '')
|
||||
elif clean_level == 'logic':
|
||||
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
||||
' ', '').replace('.', '')
|
||||
elif clean_level == 'math':
|
||||
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
||||
'[', '').replace(']', '').replace('$', '')
|
||||
return f'{clean_latex(match_str)}'
|
||||
return f'[[{match_str}]]'
|
||||
return 'NULL'
|
||||
|
||||
|
||||
def extract_inner_text_from_brackets(text):
|
||||
if not isinstance(text, str):
|
||||
print(f'text type: {type(text)}, text value: {text}')
|
||||
return 'NULL'
|
||||
match = re.search(r'\[\[(.*?)\]\]', text, re.DOTALL)
|
||||
return match.group(1) if match else 'NULL'
|
||||
|
||||
|
||||
def extract_numbers(str):
|
||||
numbers = re.findall(r'\d+', str)
|
||||
numbers = list(map(int, numbers))
|
||||
return numbers
|
||||
|
||||
|
||||
def extract_and_sort_inequalities(latex_expr):
|
||||
pattern = r'(≥|≤)\s*([-]?\d+\.?\d*)'
|
||||
matches = re.findall(pattern, latex_expr)
|
||||
extracted_inequalities = [''.join(match) for match in matches]
|
||||
sorted_inequalities = sorted(extracted_inequalities)
|
||||
return sorted_inequalities
|
||||
|
||||
|
||||
def rule5_normalize_content(content):
|
||||
parts = [part for part in content.split(';')]
|
||||
sorted_parts = sorted(parts)
|
||||
return sorted_parts
|
||||
|
||||
|
||||
def normalize_string(s):
|
||||
s = re.sub(r'[^0-9]', '', s)
|
||||
pairs = s.split(',')
|
||||
pairs.sort()
|
||||
return pairs
|
||||
|
||||
|
||||
def remove_commas_and_spaces(s):
|
||||
return re.sub(r'[,\s\[\]]+', '', s)
|
||||
|
||||
|
||||
def remove_non_alphanumeric(s):
|
||||
return re.sub(r'\W+', '', s)
|
||||
|
||||
|
||||
def contains_or(answer):
|
||||
return 'or' in answer
|
||||
|
||||
|
||||
def compare_multi_results(response, answer):
|
||||
try:
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
response_text = re.sub(r'\\text\{or\}', 'or', response_text)
|
||||
if response_text == 'NULL':
|
||||
return False
|
||||
answer = extract_text_from_brackets(answer, 'clean')
|
||||
response_split = response_text.strip('[[]]').split('or')
|
||||
answer_split = answer.strip('[[]]').split('or')
|
||||
response_sorted = sorted([x.strip() for x in response_split])
|
||||
answer_sorted = sorted([x.strip() for x in answer_split])
|
||||
return response_sorted == answer_sorted
|
||||
except Exception as e:
|
||||
print(f'Error during comparison: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def split_or_expression(expression):
|
||||
return [part.strip() for part in expression.split('or')]
|
||||
|
||||
|
||||
def compare_math_expressions(response, answer):
|
||||
response_text = extract_text_from_brackets(response, 'math')
|
||||
answer_text = extract_text_from_brackets(answer, 'math')
|
||||
if response_text == 'NULL':
|
||||
return False
|
||||
if contains_or(answer_text):
|
||||
response_parts = split_or_expression(response_text)
|
||||
answer_parts = split_or_expression(answer_text)
|
||||
try:
|
||||
response_exprs = {
|
||||
sp.simplify(parse_latex(part))
|
||||
for part in response_parts
|
||||
}
|
||||
answer_exprs = {
|
||||
sp.simplify(parse_latex(part))
|
||||
for part in answer_parts
|
||||
}
|
||||
return response_exprs == answer_exprs
|
||||
except Exception as e:
|
||||
print(f'Error during simplification or parsing: {e}')
|
||||
return response_text == answer_text
|
||||
else:
|
||||
try:
|
||||
response_expr = sp.simplify(parse_latex(response_text))
|
||||
answer_expr = sp.simplify(parse_latex(answer_text))
|
||||
return response_expr == answer_expr
|
||||
except Exception as e:
|
||||
print(f'Error during simplification or parsing: {e}')
|
||||
return response_text == answer_text
|
||||
|
||||
|
||||
def method_equal(response_text, answer):
|
||||
return response_text == answer
|
||||
|
||||
|
||||
def method_1(response_text, answer):
|
||||
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
||||
cleaned_string = cleaned_string.lower()
|
||||
answer = re.sub(r'[^A-Za-z]', '', answer)
|
||||
answer = answer.lower()
|
||||
return cleaned_string == answer
|
||||
|
||||
|
||||
def method_2(response_text, answer):
|
||||
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
||||
cleaned_string = cleaned_string.lower()
|
||||
answer = answer.split(',')
|
||||
return cleaned_string in answer
|
||||
|
||||
|
||||
def method_3(response_text, answer):
|
||||
response_text = response_text.lower()
|
||||
pairs1 = re.split(r'\W+', response_text)
|
||||
pairs2 = answer.split(' ')
|
||||
pairs1 = [word for word in pairs1 if word]
|
||||
pairs1.sort()
|
||||
pairs2.sort()
|
||||
return pairs1 == pairs2
|
||||
|
||||
|
||||
def method_4(response_text, answer):
|
||||
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
||||
cleaned_string = cleaned_string.lower()
|
||||
return cleaned_string in answer
|
||||
|
||||
|
||||
def method_5(response_text, answer):
|
||||
response_text = re.sub(r'\s+', '', response_text)
|
||||
response_text = response_text.split(',')
|
||||
answer = answer.split(',')
|
||||
response_text.sort()
|
||||
answer.sort()
|
||||
return response_text == answer
|
||||
|
||||
|
||||
def method_9(response_text, answer):
|
||||
response_text = response_text.replace('×', '*').replace('−', '-')
|
||||
answer = answer.replace('×', '*').replace('−', '-')
|
||||
|
||||
def extract_operators(s):
|
||||
return re.findall(r'[+\-*/]', s)
|
||||
|
||||
response_ops = extract_operators(response_text.split('=')[0])
|
||||
answer_ops = extract_operators(answer.split('=')[0])
|
||||
if response_ops != answer_ops:
|
||||
return False
|
||||
match = re.search(r'=\s*(-?\d+)', answer)
|
||||
expected_result = int(match.group(1))
|
||||
try:
|
||||
left_side = response_text.split('=')[0]
|
||||
result = eval(left_side)
|
||||
except Exception as e:
|
||||
print(f'Error during evaluation: {e}')
|
||||
return False
|
||||
return result == expected_result
|
||||
|
||||
|
||||
def method_10(response_text, answer):
|
||||
response_text = response_text.replace('×', '*').replace('−', '-')
|
||||
response_text = response_text.split('=')[0]
|
||||
answer = answer.split('\n')[0].split('=')[0]
|
||||
response_ops = sorted(remove_non_alphanumeric(response_text))
|
||||
answer_ops = sorted(remove_non_alphanumeric(answer))
|
||||
if response_ops != answer_ops:
|
||||
return False
|
||||
try:
|
||||
result = eval(response_text)
|
||||
except Exception as e:
|
||||
print(f'Error during evaluation: {e}')
|
||||
return False
|
||||
return result == 24
|
||||
|
||||
|
||||
def method_18(response_text, answer):
|
||||
cleaned_s1 = remove_commas_and_spaces(response_text)
|
||||
cleaned_s2 = remove_commas_and_spaces(answer)
|
||||
return cleaned_s1 == cleaned_s2
|
||||
|
||||
|
||||
def method_general(response_text, answer):
|
||||
cleaned_s1 = remove_non_alphanumeric(response_text)
|
||||
cleaned_s2 = remove_non_alphanumeric(answer)
|
||||
return cleaned_s1 == cleaned_s2
|
||||
|
||||
|
||||
question_methods = {
|
||||
'1': method_1,
|
||||
'2': method_2,
|
||||
'3': method_3,
|
||||
'4': method_4,
|
||||
'5': method_5,
|
||||
'9': method_9,
|
||||
'10': method_10,
|
||||
'18': method_18,
|
||||
}
|
||||
|
||||
|
||||
def evaluate_response_vs_answer(response, answer, question_type, rule_id, idx):
|
||||
if question_type == 'logic' and rule_id == '5':
|
||||
response_text = extract_text_from_brackets(response, 'logic')
|
||||
answer_text = extract_text_from_brackets(answer, 'logic')
|
||||
if response_text is None:
|
||||
return False
|
||||
normalized_response = rule5_normalize_content(response_text)
|
||||
normalized_answer = rule5_normalize_content(answer)
|
||||
return normalized_response == normalized_answer
|
||||
elif question_type == 'logic':
|
||||
response_text = extract_text_from_brackets(response, 'logic')
|
||||
answer_text = extract_text_from_brackets(answer, 'logic')
|
||||
return response_text == answer_text
|
||||
elif question_type == 'operation' and (idx == '178' or idx == '179'):
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
response_text = extract_and_sort_inequalities(response_text)
|
||||
answer_text = extract_and_sort_inequalities(answer)
|
||||
# print(response_text, answer_text)
|
||||
return response_text == answer_text
|
||||
elif question_type == 'operation' and rule_id == '18':
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
answer = extract_inner_text_from_brackets(answer)
|
||||
response_text = ''.join(sorted(re.sub(r'\W+', '', response_text)))
|
||||
answer = ''.join(sorted(re.sub(r'\W+', '', answer)))
|
||||
return response_text == answer
|
||||
elif question_type == 'operation' and rule_id in {'23', '24', '25'}:
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
if response_text is None:
|
||||
return False
|
||||
response_text = extract_numbers(response_text)
|
||||
answer_text = extract_numbers(answer)
|
||||
return response_text == answer_text
|
||||
elif question_type == 'operation' and is_in_idx_ranges(idx, idx_ranges):
|
||||
return compare_math_expressions(response, answer)
|
||||
elif question_type == 'operation' and contains_or(answer):
|
||||
return compare_multi_results(response, answer)
|
||||
elif question_type == 'puzzle':
|
||||
response_text = extract_inner_text_from_brackets(response)
|
||||
answer = extract_inner_text_from_brackets(answer)
|
||||
method = question_methods.get(rule_id)
|
||||
if method:
|
||||
return method(response_text, answer)
|
||||
return method_general(response_text, answer)
|
||||
else:
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
return response_text == answer
|
||||
|
||||
|
||||
def compute_one_mixed_question_pass_rate(idx,
|
||||
question_list,
|
||||
response_json,
|
||||
base_path=None):
|
||||
if response_json == 'NULL':
|
||||
result_dict = {
|
||||
'idx': idx,
|
||||
'response': response_json,
|
||||
'details': None,
|
||||
'pass_rate': 0,
|
||||
'is_correct': False
|
||||
}
|
||||
return result_dict
|
||||
response_list = extract_all_responses_from_json(response_json)
|
||||
correct_num = 0
|
||||
results = []
|
||||
for q_idx, question in enumerate(question_list):
|
||||
category, question_idx = question.rsplit('_', 1)
|
||||
question_content = load_json_or_jsonl_with_idx(base_path,
|
||||
os.path.join(
|
||||
category, 'sample'),
|
||||
idx=question_idx)
|
||||
answer = question_content['answer']
|
||||
if q_idx >= len(response_list):
|
||||
break
|
||||
response = response_list[q_idx]
|
||||
response_text = extract_text_from_brackets(response)
|
||||
rule_id = question_content['rule_id']
|
||||
is_correct = evaluate_response_vs_answer(response, answer, category,
|
||||
rule_id, q_idx)
|
||||
if is_correct:
|
||||
correct_num += 1
|
||||
results.append({
|
||||
'question': question,
|
||||
'response_text': response_text,
|
||||
'answer': answer,
|
||||
'is_correct': is_correct
|
||||
})
|
||||
|
||||
pass_rate = correct_num / len(question_list)
|
||||
question_correct = pass_rate == 1.0
|
||||
result_dict = {
|
||||
'idx': idx,
|
||||
'response': response_json,
|
||||
'details': results,
|
||||
'pass_rate': pass_rate,
|
||||
'is_correct': question_correct
|
||||
}
|
||||
return result_dict
|
||||
|
||||
|
||||
def evaluate_responses(data, mode, base_path=None):
|
||||
results = []
|
||||
|
||||
# Iterate over the values of the dictionary (numerical keys)
|
||||
for key, record in data.items():
|
||||
idx = key # Use the dictionary key as the "idx"
|
||||
response = record.get('prediction', '')
|
||||
question_type = record.get('category', '')
|
||||
response_text = extract_text_from_brackets(response)
|
||||
answer = record.get('gold', '')
|
||||
rule_id = record.get('rule_id', '')
|
||||
is_correct = evaluate_response_vs_answer(response, answer,
|
||||
question_type, rule_id,
|
||||
idx)
|
||||
result_dict = {
|
||||
'idx': idx,
|
||||
'response': response,
|
||||
'response_text': response_text,
|
||||
'answer': answer,
|
||||
'is_correct': is_correct
|
||||
}
|
||||
if question_type == 'counterfactual':
|
||||
real_life_answer = record.get('real_life_answer', '')
|
||||
is_real_life = evaluate_response_vs_answer(
|
||||
response, real_life_answer, question_type, rule_id, idx)
|
||||
result_dict['real_life_answer'] = real_life_answer
|
||||
result_dict['is_real_life'] = is_real_life
|
||||
if question_type == 'cipher' and mode == 'subquestions':
|
||||
result_dict['type'] = record.get('type', '')
|
||||
results.append(result_dict)
|
||||
return results
|
@ -403,6 +403,11 @@ DATASETS_MAPPING = {
|
||||
"hf_id": "",
|
||||
"local": "./data/OlympiadBench",
|
||||
},
|
||||
"opencompass/supergpqa": {
|
||||
"ms_id": "",
|
||||
"hf_id": "m-a-p/SuperGPQA",
|
||||
"local": "./data/supergpqa",
|
||||
},
|
||||
}
|
||||
|
||||
DATASETS_URL = {
|
||||
|
Loading…
Reference in New Issue
Block a user