mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
remove unnecessary code
This commit is contained in:
parent
89bbf13f5a
commit
14dcdaa0de
@ -1,14 +1,10 @@
|
||||
from mmengine import read_base
|
||||
|
||||
with read_base():
|
||||
# from opencompass.configs.datasets.supergpqa.supergpqa_mixed_gen_d00bdd import \
|
||||
# supergpqa_mixed_datasets as mixed_datasets
|
||||
from opencompass.configs.datasets.supergpqa.supergpqa_single_0_shot_gen import \
|
||||
supergpqa_0shot_single_datasets as zero_shot_datasets
|
||||
# from opencompass.configs.datasets.supergpqa.supergpqa_single_3_shot_gen import \
|
||||
# supergpqa_3shot_single_datasets as three_shot_datasets
|
||||
from opencompass.configs.models.hf_internlm.hf_internlm2_5_7b import \
|
||||
models as hf_internlm2_5_7b
|
||||
from opencompass.configs.datasets.supergpqa.supergpqa_gen import \
|
||||
supergpqa_datasets
|
||||
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import \
|
||||
models
|
||||
|
||||
datasets = zero_shot_datasets
|
||||
models = hf_internlm2_5_7b
|
||||
datasets = supergpqa_datasets
|
||||
models = models
|
||||
|
57
opencompass/configs/datasets/supergpqa/supergpqa_gen.py
Normal file
57
opencompass/configs/datasets/supergpqa/supergpqa_gen.py
Normal file
@ -0,0 +1,57 @@
|
||||
from opencompass.datasets.supergpqa.supergpqa import (
|
||||
SuperGPQADataset,
|
||||
SuperGPQAEvaluator,
|
||||
)
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
|
||||
|
||||
# Reader configuration
|
||||
reader_cfg = dict(
|
||||
input_columns=[
|
||||
'question',
|
||||
"options",
|
||||
'discipline',
|
||||
'field',
|
||||
'subfield',
|
||||
'difficulty',
|
||||
"infer_prompt",
|
||||
"prompt_mode",
|
||||
],
|
||||
output_column='answer_letter',
|
||||
)
|
||||
|
||||
# Inference configuration
|
||||
infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt='{infer_prompt}',
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
# Evaluation configuration
|
||||
eval_cfg = dict(
|
||||
evaluator=dict(type=SuperGPQAEvaluator),
|
||||
pred_role='BOT',
|
||||
)
|
||||
supergpqa_dataset = dict(
|
||||
type=SuperGPQADataset,
|
||||
abbr='supergpqa',
|
||||
path="m-a-p/SuperGPQA",
|
||||
prompt_mode='zero-shot',
|
||||
reader_cfg=reader_cfg,
|
||||
infer_cfg=infer_cfg,
|
||||
eval_cfg=eval_cfg,
|
||||
)
|
||||
|
||||
supergpqa_datasets = [supergpqa_dataset]
|
@ -1,291 +0,0 @@
|
||||
categories =[
|
||||
"Power Systems and Automation",
|
||||
"Anesthesiology",
|
||||
"Oncology",
|
||||
"Group Theory",
|
||||
"Thermal Energy Engineering",
|
||||
"Emergency Medicine",
|
||||
"Systems Science",
|
||||
"Geometry and Topology",
|
||||
"Advanced Algebra",
|
||||
"Electrical Theory and New Technologies",
|
||||
"Engineering Thermophysics",
|
||||
"Operating Systems",
|
||||
"Guidance, Navigation and Control",
|
||||
"Harmony",
|
||||
"Marine Biology",
|
||||
"Pediatrics",
|
||||
"Road and Railway Engineering",
|
||||
"Information Management and Communication",
|
||||
"Water conservancy and Hydropower Engineering",
|
||||
"Veterinary Medicine",
|
||||
"Astronomical Observation and Technology",
|
||||
"Special Number Theory",
|
||||
"Philology and Bibliography",
|
||||
"Textile Materials Science",
|
||||
"Legal Theory and Legal History",
|
||||
"Education Economics, Management and Social Security",
|
||||
"Traditional Chinese Health Preservation",
|
||||
"Epidemiology and Health Statistics",
|
||||
"Pitch and Scales",
|
||||
"Economic History",
|
||||
"Marine Engineering",
|
||||
"Labor Economics",
|
||||
"Materials Processing Engineering",
|
||||
"Demography and Anthropology",
|
||||
"Preschool Education",
|
||||
"Music History, Education, and Technology",
|
||||
"Instrumentation and Performance",
|
||||
"Cryptography",
|
||||
"Mineralogy, Petrology, and Economic Geology",
|
||||
"Microbiology and Biochemical Pharmacy",
|
||||
"Poromechanics and Reservoir Physics",
|
||||
"Imaging and Nuclear Medicine",
|
||||
"Solid State Physics",
|
||||
"Microelectronics and Solid-State Electronics",
|
||||
"Zoology",
|
||||
"Food Biochemistry",
|
||||
"Traditional Chinese Pharmacy",
|
||||
"Neurology",
|
||||
"Hydrogeology",
|
||||
"Criminal Law",
|
||||
"Radiation Medicine",
|
||||
"Relativity",
|
||||
"Analytical Chemistry",
|
||||
"Signal and Information Processing",
|
||||
"Military Command and Information Systems",
|
||||
"Literary Theory",
|
||||
"Textile Chemistry and Dyeing Engineering",
|
||||
"Urban Infrastructure Engineering",
|
||||
"Stellar and Interstellar Evolution",
|
||||
"Geological Resources and Geological Engineering",
|
||||
"Pattern Recognition",
|
||||
"Engineering Fluid Mechanics",
|
||||
"Communication and Information Systems",
|
||||
"Architectural History",
|
||||
"Stochastic Processes",
|
||||
"Microbiology",
|
||||
"French Language and Literature",
|
||||
"Principles of Computer Organization",
|
||||
"Architectural Design and Theory",
|
||||
"Animal Rearing and Breeding",
|
||||
"Physical Oceanography",
|
||||
"Acoustics",
|
||||
"Organic Chemistry",
|
||||
"Refrigeration and Cryogenic Engineering",
|
||||
"Public Finance",
|
||||
"Dermatology and Venereology",
|
||||
"Religious Studies",
|
||||
"Discrete Mathematics",
|
||||
"Forest Cultivation and Genetic Breeding",
|
||||
"Vehicle Operation Engineering",
|
||||
"Physical Chemistry",
|
||||
"Nutrition and Food Hygiene",
|
||||
"Ship Mechanics and Design Principles",
|
||||
"Power Electronics and Electrical Drives",
|
||||
"Finance",
|
||||
"Pharmacology",
|
||||
"Environmental Engineering",
|
||||
"Ecology",
|
||||
"Aeronautical and Astronautical Science and Technology",
|
||||
"Agricultural Mechanization Engineering",
|
||||
"Computer Architecture",
|
||||
"Political Economy",
|
||||
"Principles of Seismic Exploration",
|
||||
"Elements of Chemical Reaction Engineering",
|
||||
"Digital Surveying and Remote Sensing Applications",
|
||||
"History and Theory of Journalism and Media Management",
|
||||
"Instrument Science and Technology",
|
||||
"Structural Engineering",
|
||||
"Computer Networks",
|
||||
"Power Machinery and Engineering",
|
||||
"Constitutional and Administrative Law",
|
||||
"Law and Social Governance",
|
||||
"Psychology",
|
||||
"Urban Planning and Design",
|
||||
"Thermodynamics and Statistical Physics",
|
||||
"Chemical Transport Engineering",
|
||||
"Environmental and Resource Protection",
|
||||
"Fluid Machinery and Engineering",
|
||||
"Cartography and Geographic Information Engineering",
|
||||
"Computational Mathematics",
|
||||
"Pathogen Biology",
|
||||
"Human Geography",
|
||||
"Theoretical Optics",
|
||||
"Solid Mechanics",
|
||||
"Electrochemistry",
|
||||
"Aquaculture",
|
||||
"Logic",
|
||||
"Mechatronic Engineering",
|
||||
"Modern and Contemporary Chinese Literature",
|
||||
"Operations Research and Cybernetics",
|
||||
"Circuits and Systems",
|
||||
"Internal Combustion Engineering",
|
||||
"Atomic and Molecular Physics",
|
||||
"Marine Chemistry",
|
||||
"Electromagnetic Field and Microwave Technology",
|
||||
"Rigid Body Mechanics",
|
||||
"Physiology",
|
||||
"Military Chemistry and Pyrotechnics",
|
||||
"Fundamentals of Dynamics and Control",
|
||||
"Control Theory and Control Engineering",
|
||||
"Historical Geography",
|
||||
"Physical Geography",
|
||||
"National and Defense Economics",
|
||||
"Polymer Physics",
|
||||
"Landscape Plants and Ornamental Horticulture",
|
||||
"Solar System Science",
|
||||
"Library and Archival Science",
|
||||
"Internal Medicine",
|
||||
"Physical Chemistry of Metallurgical Process",
|
||||
"Antenna and Radio Communication",
|
||||
"Genetics",
|
||||
"Graph Theory",
|
||||
"Principles of Metallurgy",
|
||||
"Bridge and Tunnel Engineering",
|
||||
"Combinatorial Mathematics",
|
||||
"Otorhinolaryngology",
|
||||
"Political Science",
|
||||
"Medicinal Chemistry",
|
||||
"Health Toxicology and Environmental Health",
|
||||
"Archaeology and Museology",
|
||||
"Geotechnical Engineering",
|
||||
"Land Resource Management and Administrative Management",
|
||||
"Thermodynamics",
|
||||
"Atmospheric Physics and Atmospheric Environment",
|
||||
"Broadcasting and Television Art",
|
||||
"Numerical Analysis",
|
||||
"Statistical Mechanics",
|
||||
"Mineral Processing Engineering",
|
||||
"Mathematical Analysis",
|
||||
"Philosophy of Science and Technology",
|
||||
"Western Economics",
|
||||
"Data Structures",
|
||||
"Fine Arts",
|
||||
"Economic Statistics",
|
||||
"Environmental Science",
|
||||
"Military Thought and History",
|
||||
"Drama and Opera Studies",
|
||||
"Film Studies",
|
||||
"High Voltage and Insulation Technology",
|
||||
"Military Law",
|
||||
"Wood Science and Technology",
|
||||
"Obstetrics and Gynecology",
|
||||
"Hydraulics and Hydrology",
|
||||
"Cell Biology",
|
||||
"Biochemistry and Molecular Biology",
|
||||
"Fluid Flow and Heat Transfer in Chemical Engineering",
|
||||
"Formal Languages",
|
||||
"Optoelectronic Technology",
|
||||
"Crop Science",
|
||||
"Fundamental Mathematics",
|
||||
"Immunology",
|
||||
"Surgery",
|
||||
"Ophthalmology",
|
||||
"Social Medicine and Health Management",
|
||||
"Industrial Economics",
|
||||
"Traffic Information Engineering and Control",
|
||||
"Traditional Chinese Medicine Theory",
|
||||
"Polymer Chemistry and Physics",
|
||||
"Maternal, Child and Adolescent Health",
|
||||
"Radiation Protection and Nuclear Technology Applications",
|
||||
"Food Processing and Storage Engineering",
|
||||
"Fluid Physics",
|
||||
"Materials Physics and Chemistry",
|
||||
"Pharmaceutical Analysis",
|
||||
"Semiconductor Physics",
|
||||
"Optical Fiber Communication",
|
||||
"Ethics",
|
||||
"Psychiatry and Mental Health",
|
||||
"Management Science and Engineering",
|
||||
"Number Theory",
|
||||
"Contract Law",
|
||||
"Inorganic Chemistry",
|
||||
"Design Arts",
|
||||
"Human Anatomy and Histology-Embryology",
|
||||
"Iron and Steel Metallurgy",
|
||||
"Dance Studies",
|
||||
"Structural Geology",
|
||||
"Special Education",
|
||||
"Musical Forms and Analysis",
|
||||
"Philosophical Aesthetics",
|
||||
"Astrophysics",
|
||||
"Manufacturing Automation",
|
||||
"Quantum Mechanics",
|
||||
"Probability and Statistics",
|
||||
"Military Logistics and Equipment",
|
||||
"Heat Transfer",
|
||||
"Classical Chinese Literature",
|
||||
"Information Management Science",
|
||||
"Cosmology",
|
||||
"Educational Technology and Principles",
|
||||
"Ordinary Differential Equations",
|
||||
"Underwater Acoustics",
|
||||
"Business and Accounting Management",
|
||||
"Dynamic Meteorology",
|
||||
"Military Management",
|
||||
"Journalism and News Practice",
|
||||
"Animal Nutrition and Feed Science",
|
||||
"Applied Optics",
|
||||
"Theoretical Fluid Mechanics",
|
||||
"Communication Principles",
|
||||
"Physical Education and Training",
|
||||
"Geodesy and Surveying Engineering",
|
||||
"Meteorology",
|
||||
"Sports Science and Medicine",
|
||||
"Solid Earth Geophysics",
|
||||
"Particle and Nuclear Physics",
|
||||
"International Law",
|
||||
"Oil and Gas Field Development and Storage & Transportation Engineering",
|
||||
"Basic Stomatology",
|
||||
"Agricultural Environment and Soil-Water Engineering",
|
||||
"Geochemistry",
|
||||
"Procedural Law",
|
||||
"Botany",
|
||||
"Fuzzy Mathematics",
|
||||
"Paleontology and Stratigraphy",
|
||||
"Sports Humanities and Sociology",
|
||||
"Civil and Commercial Law",
|
||||
"Electrodynamics",
|
||||
"Mining and Safety Engineering",
|
||||
"Mass Transport and Separation Process in Chemical Engineering",
|
||||
"Advanced Programming Languages",
|
||||
"Laser Technology",
|
||||
"Weapon Systems Science and Engineering",
|
||||
"Quantitative Economics",
|
||||
"Theoretical Mechanics",
|
||||
"Nursing and Rehabilitation Medicine",
|
||||
"Databases",
|
||||
"Pharmaceutics",
|
||||
"Space physics",
|
||||
"Functions of Real Variables",
|
||||
"Non-ferrous Metallurgy",
|
||||
"Theory of Curriculum and Instruction",
|
||||
"Clinical Laboratory Diagnostics",
|
||||
"Clinical Stomatology",
|
||||
"Literary History",
|
||||
"Tourism Management and Technological Economics Management",
|
||||
"Communication and Broadcasting",
|
||||
"Pathology and Pathophysiology",
|
||||
"Functions of Complex Variables",
|
||||
"World History",
|
||||
"Forest Engineering",
|
||||
"Forensic Medicine",
|
||||
"Linguistics and Applied Linguistics",
|
||||
"Social and Folklore Studies",
|
||||
"Computer Software and Theory",
|
||||
"Subatomic and Atomic Physics",
|
||||
"Biophysics",
|
||||
"Radiochemistry",
|
||||
"Russian Language and Literature",
|
||||
"International Trade",
|
||||
"Geriatric Medicine",
|
||||
"Composition",
|
||||
"Transportation Planning and Management",
|
||||
"Polynomials and Series Expansions",
|
||||
"Nuclear Energy and Reactor Technology"
|
||||
]
|
||||
|
||||
supergpqa_summary_groups = [
|
||||
{'name': 'supergpqa', 'subsets': ['supergpqa_' + c.replace(' ', '_') for c in categories]},
|
||||
]
|
@ -1,296 +0,0 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .groups.supergpqa import supergpqa_summary_groups
|
||||
|
||||
summarizer = dict(
|
||||
dataset_abbrs=[
|
||||
"supergpqa",
|
||||
"supergpqa_Procedural_Law",
|
||||
"supergpqa_Microbiology",
|
||||
"supergpqa_World_History",
|
||||
"supergpqa_Civil_and_Commercial_Law",
|
||||
"supergpqa_Relativity",
|
||||
"supergpqa_Discrete_Mathematics",
|
||||
"supergpqa_Laser_Technology",
|
||||
"supergpqa_Power_Machinery_and_Engineering",
|
||||
"supergpqa_Geotechnical_Engineering",
|
||||
"supergpqa_Mineralogy,_Petrology,_and_Economic_Geology",
|
||||
"supergpqa_Fluid_Flow_and_Heat_Transfer_in_Chemical_Engineering",
|
||||
"supergpqa_Composition",
|
||||
"supergpqa_Biophysics",
|
||||
"supergpqa_Geriatric_Medicine",
|
||||
"supergpqa_Cell_Biology",
|
||||
"supergpqa_Underwater_Acoustics",
|
||||
"supergpqa_Political_Science",
|
||||
"supergpqa_Atomic_and_Molecular_Physics",
|
||||
"supergpqa_Industrial_Economics",
|
||||
"supergpqa_Marine_Chemistry",
|
||||
"supergpqa_Ophthalmology",
|
||||
"supergpqa_Geochemistry",
|
||||
"supergpqa_Anesthesiology",
|
||||
"supergpqa_Sports_Science_and_Medicine",
|
||||
"supergpqa_Forest_Cultivation_and_Genetic_Breeding",
|
||||
"supergpqa_Philology_and_Bibliography",
|
||||
"supergpqa_Cryptography",
|
||||
"supergpqa_Road_and_Railway_Engineering",
|
||||
"supergpqa_Literary_History",
|
||||
"supergpqa_Mining_and_Safety_Engineering",
|
||||
"supergpqa_Group_Theory",
|
||||
"supergpqa_Crop_Science",
|
||||
"supergpqa_Food_Biochemistry",
|
||||
"supergpqa_Textile_Materials_Science",
|
||||
"supergpqa_Fundamental_Mathematics",
|
||||
"supergpqa_Microelectronics_and_Solid-State_Electronics",
|
||||
"supergpqa_International_Law",
|
||||
"supergpqa_Agricultural_Environment_and_Soil-Water_Engineering",
|
||||
"supergpqa_Environmental_Science",
|
||||
"supergpqa_Urban_Infrastructure_Engineering",
|
||||
"supergpqa_Solid_State_Physics",
|
||||
"supergpqa_Mechatronic_Engineering",
|
||||
"supergpqa_Economic_History",
|
||||
"supergpqa_Power_Electronics_and_Electrical_Drives",
|
||||
"supergpqa_History_and_Theory_of_Journalism_and_Media_Management",
|
||||
"supergpqa_Neurology",
|
||||
"supergpqa_Computer_Networks",
|
||||
"supergpqa_Animal_Nutrition_and_Feed_Science",
|
||||
"supergpqa_Marine_Engineering",
|
||||
"supergpqa_Materials_Physics_and_Chemistry",
|
||||
"supergpqa_Business_and_Accounting_Management",
|
||||
"supergpqa_Basic_Stomatology",
|
||||
"supergpqa_Space_physics",
|
||||
"supergpqa_Transportation_Planning_and_Management",
|
||||
"supergpqa_Information_Management_and_Communication",
|
||||
"supergpqa_Quantitative_Economics",
|
||||
"supergpqa_Elements_of_Chemical_Reaction_Engineering",
|
||||
"supergpqa_Library_and_Archival_Science",
|
||||
"supergpqa_Electrodynamics",
|
||||
"supergpqa_Fluid_Machinery_and_Engineering",
|
||||
"supergpqa_Dynamic_Meteorology",
|
||||
"supergpqa_Functions_of_Real_Variables",
|
||||
"supergpqa_Pharmacology",
|
||||
"supergpqa_Communication_Principles",
|
||||
"supergpqa_Communication_and_Broadcasting",
|
||||
"supergpqa_Musical_Forms_and_Analysis",
|
||||
"supergpqa_Cartography_and_Geographic_Information_Engineering",
|
||||
"supergpqa_Maternal,_Child_and_Adolescent_Health",
|
||||
"supergpqa_Clinical_Stomatology",
|
||||
"supergpqa_Data_Structures",
|
||||
"supergpqa_Optoelectronic_Technology",
|
||||
"supergpqa_Physiology",
|
||||
"supergpqa_Thermodynamics_and_Statistical_Physics",
|
||||
"supergpqa_Pediatrics",
|
||||
"supergpqa_Geodesy_and_Surveying_Engineering",
|
||||
"supergpqa_Theoretical_Mechanics",
|
||||
"supergpqa_Hydraulics_and_Hydrology",
|
||||
"supergpqa_International_Trade",
|
||||
"supergpqa_Military_Chemistry_and_Pyrotechnics",
|
||||
"supergpqa_Finance",
|
||||
"supergpqa_Psychiatry_and_Mental_Health",
|
||||
"supergpqa_Fundamentals_of_Dynamics_and_Control",
|
||||
"supergpqa_Sports_Humanities_and_Sociology",
|
||||
"supergpqa_Harmony",
|
||||
"supergpqa_Control_Theory_and_Control_Engineering",
|
||||
"supergpqa_Surgery",
|
||||
"supergpqa_Analytical_Chemistry",
|
||||
"supergpqa_Political_Economy",
|
||||
"supergpqa_Theory_of_Curriculum_and_Instruction",
|
||||
"supergpqa_High_Voltage_and_Insulation_Technology",
|
||||
"supergpqa_Numerical_Analysis",
|
||||
"supergpqa_Physical_Geography",
|
||||
"supergpqa_Physical_Education_and_Training",
|
||||
"supergpqa_Applied_Optics",
|
||||
"supergpqa_Mathematical_Analysis",
|
||||
"supergpqa_Advanced_Programming_Languages",
|
||||
"supergpqa_Western_Economics",
|
||||
"supergpqa_Organic_Chemistry",
|
||||
"supergpqa_French_Language_and_Literature",
|
||||
"supergpqa_Urban_Planning_and_Design",
|
||||
"supergpqa_Polynomials_and_Series_Expansions",
|
||||
"supergpqa_Functions_of_Complex_Variables",
|
||||
"supergpqa_Advanced_Algebra",
|
||||
"supergpqa_Operating_Systems",
|
||||
"supergpqa_Internal_Combustion_Engineering",
|
||||
"supergpqa_Food_Processing_and_Storage_Engineering",
|
||||
"supergpqa_Educational_Technology_and_Principles",
|
||||
"supergpqa_Acoustics",
|
||||
"supergpqa_Quantum_Mechanics",
|
||||
"supergpqa_Iron_and_Steel_Metallurgy",
|
||||
"supergpqa_Land_Resource_Management_and_Administrative_Management",
|
||||
"supergpqa_Fuzzy_Mathematics",
|
||||
"supergpqa_Special_Education",
|
||||
"supergpqa_Solid_Mechanics",
|
||||
"supergpqa_Zoology",
|
||||
"supergpqa_Demography_and_Anthropology",
|
||||
"supergpqa_Tourism_Management_and_Technological_Economics_Management",
|
||||
"supergpqa_Theoretical_Optics",
|
||||
"supergpqa_Genetics",
|
||||
"supergpqa_Constitutional_and_Administrative_Law",
|
||||
"supergpqa_Structural_Engineering",
|
||||
"supergpqa_Principles_of_Metallurgy",
|
||||
"supergpqa_Medicinal_Chemistry",
|
||||
"supergpqa_Electromagnetic_Field_and_Microwave_Technology",
|
||||
"supergpqa_Clinical_Laboratory_Diagnostics",
|
||||
"supergpqa_Theoretical_Fluid_Mechanics",
|
||||
"supergpqa_Pitch_and_Scales",
|
||||
"supergpqa_Stochastic_Processes",
|
||||
"supergpqa_Ethics",
|
||||
"supergpqa_Circuits_and_Systems",
|
||||
"supergpqa_Engineering_Thermophysics",
|
||||
"supergpqa_Landscape_Plants_and_Ornamental_Horticulture",
|
||||
"supergpqa_Polymer_Physics",
|
||||
"supergpqa_Wood_Science_and_Technology",
|
||||
"supergpqa_Biochemistry_and_Molecular_Biology",
|
||||
"supergpqa_Preschool_Education",
|
||||
"supergpqa_Psychology",
|
||||
"supergpqa_Traditional_Chinese_Health_Preservation",
|
||||
"supergpqa_Modern_and_Contemporary_Chinese_Literature",
|
||||
"supergpqa_Religious_Studies",
|
||||
"supergpqa_Subatomic_and_Atomic_Physics",
|
||||
"supergpqa_Human_Geography",
|
||||
"supergpqa_Water_conservancy_and_Hydropower_Engineering",
|
||||
"supergpqa_Thermal_Energy_Engineering",
|
||||
"supergpqa_Immunology",
|
||||
"supergpqa_Communication_and_Information_Systems",
|
||||
"supergpqa_Meteorology",
|
||||
"supergpqa_Bridge_and_Tunnel_Engineering",
|
||||
"supergpqa_Military_Management",
|
||||
"supergpqa_Russian_Language_and_Literature",
|
||||
"supergpqa_Particle_and_Nuclear_Physics",
|
||||
"supergpqa_Rigid_Body_Mechanics",
|
||||
"supergpqa_Nuclear_Energy_and_Reactor_Technology",
|
||||
"supergpqa_Oncology",
|
||||
"supergpqa_Public_Finance",
|
||||
"supergpqa_Classical_Chinese_Literature",
|
||||
"supergpqa_Ecology",
|
||||
"supergpqa_Principles_of_Computer_Organization",
|
||||
"supergpqa_Pattern_Recognition",
|
||||
"supergpqa_Databases",
|
||||
"supergpqa_Ordinary_Differential_Equations",
|
||||
"supergpqa_Electrochemistry",
|
||||
"supergpqa_Traditional_Chinese_Pharmacy",
|
||||
"supergpqa_Dance_Studies",
|
||||
"supergpqa_Pharmaceutical_Analysis",
|
||||
"supergpqa_Otorhinolaryngology",
|
||||
"supergpqa_Principles_of_Seismic_Exploration",
|
||||
"supergpqa_Physical_Chemistry",
|
||||
"supergpqa_Special_Number_Theory",
|
||||
"supergpqa_Astrophysics",
|
||||
"supergpqa_Physical_Oceanography",
|
||||
"supergpqa_Instrumentation_and_Performance",
|
||||
"supergpqa_Military_Law",
|
||||
"supergpqa_Signal_and_Information_Processing",
|
||||
"supergpqa_Thermodynamics",
|
||||
"supergpqa_Architectural_Design_and_Theory",
|
||||
"supergpqa_Non-ferrous_Metallurgy",
|
||||
"supergpqa_Internal_Medicine",
|
||||
"supergpqa_Film_Studies",
|
||||
"supergpqa_Fluid_Physics",
|
||||
"supergpqa_Refrigeration_and_Cryogenic_Engineering",
|
||||
"supergpqa_Broadcasting_and_Television_Art",
|
||||
"supergpqa_Social_Medicine_and_Health_Management",
|
||||
"supergpqa_Military_Logistics_and_Equipment",
|
||||
"supergpqa_Criminal_Law",
|
||||
"supergpqa_Electrical_Theory_and_New_Technologies",
|
||||
"supergpqa_Nutrition_and_Food_Hygiene",
|
||||
"supergpqa_Literary_Theory",
|
||||
"supergpqa_Instrument_Science_and_Technology",
|
||||
"supergpqa_Legal_Theory_and_Legal_History",
|
||||
"supergpqa_Computer_Architecture",
|
||||
"supergpqa_Chemical_Transport_Engineering",
|
||||
"supergpqa_Military_Thought_and_History",
|
||||
"supergpqa_Archaeology_and_Museology",
|
||||
"supergpqa_Architectural_History",
|
||||
"supergpqa_Microbiology_and_Biochemical_Pharmacy",
|
||||
"supergpqa_Philosophy_of_Science_and_Technology",
|
||||
"supergpqa_Labor_Economics",
|
||||
"supergpqa_Dermatology_and_Venereology",
|
||||
"supergpqa_Materials_Processing_Engineering",
|
||||
"supergpqa_Human_Anatomy_and_Histology-Embryology",
|
||||
"supergpqa_Optical_Fiber_Communication",
|
||||
"supergpqa_Journalism_and_News_Practice",
|
||||
"supergpqa_Emergency_Medicine",
|
||||
"supergpqa_Veterinary_Medicine",
|
||||
"supergpqa_Heat_Transfer",
|
||||
"supergpqa_Information_Management_Science",
|
||||
"supergpqa_Physical_Chemistry_of_Metallurgical_Process",
|
||||
"supergpqa_Radiochemistry",
|
||||
"supergpqa_Guidance,_Navigation_and_Control",
|
||||
"supergpqa_Solid_Earth_Geophysics",
|
||||
"supergpqa_Systems_Science",
|
||||
"supergpqa_Weapon_Systems_Science_and_Engineering",
|
||||
"supergpqa_Manufacturing_Automation",
|
||||
"supergpqa_Engineering_Fluid_Mechanics",
|
||||
"supergpqa_Mineral_Processing_Engineering",
|
||||
"supergpqa_Animal_Rearing_and_Breeding",
|
||||
"supergpqa_Philosophical_Aesthetics",
|
||||
"supergpqa_Solar_System_Science",
|
||||
"supergpqa_Antenna_and_Radio_Communication",
|
||||
"supergpqa_Computational_Mathematics",
|
||||
"supergpqa_Health_Toxicology_and_Environmental_Health",
|
||||
"supergpqa_Design_Arts",
|
||||
"supergpqa_Computer_Software_and_Theory",
|
||||
"supergpqa_Aquaculture",
|
||||
"supergpqa_Nursing_and_Rehabilitation_Medicine",
|
||||
"supergpqa_Inorganic_Chemistry",
|
||||
"supergpqa_Traffic_Information_Engineering_and_Control",
|
||||
"supergpqa_Botany",
|
||||
"supergpqa_Number_Theory",
|
||||
"supergpqa_Hydrogeology",
|
||||
"supergpqa_Marine_Biology",
|
||||
"supergpqa_Law_and_Social_Governance",
|
||||
"supergpqa_Contract_Law",
|
||||
"supergpqa_Vehicle_Operation_Engineering",
|
||||
"supergpqa_Aeronautical_and_Astronautical_Science_and_Technology",
|
||||
"supergpqa_Poromechanics_and_Reservoir_Physics",
|
||||
"supergpqa_Pathogen_Biology",
|
||||
"supergpqa_Power_Systems_and_Automation",
|
||||
"supergpqa_Epidemiology_and_Health_Statistics",
|
||||
"supergpqa_Drama_and_Opera_Studies",
|
||||
"supergpqa_Environmental_Engineering",
|
||||
"supergpqa_Polymer_Chemistry_and_Physics",
|
||||
"supergpqa_Digital_Surveying_and_Remote_Sensing_Applications",
|
||||
"supergpqa_Atmospheric_Physics_and_Atmospheric_Environment",
|
||||
"supergpqa_Education_Economics,_Management_and_Social_Security",
|
||||
"supergpqa_Probability_and_Statistics",
|
||||
"supergpqa_Geometry_and_Topology",
|
||||
"supergpqa_Linguistics_and_Applied_Linguistics",
|
||||
"supergpqa_Astronomical_Observation_and_Technology",
|
||||
"supergpqa_Forensic_Medicine",
|
||||
"supergpqa_Fine_Arts",
|
||||
"supergpqa_Paleontology_and_Stratigraphy",
|
||||
"supergpqa_Management_Science_and_Engineering",
|
||||
"supergpqa_Logic",
|
||||
"supergpqa_Agricultural_Mechanization_Engineering",
|
||||
"supergpqa_Traditional_Chinese_Medicine_Theory",
|
||||
"supergpqa_Obstetrics_and_Gynecology",
|
||||
"supergpqa_Ship_Mechanics_and_Design_Principles",
|
||||
"supergpqa_Statistical_Mechanics",
|
||||
"supergpqa_Combinatorial_Mathematics",
|
||||
"supergpqa_Mass_Transport_and_Separation_Process_in_Chemical_Engineering",
|
||||
"supergpqa_Economic_Statistics",
|
||||
"supergpqa_Operations_Research_and_Cybernetics",
|
||||
"supergpqa_Formal_Languages",
|
||||
"supergpqa_Oil_and_Gas_Field_Development_and_Storage_&_Transportation_Engineering",
|
||||
"supergpqa_Environmental_and_Resource_Protection",
|
||||
"supergpqa_Structural_Geology",
|
||||
"supergpqa_Semiconductor_Physics",
|
||||
"supergpqa_Social_and_Folklore_Studies",
|
||||
"supergpqa_Music_History,_Education,_and_Technology",
|
||||
"supergpqa_Radiation_Protection_and_Nuclear_Technology_Applications",
|
||||
"supergpqa_Pathology_and_Pathophysiology",
|
||||
"supergpqa_Textile_Chemistry_and_Dyeing_Engineering",
|
||||
"supergpqa_Military_Command_and_Information_Systems",
|
||||
"supergpqa_Forest_Engineering",
|
||||
"supergpqa_Graph_Theory",
|
||||
"supergpqa_Radiation_Medicine",
|
||||
"supergpqa_Geological_Resources_and_Geological_Engineering",
|
||||
"supergpqa_Historical_Geography",
|
||||
"supergpqa_Cosmology",
|
||||
"supergpqa_Pharmaceutics",
|
||||
"supergpqa_National_and_Defense_Economics",
|
||||
"supergpqa_Imaging_and_Nuclear_Medicine",
|
||||
"supergpqa_Stellar_and_Interstellar_Evolution"
|
||||
],
|
||||
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
|
||||
)
|
@ -6,18 +6,34 @@ from datasets import load_dataset
|
||||
import os
|
||||
from datasets import Dataset, DatasetDict
|
||||
from opencompass.datasets.supergpqa.supergpqa_utils import (
|
||||
evaluate_responses, find_file, load_json_or_jsonl,
|
||||
load_json_or_jsonl_with_idx, load_yaml)
|
||||
evaluate_responses,
|
||||
find_file,
|
||||
load_json_or_jsonl,
|
||||
load_json_or_jsonl_with_idx,
|
||||
load_yaml,
|
||||
)
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS,LOAD_DATASET
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
import unittest
|
||||
from opencompass.utils import get_data_path
|
||||
from opencompass.datasets.supergpqa.supergpqa_eval import (extract_option_labels,extract_option_content)
|
||||
from opencompass.datasets.supergpqa.supergpqa_eval import (
|
||||
extract_option_labels,
|
||||
extract_option_content,
|
||||
)
|
||||
from ..base import BaseDataset
|
||||
|
||||
|
||||
def _parse(item, template, prompt_mode):
|
||||
prompt_format = [item['question']+'\n'+'\n'.join([f'{chr(65+i)}) {option}' for i, option in enumerate(item['options'])])]
|
||||
prompt_format = [
|
||||
item['question']
|
||||
+ '\n'
|
||||
+ '\n'.join(
|
||||
[
|
||||
f'{chr(65+i)}) {option}'
|
||||
for i, option in enumerate(item['options'])
|
||||
]
|
||||
)
|
||||
]
|
||||
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
|
||||
item['prompt_mode'] = prompt_mode
|
||||
return item
|
||||
@ -26,70 +42,86 @@ def _parse(item, template, prompt_mode):
|
||||
@LOAD_DATASET.register_module()
|
||||
class SuperGPQADataset(BaseDataset):
|
||||
@staticmethod
|
||||
def load(path: str, prompt_mode: str,category:str, **kwargs):
|
||||
path = get_data_path(path)
|
||||
def load(path: str, prompt_mode: str, **kwargs):
|
||||
path = get_data_path(path, local_mode=True)
|
||||
dataset = load_dataset(path, split='train')
|
||||
dataset = dataset.filter(lambda x: x['subfield'] == category)
|
||||
|
||||
#get prompt template
|
||||
# get prompt template
|
||||
template_path = None
|
||||
if prompt_mode == 'zero-shot':
|
||||
template_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'supergpqa_dataset_config/prompt/zero-shot.yaml')
|
||||
'supergpqa_dataset_config/prompt/zero-shot.yaml',
|
||||
)
|
||||
elif prompt_mode == 'five-shot':
|
||||
template_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'supergpqa_dataset_config/prompt/five-shot.yaml')
|
||||
'supergpqa_dataset_config/prompt/five-shot.yaml',
|
||||
)
|
||||
try:
|
||||
template = load_yaml(template_path)
|
||||
except FileNotFoundError:
|
||||
print(f'[ERROR] Missing prompt template: {template_path}')
|
||||
return Dataset.from_list([])
|
||||
|
||||
dataset =dataset.map(lambda item: _parse(item, template, prompt_mode))
|
||||
dataset = dataset.map(lambda item: _parse(item, template, prompt_mode))
|
||||
return dataset
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class SuperGPQAEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def score(self,predictions, references,test_set):
|
||||
mode= test_set[0]['prompt_mode']
|
||||
def score(self, predictions, references, test_set):
|
||||
mode = test_set[0]['prompt_mode']
|
||||
acc = 0
|
||||
count = 0
|
||||
err = 0
|
||||
miss = 0
|
||||
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
stats = {
|
||||
'discipline': {},
|
||||
'field': {},
|
||||
'subfield': {}
|
||||
}
|
||||
|
||||
for i,sample in enumerate(test_set):
|
||||
prediction=predictions[i]
|
||||
gold=references[i]
|
||||
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
|
||||
details = []
|
||||
for i, sample in enumerate(test_set):
|
||||
sample["pred"] = prediction = predictions[i]
|
||||
gold = references[i]
|
||||
if mode == 'zero-shot':
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
predict = extract_option_content(
|
||||
prediction, sample["options"]
|
||||
)
|
||||
predict = (
|
||||
chr(sample["options"].index(predict) + 65)
|
||||
if predict
|
||||
else None
|
||||
)
|
||||
sample["extracted_answer"] = predict
|
||||
elif mode == 'five-shot':
|
||||
response = prediction.split('Question:')[0]
|
||||
predict = extract_option_labels(response, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(response, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
predict = extract_option_content(
|
||||
response, sample["options"]
|
||||
)
|
||||
predict = (
|
||||
chr(sample["options"].index(predict) + 65)
|
||||
if predict
|
||||
else None
|
||||
)
|
||||
if predict == None:
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
predict = extract_option_content(
|
||||
prediction, sample["options"]
|
||||
)
|
||||
predict = (
|
||||
chr(sample["options"].index(predict) + 65)
|
||||
if predict
|
||||
else None
|
||||
)
|
||||
sample["extracted_answer"] = predict
|
||||
|
||||
discipline = sample.get("discipline", "unknown")
|
||||
@ -99,8 +131,8 @@ class SuperGPQAEvaluator(BaseEvaluator):
|
||||
|
||||
for level, key in [
|
||||
('discipline', discipline),
|
||||
('field', f"{discipline}/{field}"),
|
||||
('subfield', f"{discipline}/{field}/{subfield}")
|
||||
# ('field', f"{discipline}/{field}"),
|
||||
# ('subfield', f"{discipline}/{field}/{subfield}"),
|
||||
]:
|
||||
if key not in stats[level]:
|
||||
stats[level][key] = {
|
||||
@ -114,15 +146,15 @@ class SuperGPQAEvaluator(BaseEvaluator):
|
||||
"difficulty": {
|
||||
"easy": {"correct": 0, "total": 0},
|
||||
"middle": {"correct": 0, "total": 0},
|
||||
"hard": {"correct": 0, "total": 0}
|
||||
}
|
||||
"hard": {"correct": 0, "total": 0},
|
||||
},
|
||||
}
|
||||
|
||||
stats[level][key]["total"] += 1
|
||||
stats[level][key]["difficulty"][difficulty]["total"] += 1
|
||||
|
||||
answer_letter = sample["answer_letter"]
|
||||
assert answer_letter==gold
|
||||
assert answer_letter == gold
|
||||
if predict and answer_letter == predict:
|
||||
acc += 1
|
||||
acc_difficulty[difficulty] += 1
|
||||
@ -141,12 +173,33 @@ class SuperGPQAEvaluator(BaseEvaluator):
|
||||
sample["status"] = "incorrect"
|
||||
count += 1
|
||||
count_difficulty[difficulty] += 1
|
||||
details.append(
|
||||
{
|
||||
'pred': sample['pred'],
|
||||
'answer': sample['answer'],
|
||||
'parsed_answer': sample['extracted_answer'],
|
||||
'correct': True if sample['status'] else False,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
'accuracy': acc / count if count > 0 else 0,
|
||||
# 'error_rate': err / count if count > 0 else 0,
|
||||
# 'miss_rate': miss / count if count > 0 else 0,
|
||||
# 'hard_accuracy': acc_difficulty["hard"] / count_difficulty["hard"] if count_difficulty["hard"] > 0 else 0,
|
||||
# 'middle_accuracy': acc_difficulty["middle"] / count_difficulty["middle"] if count_difficulty["middle"] > 0 else 0,
|
||||
# 'easy_accuracy': acc_difficulty["easy"] / count_difficulty["easy"] if count_difficulty["easy"] > 0 else 0
|
||||
'accuracy': acc / count if count > 0 else 0,
|
||||
'error_rate': err / count if count > 0 else 0,
|
||||
'miss_rate': miss / count if count > 0 else 0,
|
||||
'hard_accuracy': (
|
||||
acc_difficulty["hard"] / count_difficulty["hard"]
|
||||
if count_difficulty["hard"] > 0
|
||||
else 0
|
||||
),
|
||||
'middle_accuracy': (
|
||||
acc_difficulty["middle"] / count_difficulty["middle"]
|
||||
if count_difficulty["middle"] > 0
|
||||
else 0
|
||||
),
|
||||
'easy_accuracy': (
|
||||
acc_difficulty["easy"] / count_difficulty["easy"]
|
||||
if count_difficulty["easy"] > 0
|
||||
else 0
|
||||
),
|
||||
'details': details,
|
||||
}
|
@ -4,7 +4,7 @@ import argparse
|
||||
import os
|
||||
from prettytable import PrettyTable
|
||||
import pandas as pd
|
||||
# from openpyxl.styles import PatternFill, Font, Alignment
|
||||
|
||||
from tqdm import tqdm
|
||||
import timeout_decorator
|
||||
import multiprocessing
|
||||
@ -96,538 +96,3 @@ def extract_option_content(text, options_content=None):
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
def calculate_accuracy(file_path, save_dir, mode):
|
||||
data = []
|
||||
acc = 0
|
||||
count = 0
|
||||
err = 0
|
||||
miss = 0
|
||||
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
|
||||
|
||||
stats = {
|
||||
'discipline': {},
|
||||
'field': {},
|
||||
'subfield': {}
|
||||
}
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
for line in tqdm(file, desc=f"Reading {os.path.basename(file_path)} data", leave=False):
|
||||
data.append(json.loads(line))
|
||||
|
||||
if not data:
|
||||
print(f"Warning: No data found in {file_path}")
|
||||
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
|
||||
|
||||
for sample in tqdm(data, desc=f"Processing {os.path.basename(file_path)} samples", leave=False):
|
||||
if mode == 'zero-shot':
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
sample["extracted_answer"] = predict
|
||||
elif mode == 'five-shot':
|
||||
response = prediction.split('Question:')[0]
|
||||
predict = extract_option_labels(response, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(response, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
if predict == None:
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict == None:
|
||||
predict = extract_option_content(prediction, sample["options"])
|
||||
predict = chr(sample["options"].index(predict) + 65) if predict else None
|
||||
sample["extracted_answer"] = predict
|
||||
|
||||
discipline = sample.get("discipline", "unknown")
|
||||
field = sample.get("field", "unknown")
|
||||
subfield = sample.get("subfield", "unknown")
|
||||
difficulty = sample.get("difficulty", "unknown")
|
||||
|
||||
for level, key in [
|
||||
('discipline', discipline),
|
||||
('field', f"{discipline}/{field}"),
|
||||
('subfield', f"{discipline}/{field}/{subfield}")
|
||||
]:
|
||||
if key not in stats[level]:
|
||||
stats[level][key] = {
|
||||
"correct": 0,
|
||||
"total": 0,
|
||||
"miss": 0,
|
||||
"error": 0,
|
||||
"discipline": discipline,
|
||||
"field": field,
|
||||
"subfield": subfield,
|
||||
"difficulty": {
|
||||
"easy": {"correct": 0, "total": 0},
|
||||
"middle": {"correct": 0, "total": 0},
|
||||
"hard": {"correct": 0, "total": 0}
|
||||
}
|
||||
}
|
||||
|
||||
stats[level][key]["total"] += 1
|
||||
stats[level][key]["difficulty"][difficulty]["total"] += 1
|
||||
|
||||
answer_letter = sample["answer_letter"]
|
||||
|
||||
if predict and answer_letter == predict:
|
||||
acc += 1
|
||||
acc_difficulty[difficulty] += 1
|
||||
sample["status"] = "correct"
|
||||
stats[level][key]["correct"] += 1
|
||||
stats[level][key]["difficulty"][difficulty]["correct"] += 1
|
||||
elif predict == None or predict == "":
|
||||
miss += 1
|
||||
sample["status"] = "miss"
|
||||
stats[level][key]["miss"] += 1
|
||||
elif predict == 'error':
|
||||
err += 1
|
||||
sample["status"] = "error"
|
||||
stats[level][key]["error"] += 1
|
||||
else:
|
||||
sample["status"] = "incorrect"
|
||||
count += 1
|
||||
count_difficulty[difficulty] += 1
|
||||
|
||||
if count == 0:
|
||||
print(f"Warning: No valid samples found in {file_path}")
|
||||
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
|
||||
|
||||
accuracy = acc / count
|
||||
error_rate = err / count
|
||||
miss_rate = miss / count
|
||||
hard_accuracy = acc_difficulty["hard"] / count_difficulty["hard"]
|
||||
middle_accuracy = acc_difficulty["middle"] / count_difficulty["middle"]
|
||||
easy_accuracy = acc_difficulty["easy"] / count_difficulty["easy"]
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, os.path.basename(file_path))
|
||||
with open(save_path, "w") as file:
|
||||
for sample in data:
|
||||
json.dump(sample, file)
|
||||
file.write("\n")
|
||||
|
||||
return accuracy, error_rate, miss_rate, hard_accuracy, middle_accuracy, easy_accuracy, stats
|
||||
|
||||
def calculate_total_row(hierarchy_stats, model_results, metric_name):
|
||||
"""Calculate overall summary row, including sample-wise and weighted average across dimensions"""
|
||||
total_rows = []
|
||||
|
||||
# Calculate total counts across dimensions
|
||||
total_samples = 0
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
total_subfields = sum(subfield['difficulty'][metric_name.lower()]['total'] > 0 for subfield in hierarchy_stats['subfield'].values())
|
||||
total_fields = sum(field['difficulty'][metric_name.lower()]['total'] > 0 for field in hierarchy_stats['field'].values())
|
||||
total_disciplines = sum(discipline['difficulty'][metric_name.lower()]['total'] > 0 for discipline in hierarchy_stats['discipline'].values())
|
||||
else:
|
||||
total_subfields = len(hierarchy_stats['subfield'])
|
||||
total_fields = len(hierarchy_stats['field'])
|
||||
total_disciplines = len(hierarchy_stats['discipline'])
|
||||
|
||||
# Calculate total sample count
|
||||
for discipline_stats in hierarchy_stats['discipline'].values():
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
total_samples += discipline_stats['difficulty'][metric_name.lower()]['total']
|
||||
else:
|
||||
total_samples += discipline_stats['total']
|
||||
|
||||
if metric_name == 'Accuracy':
|
||||
row_types = [
|
||||
(f'Overall (sample-wise) (Total samples: {total_samples})', 'sample'),
|
||||
(f'Overall (subfield-wise) (Total subfields: {total_subfields})', 'subfield'),
|
||||
(f'Overall (field-wise) (Total fields: {total_fields})', 'field'),
|
||||
(f'Overall (discipline-wise) (Total disciplines: {total_disciplines})', 'discipline')
|
||||
]
|
||||
elif metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
row_types = [
|
||||
(f'Overall (sample-wise) (Total {metric_name.lower()} samples: {total_samples})', 'sample'),
|
||||
(f'Overall (subfield-wise) (Total {metric_name.lower()} subfields: {total_subfields})', 'subfield'),
|
||||
(f'Overall (field-wise) (Total {metric_name.lower()} fields: {total_fields})', 'field'),
|
||||
(f'Overall (discipline-wise) (Total {metric_name.lower()} disciplines: {total_disciplines})', 'discipline')
|
||||
]
|
||||
else: # Error Rate and Miss Rate
|
||||
row_types = [(f'Overall (Total samples: {total_samples})', 'sample')]
|
||||
|
||||
for row_name, stat_type in row_types:
|
||||
total_row = {
|
||||
'Discipline': row_name,
|
||||
'Field': '',
|
||||
'Subfield': ''
|
||||
}
|
||||
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
if stat_type == 'sample':
|
||||
# sample-wise statistics (weighted by sample count)
|
||||
stats = {'total': 0, 'correct': 0, 'error': 0, 'miss': 0}
|
||||
|
||||
for discipline_stats in hierarchy_stats['discipline'].values():
|
||||
if 'model_stats' in discipline_stats and model_name in discipline_stats['model_stats']:
|
||||
curr_stats = discipline_stats['model_stats'][model_name].get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
stats['total'] += difficulty_stats.get('total', 0)
|
||||
stats['correct'] += difficulty_stats.get('correct', 0)
|
||||
else:
|
||||
for key in ['total', 'correct', 'error', 'miss']:
|
||||
stats[key] += curr_stats.get(key, 0)
|
||||
|
||||
if stats['total'] > 0:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy'] or metric_name == 'Accuracy':
|
||||
value = stats['correct'] / stats['total']
|
||||
elif metric_name == 'Error Rate':
|
||||
value = stats['error'] / stats['total']
|
||||
else: # Miss Rate
|
||||
value = stats['miss'] / stats['total']
|
||||
else:
|
||||
value = 0
|
||||
|
||||
else:
|
||||
# Other dimension statistics (direct average of correct rates across categories)
|
||||
scores = []
|
||||
|
||||
if stat_type == 'discipline':
|
||||
categories = hierarchy_stats['discipline']
|
||||
elif stat_type == 'field':
|
||||
categories = hierarchy_stats['field']
|
||||
else: # subfield
|
||||
categories = hierarchy_stats['subfield']
|
||||
|
||||
for cat_stats in categories.values():
|
||||
if 'model_stats' in cat_stats and model_name in cat_stats['model_stats']:
|
||||
curr_stats = cat_stats['model_stats'][model_name].get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
score = difficulty_stats['correct'] / difficulty_stats['total']
|
||||
scores.append(score)
|
||||
else:
|
||||
if curr_stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
score = curr_stats['correct'] / curr_stats['total']
|
||||
scores.append(score)
|
||||
value = sum(scores) / len(scores) if scores else 0
|
||||
|
||||
total_row[f'{model_name}_{mode}'] = f"{value:.2%}"
|
||||
|
||||
total_rows.append(total_row)
|
||||
|
||||
return total_rows
|
||||
|
||||
def create_excel_report_from_stats(model_results, hierarchy_stats, save_path):
|
||||
print("Starting Excel report generation...")
|
||||
|
||||
# Create six different DataFrames for storing different metrics and difficulties
|
||||
metrics = {
|
||||
'Accuracy': {'rows': [], 'color': '000000'}, # black
|
||||
'Error Rate': {'rows': [], 'color': '000000'}, # black
|
||||
'Miss Rate': {'rows': [], 'color': '000000'}, # black
|
||||
'Hard': {'rows': [], 'color': '000000'}, # black
|
||||
'Middle': {'rows': [], 'color': '000000'}, # black
|
||||
'Easy': {'rows': [], 'color': '000000'} # black
|
||||
}
|
||||
|
||||
# Organize data by hierarchy
|
||||
for discipline in tqdm(sorted(hierarchy_stats['discipline'].keys()), desc="Processing discipline level"):
|
||||
discipline_stats = hierarchy_stats['discipline'][discipline]
|
||||
discipline_total = discipline_stats['total']
|
||||
|
||||
# Get all fields under this discipline
|
||||
categories = [k for k in hierarchy_stats['field'].keys()
|
||||
if k.startswith(f"{discipline}/")]
|
||||
|
||||
for field_key in sorted(categories):
|
||||
field_stats = hierarchy_stats['field'][field_key]
|
||||
field = field_stats['field']
|
||||
field_total = field_stats['total']
|
||||
|
||||
# Get all subfields under this field
|
||||
subcategories = [k for k in hierarchy_stats['subfield'].keys()
|
||||
if k.startswith(f"{discipline}/{field}/")]
|
||||
|
||||
# Add subfield row
|
||||
for subfield_key in sorted(subcategories):
|
||||
subfield_stats = hierarchy_stats['subfield'][subfield_key]
|
||||
|
||||
# Create base row data for each metric
|
||||
for metric_name in metrics:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
base_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': field,
|
||||
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['difficulty'][metric_name.lower()]['total']})"
|
||||
}
|
||||
else:
|
||||
base_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': field,
|
||||
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['total']})"
|
||||
}
|
||||
|
||||
row_data = base_row.copy()
|
||||
|
||||
# Add score for each model
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
stats = subfield_stats['model_stats'].get(model_name, {}).get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
else:
|
||||
if stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
value = f"{stats['correct'] / stats['total']:.2%}"
|
||||
elif metric_name == 'Error Rate':
|
||||
value = f"{stats['error'] / stats['total']:.2%}"
|
||||
else: # Miss Rate
|
||||
value = f"{stats['miss'] / stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
|
||||
row_data[f'{model_name}_{mode}'] = value
|
||||
|
||||
metrics[metric_name]['rows'].append(row_data)
|
||||
|
||||
# Add field summary row
|
||||
for metric_name in metrics:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
field_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': f"{field} (Total: {field_stats['difficulty'][metric_name.lower()]['total']})",
|
||||
'Subfield': ''
|
||||
}
|
||||
else:
|
||||
field_row = {
|
||||
'Discipline': discipline,
|
||||
'Field': f"{field} (Total: {field_total})",
|
||||
'Subfield': ''
|
||||
}
|
||||
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
stats = field_stats['model_stats'].get(model_name, {}).get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
else:
|
||||
if stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
value = f"{stats['correct'] / stats['total']:.2%}"
|
||||
elif metric_name == 'Error Rate':
|
||||
value = f"{stats['error'] / stats['total']:.2%}"
|
||||
else: # Miss Rate
|
||||
value = f"{stats['miss'] / stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
|
||||
field_row[f'{model_name}_{mode}'] = value
|
||||
|
||||
metrics[metric_name]['rows'].append(field_row)
|
||||
|
||||
# Add discipline summary row
|
||||
for metric_name in metrics:
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
discipline_row = {
|
||||
'Discipline': f"{discipline} (Total: {discipline_stats['difficulty'][metric_name.lower()]['total']})",
|
||||
'Field': '',
|
||||
'Subfield': ''
|
||||
}
|
||||
else:
|
||||
discipline_row = {
|
||||
'Discipline': f"{discipline} (Total: {discipline_total})",
|
||||
'Field': '',
|
||||
'Subfield': ''
|
||||
}
|
||||
|
||||
for model_name in model_results.keys():
|
||||
for mode in model_results[model_name].keys():
|
||||
stats = discipline_stats['model_stats'].get(model_name, {}).get(mode, {})
|
||||
|
||||
if metric_name in ['Hard', 'Middle', 'Easy']:
|
||||
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
|
||||
if difficulty_stats.get('total', 0) > 0:
|
||||
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
else:
|
||||
if stats.get('total', 0) > 0:
|
||||
if metric_name == 'Accuracy':
|
||||
value = f"{stats['correct'] / stats['total']:.2%}"
|
||||
elif metric_name == 'Error Rate':
|
||||
value = f"{stats['error'] / stats['total']:.2%}"
|
||||
else: # Miss Rate
|
||||
value = f"{stats['miss'] / stats['total']:.2%}"
|
||||
else:
|
||||
value = '0.00%'
|
||||
|
||||
discipline_row[f'{model_name}_{mode}'] = value
|
||||
|
||||
metrics[metric_name]['rows'].append(discipline_row)
|
||||
|
||||
# Create DataFrames
|
||||
dfs = {metric: pd.DataFrame(data['rows']) for metric, data in metrics.items()}
|
||||
|
||||
# Add overall summary row to each DataFrame
|
||||
for metric_name, df in dfs.items():
|
||||
total_rows = calculate_total_row(hierarchy_stats, model_results, metric_name)
|
||||
dfs[metric_name] = pd.concat([df, pd.DataFrame(total_rows)], ignore_index=True)
|
||||
|
||||
# Save to Excel, one sheet per metric
|
||||
with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
|
||||
for metric_name, df in dfs.items():
|
||||
df.to_excel(writer, sheet_name=metric_name, index=False)
|
||||
format_worksheet(writer.sheets[metric_name], df, metrics[metric_name]['color'])
|
||||
|
||||
print(f"Report generation completed, Excel file saved: {save_path}")
|
||||
|
||||
def format_worksheet(worksheet, df, color):
|
||||
"""Format worksheet"""
|
||||
# Set default font
|
||||
for row in worksheet.rows:
|
||||
for cell in row:
|
||||
cell.font = Font(name='Arial', color='000000') # Use black font uniformly
|
||||
|
||||
# Set background color
|
||||
discipline_fill = PatternFill(start_color='FFFF00', end_color='FFFF00', fill_type='solid')
|
||||
field_fill = PatternFill(start_color='FFFFD4', end_color='FFFFD4', fill_type='solid')
|
||||
|
||||
# Overall row background color
|
||||
sample_wise_fill = PatternFill(start_color='B8CCE4', end_color='B8CCE4', fill_type='solid') # Bright but not bright blue
|
||||
subfield_wise_fill = PatternFill(start_color='DCE6F1', end_color='DCE6F1', fill_type='solid') # Light blue
|
||||
field_wise_fill = PatternFill(start_color='E9EEF5', end_color='E9EEF5', fill_type='solid') # Lighter blue
|
||||
discipline_wise_fill = PatternFill(start_color='F2F5F9', end_color='F2F5F9', fill_type='solid') # Lightest blue
|
||||
error_rate_fill = PatternFill(start_color='FFB6C1', end_color='FFB6C1', fill_type='solid') # Red
|
||||
miss_rate_fill = PatternFill(start_color='D3D3D3', end_color='D3D3D3', fill_type='solid') # Gray
|
||||
|
||||
# Set column width
|
||||
for column in worksheet.columns:
|
||||
max_length = 0
|
||||
column = list(column)
|
||||
for cell in column:
|
||||
try:
|
||||
if len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
except:
|
||||
pass
|
||||
adjusted_width = (max_length + 2)
|
||||
worksheet.column_dimensions[column[0].column_letter].width = adjusted_width
|
||||
|
||||
# Merge cells and apply background color
|
||||
current_discipline = None
|
||||
discipline_start = None
|
||||
current_field = None
|
||||
field_start = None
|
||||
|
||||
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
|
||||
discipline = row[0].value
|
||||
field = row[1].value
|
||||
|
||||
# Process discipline (Discipline) merge
|
||||
if discipline and "Total:" in str(discipline):
|
||||
# If there was an unmerged discipline row before
|
||||
if discipline_start and current_discipline:
|
||||
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
|
||||
|
||||
# Apply background color to current total row
|
||||
for cell in row:
|
||||
cell.fill = discipline_fill
|
||||
|
||||
# Reset tracking variables
|
||||
current_discipline = None
|
||||
discipline_start = None
|
||||
elif discipline and discipline != current_discipline:
|
||||
# If there was an unmerged discipline row before
|
||||
if discipline_start and current_discipline:
|
||||
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
|
||||
|
||||
current_discipline = discipline
|
||||
discipline_start = row_idx
|
||||
|
||||
# Process field (Field) merge
|
||||
if field and "Total:" in str(field):
|
||||
# If there was an unmerged field row before
|
||||
if field_start and current_field:
|
||||
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
|
||||
|
||||
# Apply background color to current total row
|
||||
for cell in row:
|
||||
cell.fill = field_fill
|
||||
|
||||
# Reset tracking variables
|
||||
current_field = None
|
||||
field_start = None
|
||||
elif field and field != current_field:
|
||||
# If there was an unmerged field row before
|
||||
if field_start and current_field:
|
||||
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
|
||||
|
||||
current_field = field
|
||||
field_start = row_idx
|
||||
|
||||
# Process last unmerged cells
|
||||
last_row = worksheet.max_row
|
||||
if discipline_start and current_discipline:
|
||||
worksheet.merge_cells(f'A{discipline_start}:A{last_row}')
|
||||
if field_start and current_field:
|
||||
worksheet.merge_cells(f'B{field_start}:B{last_row}')
|
||||
|
||||
# Apply special background color to Overall row
|
||||
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
|
||||
cell_value = row[0].value
|
||||
if cell_value:
|
||||
if 'Overall (sample-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = sample_wise_fill
|
||||
elif 'Overall (subfield-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = subfield_wise_fill
|
||||
elif 'Overall (field-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = field_wise_fill
|
||||
elif 'Overall (discipline-wise)' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = discipline_wise_fill
|
||||
elif worksheet.title == 'Error Rate' and 'Overall' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = error_rate_fill
|
||||
elif worksheet.title == 'Miss Rate' and 'Overall' in str(cell_value):
|
||||
for cell in row:
|
||||
cell.fill = miss_rate_fill
|
||||
|
||||
# Set value format to keep two decimal places
|
||||
for row in worksheet.iter_rows(min_row=2):
|
||||
for cell in row[3:]: # Start from 4th column (skip Discipline, Field, Subfield columns)
|
||||
if isinstance(cell.value, str) and '%' in cell.value:
|
||||
try:
|
||||
value = float(cell.value.strip('%')) / 100
|
||||
cell.value = f"{value:.2%}"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Set all cells to center alignment
|
||||
for row in worksheet.rows:
|
||||
for cell in row:
|
||||
cell.alignment = Alignment(horizontal='center', vertical='center')
|
||||
|
||||
def format_cell_value(stats):
|
||||
"""Format cell content, return string with acc/error/miss"""
|
||||
total = stats['total']
|
||||
if total == 0:
|
||||
return '0%/0%/0%'
|
||||
|
||||
acc = stats['correct'] / total
|
||||
error = stats['error'] / total
|
||||
miss = stats['miss'] / total
|
||||
|
||||
return f"{acc:.1%}/{error:.1%}/{miss:.1%}"
|
@ -1,4 +1,5 @@
|
||||
"""Base Evaluator."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Union
|
||||
@ -46,8 +47,9 @@ class BaseEvaluator:
|
||||
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
|
||||
return self._out_dir
|
||||
|
||||
def group(self, n: int, details: List[Dict[str, Any]],
|
||||
test_set: Dataset) -> Dict[str, Any]:
|
||||
def group(
|
||||
self, n: int, details: List[Dict[str, Any]], test_set: Dataset
|
||||
) -> Dict[str, Any]:
|
||||
example2replications = {}
|
||||
for detail, example in zip(details, test_set):
|
||||
example_abbr = f"{example['subdivision']}_{example['idx']}"
|
||||
@ -62,27 +64,37 @@ class BaseEvaluator:
|
||||
def reduce(self, details: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
g_passk_details = OrderedDict()
|
||||
all_subdivisions = set(
|
||||
[detail['example_abbr'].split('_')[0] for detail in details])
|
||||
[detail['example_abbr'].split('_')[0] for detail in details]
|
||||
)
|
||||
all_metrics = list(details[0].keys())
|
||||
|
||||
for subdivision in sorted(list(all_subdivisions)):
|
||||
for metric in all_metrics:
|
||||
if metric in ['predictions', 'example_abbr']:
|
||||
continue
|
||||
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean([
|
||||
detail[metric] for detail in details
|
||||
if detail['example_abbr'].split('_')[0] == subdivision
|
||||
])
|
||||
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean(
|
||||
[
|
||||
detail[metric]
|
||||
for detail in details
|
||||
if detail['example_abbr'].split('_')[0] == subdivision
|
||||
]
|
||||
)
|
||||
|
||||
for metric in all_metrics:
|
||||
if metric in ['predictions', 'example_abbr']:
|
||||
continue
|
||||
g_passk_details[metric] = 100. * np.mean(
|
||||
[detail[metric] for detail in details])
|
||||
g_passk_details[metric] = 100.0 * np.mean(
|
||||
[detail[metric] for detail in details]
|
||||
)
|
||||
return g_passk_details
|
||||
|
||||
def evaluate(self, k: Union[int, List[int]], n: int,
|
||||
original_dataset: Dataset, **score_kwargs):
|
||||
def evaluate(
|
||||
self,
|
||||
k: Union[int, List[int]],
|
||||
n: int,
|
||||
original_dataset: Dataset,
|
||||
**score_kwargs,
|
||||
):
|
||||
real_size = len(original_dataset) // n
|
||||
all_details = []
|
||||
all_results = []
|
||||
@ -92,7 +104,7 @@ class BaseEvaluator:
|
||||
if isinstance(x, Dataset):
|
||||
return x.select(range(i * real_size, (i + 1) * real_size))
|
||||
elif isinstance(x, Iterable):
|
||||
return x[i * real_size:(i + 1) * real_size]
|
||||
return x[i * real_size : (i + 1) * real_size]
|
||||
else:
|
||||
return x
|
||||
|
||||
@ -100,7 +112,8 @@ class BaseEvaluator:
|
||||
**{
|
||||
key: select_fn(i, real_size, value)
|
||||
for key, value in score_kwargs.items()
|
||||
})
|
||||
}
|
||||
)
|
||||
details = results.pop('details', None)
|
||||
if details is not None:
|
||||
if isinstance(details, Dict):
|
||||
@ -116,10 +129,12 @@ class BaseEvaluator:
|
||||
eval_results[key].append(single_results[key])
|
||||
for key in deepcopy(eval_results):
|
||||
if isinstance(eval_results[key][0], float) or isinstance(
|
||||
eval_results[key][0], int):
|
||||
eval_results[key][0], int
|
||||
):
|
||||
if n > 1:
|
||||
eval_results[key + f' ({n} runs average)'] = np.mean(
|
||||
eval_results[key])
|
||||
eval_results[key]
|
||||
)
|
||||
eval_results.pop(key)
|
||||
else:
|
||||
eval_results[key] = np.mean(eval_results[key])
|
||||
@ -146,24 +161,46 @@ class BaseEvaluator:
|
||||
|
||||
if can_calculate and n > 1 and k > 1:
|
||||
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
for _k in ([k] if isinstance(k, int) else k):
|
||||
for _k in [k] if isinstance(k, int) else k:
|
||||
for threshold in thresholds:
|
||||
g_pass = compute_g_pass_at_k(n=n,
|
||||
c=c,
|
||||
k=_k,
|
||||
t=threshold)
|
||||
g_pass = compute_g_pass_at_k(
|
||||
n=n, c=c, k=_k, t=threshold
|
||||
)
|
||||
detail[f'G-Pass@{_k}_{threshold}'] = g_pass
|
||||
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(n=n,
|
||||
c=c,
|
||||
k=_k)
|
||||
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(
|
||||
n=n, c=c, k=_k
|
||||
)
|
||||
|
||||
eval_details.append(detail)
|
||||
|
||||
if can_calculate and n > 1 and k > 1:
|
||||
eval_results.update(self.reduce(eval_details))
|
||||
|
||||
# Store eval_details in eval_results
|
||||
eval_results['details'] = eval_details
|
||||
|
||||
return eval_results
|
||||
# Process details to flatten the predictions
|
||||
for detail in eval_details:
|
||||
# Extract all prediction fields and flatten them
|
||||
flattened_predictions = {}
|
||||
for pred in detail['predictions']:
|
||||
for k, v in pred.items():
|
||||
if k not in flattened_predictions:
|
||||
flattened_predictions[k] = [v]
|
||||
else:
|
||||
flattened_predictions[k].append(v)
|
||||
|
||||
# Replace the predictions list with the flattened dictionary
|
||||
for k, v in flattened_predictions.items():
|
||||
detail[k] = v
|
||||
|
||||
# Remove the original predictions field
|
||||
detail.pop('predictions')
|
||||
import ipdb; ipdb.set_trace()
|
||||
return eval_results
|
||||
|
||||
# If there are no details, return an empty dictionary
|
||||
return {}
|
||||
|
||||
def score(self):
|
||||
raise NotImplementedError("Method hasn't been implemented yet")
|
||||
|
Loading…
Reference in New Issue
Block a user