OpenCompass/opencompass/datasets/supergpqa/supergpqa_eval.py

633 lines
29 KiB
Python
Raw Normal View History

2025-03-07 17:36:00 +08:00
import json
import re
import argparse
import os
from prettytable import PrettyTable
import pandas as pd
# from openpyxl.styles import PatternFill, Font, Alignment
from tqdm import tqdm
import timeout_decorator
import multiprocessing
import time
from functools import partial
@timeout_decorator.timeout(5) # 5 seconds timeout
def safe_regex_search(pattern, text, flags=0):
try:
return re.search(pattern, text, flags)
except timeout_decorator.TimeoutError:
print(f"Regex match timeout: pattern={pattern}, text={text[:100]}...")
return None
except Exception as e:
print(f"Regex match error: {str(e)}")
return None
def extract_option_labels(text, options='ABCDEFGHIJ'):
if not isinstance(text, str) or not isinstance(options, str):
return 'error'
text = text.rstrip()
last_line = text.split('\n')[-1]
option_str = ''.join([chr(65 + i) for i in range(len(options))]) if options else 'ABCDEFGHIJ'
patterns = [
# e.g. "The final answer to this question is: A."
# "The best option is $\boxed{B}:"
# "The correct answer is (C)."
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is?:?\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
# e.g. "ANSWER: A"
# "Answer: $\boxed{B}."
# "ANSWER: (C):"
f'(?i:Answer)[\*\s]*:\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
# e.g. "A"
# "$\boxed{B}$"
# "(C)."
# "[D]:"
f'^[^\w\r\n]*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
]
for pattern in patterns:
match = safe_regex_search(pattern, last_line, re.IGNORECASE)
if match:
return match.group(1)
for pattern in patterns:
match = safe_regex_search(pattern, text, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_option_content(text, options_content=None):
if not isinstance(text, str) or not isinstance(options_content, list):
return 'error'
escaped_options_content = [re.escape(option_content) for option_content in options_content]
escaped_options_content_str = '|'.join(escaped_options_content)
text = text.rstrip()
last_line = text.split('\n')[-1]
patterns = [
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is:?\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'(?i:Answer)\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'^[^\w\r\n]*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
]
for pattern in patterns:
match = safe_regex_search(pattern, last_line)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(match.group(1))]
else:
return match.group(1)
for pattern in patterns:
match = safe_regex_search(pattern, text)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(match.group(1))]
else:
return match.group(1)
return None
def calculate_accuracy(file_path, save_dir, mode):
data = []
acc = 0
count = 0
err = 0
miss = 0
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
stats = {
'discipline': {},
'field': {},
'subfield': {}
}
with open(file_path, "r") as file:
for line in tqdm(file, desc=f"Reading {os.path.basename(file_path)} data", leave=False):
data.append(json.loads(line))
if not data:
print(f"Warning: No data found in {file_path}")
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
for sample in tqdm(data, desc=f"Processing {os.path.basename(file_path)} samples", leave=False):
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(prediction, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
sample["extracted_answer"] = predict
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(response, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
if predict == None:
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(prediction, sample["options"])
predict = chr(sample["options"].index(predict) + 65) if predict else None
sample["extracted_answer"] = predict
discipline = sample.get("discipline", "unknown")
field = sample.get("field", "unknown")
subfield = sample.get("subfield", "unknown")
difficulty = sample.get("difficulty", "unknown")
for level, key in [
('discipline', discipline),
('field', f"{discipline}/{field}"),
('subfield', f"{discipline}/{field}/{subfield}")
]:
if key not in stats[level]:
stats[level][key] = {
"correct": 0,
"total": 0,
"miss": 0,
"error": 0,
"discipline": discipline,
"field": field,
"subfield": subfield,
"difficulty": {
"easy": {"correct": 0, "total": 0},
"middle": {"correct": 0, "total": 0},
"hard": {"correct": 0, "total": 0}
}
}
stats[level][key]["total"] += 1
stats[level][key]["difficulty"][difficulty]["total"] += 1
answer_letter = sample["answer_letter"]
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
sample["status"] = "correct"
stats[level][key]["correct"] += 1
stats[level][key]["difficulty"][difficulty]["correct"] += 1
elif predict == None or predict == "":
miss += 1
sample["status"] = "miss"
stats[level][key]["miss"] += 1
elif predict == 'error':
err += 1
sample["status"] = "error"
stats[level][key]["error"] += 1
else:
sample["status"] = "incorrect"
count += 1
count_difficulty[difficulty] += 1
if count == 0:
print(f"Warning: No valid samples found in {file_path}")
return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, stats
accuracy = acc / count
error_rate = err / count
miss_rate = miss / count
hard_accuracy = acc_difficulty["hard"] / count_difficulty["hard"]
middle_accuracy = acc_difficulty["middle"] / count_difficulty["middle"]
easy_accuracy = acc_difficulty["easy"] / count_difficulty["easy"]
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, os.path.basename(file_path))
with open(save_path, "w") as file:
for sample in data:
json.dump(sample, file)
file.write("\n")
return accuracy, error_rate, miss_rate, hard_accuracy, middle_accuracy, easy_accuracy, stats
def calculate_total_row(hierarchy_stats, model_results, metric_name):
"""Calculate overall summary row, including sample-wise and weighted average across dimensions"""
total_rows = []
# Calculate total counts across dimensions
total_samples = 0
if metric_name in ['Hard', 'Middle', 'Easy']:
total_subfields = sum(subfield['difficulty'][metric_name.lower()]['total'] > 0 for subfield in hierarchy_stats['subfield'].values())
total_fields = sum(field['difficulty'][metric_name.lower()]['total'] > 0 for field in hierarchy_stats['field'].values())
total_disciplines = sum(discipline['difficulty'][metric_name.lower()]['total'] > 0 for discipline in hierarchy_stats['discipline'].values())
else:
total_subfields = len(hierarchy_stats['subfield'])
total_fields = len(hierarchy_stats['field'])
total_disciplines = len(hierarchy_stats['discipline'])
# Calculate total sample count
for discipline_stats in hierarchy_stats['discipline'].values():
if metric_name in ['Hard', 'Middle', 'Easy']:
total_samples += discipline_stats['difficulty'][metric_name.lower()]['total']
else:
total_samples += discipline_stats['total']
if metric_name == 'Accuracy':
row_types = [
(f'Overall (sample-wise) (Total samples: {total_samples})', 'sample'),
(f'Overall (subfield-wise) (Total subfields: {total_subfields})', 'subfield'),
(f'Overall (field-wise) (Total fields: {total_fields})', 'field'),
(f'Overall (discipline-wise) (Total disciplines: {total_disciplines})', 'discipline')
]
elif metric_name in ['Hard', 'Middle', 'Easy']:
row_types = [
(f'Overall (sample-wise) (Total {metric_name.lower()} samples: {total_samples})', 'sample'),
(f'Overall (subfield-wise) (Total {metric_name.lower()} subfields: {total_subfields})', 'subfield'),
(f'Overall (field-wise) (Total {metric_name.lower()} fields: {total_fields})', 'field'),
(f'Overall (discipline-wise) (Total {metric_name.lower()} disciplines: {total_disciplines})', 'discipline')
]
else: # Error Rate and Miss Rate
row_types = [(f'Overall (Total samples: {total_samples})', 'sample')]
for row_name, stat_type in row_types:
total_row = {
'Discipline': row_name,
'Field': '',
'Subfield': ''
}
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
if stat_type == 'sample':
# sample-wise statistics (weighted by sample count)
stats = {'total': 0, 'correct': 0, 'error': 0, 'miss': 0}
for discipline_stats in hierarchy_stats['discipline'].values():
if 'model_stats' in discipline_stats and model_name in discipline_stats['model_stats']:
curr_stats = discipline_stats['model_stats'][model_name].get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
stats['total'] += difficulty_stats.get('total', 0)
stats['correct'] += difficulty_stats.get('correct', 0)
else:
for key in ['total', 'correct', 'error', 'miss']:
stats[key] += curr_stats.get(key, 0)
if stats['total'] > 0:
if metric_name in ['Hard', 'Middle', 'Easy'] or metric_name == 'Accuracy':
value = stats['correct'] / stats['total']
elif metric_name == 'Error Rate':
value = stats['error'] / stats['total']
else: # Miss Rate
value = stats['miss'] / stats['total']
else:
value = 0
else:
# Other dimension statistics (direct average of correct rates across categories)
scores = []
if stat_type == 'discipline':
categories = hierarchy_stats['discipline']
elif stat_type == 'field':
categories = hierarchy_stats['field']
else: # subfield
categories = hierarchy_stats['subfield']
for cat_stats in categories.values():
if 'model_stats' in cat_stats and model_name in cat_stats['model_stats']:
curr_stats = cat_stats['model_stats'][model_name].get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = curr_stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
score = difficulty_stats['correct'] / difficulty_stats['total']
scores.append(score)
else:
if curr_stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
score = curr_stats['correct'] / curr_stats['total']
scores.append(score)
value = sum(scores) / len(scores) if scores else 0
total_row[f'{model_name}_{mode}'] = f"{value:.2%}"
total_rows.append(total_row)
return total_rows
def create_excel_report_from_stats(model_results, hierarchy_stats, save_path):
print("Starting Excel report generation...")
# Create six different DataFrames for storing different metrics and difficulties
metrics = {
'Accuracy': {'rows': [], 'color': '000000'}, # black
'Error Rate': {'rows': [], 'color': '000000'}, # black
'Miss Rate': {'rows': [], 'color': '000000'}, # black
'Hard': {'rows': [], 'color': '000000'}, # black
'Middle': {'rows': [], 'color': '000000'}, # black
'Easy': {'rows': [], 'color': '000000'} # black
}
# Organize data by hierarchy
for discipline in tqdm(sorted(hierarchy_stats['discipline'].keys()), desc="Processing discipline level"):
discipline_stats = hierarchy_stats['discipline'][discipline]
discipline_total = discipline_stats['total']
# Get all fields under this discipline
categories = [k for k in hierarchy_stats['field'].keys()
if k.startswith(f"{discipline}/")]
for field_key in sorted(categories):
field_stats = hierarchy_stats['field'][field_key]
field = field_stats['field']
field_total = field_stats['total']
# Get all subfields under this field
subcategories = [k for k in hierarchy_stats['subfield'].keys()
if k.startswith(f"{discipline}/{field}/")]
# Add subfield row
for subfield_key in sorted(subcategories):
subfield_stats = hierarchy_stats['subfield'][subfield_key]
# Create base row data for each metric
for metric_name in metrics:
if metric_name in ['Hard', 'Middle', 'Easy']:
base_row = {
'Discipline': discipline,
'Field': field,
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['difficulty'][metric_name.lower()]['total']})"
}
else:
base_row = {
'Discipline': discipline,
'Field': field,
'Subfield': f"{subfield_stats['subfield']} ({subfield_stats['total']})"
}
row_data = base_row.copy()
# Add score for each model
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
stats = subfield_stats['model_stats'].get(model_name, {}).get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
else:
value = '0.00%'
else:
if stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
value = f"{stats['correct'] / stats['total']:.2%}"
elif metric_name == 'Error Rate':
value = f"{stats['error'] / stats['total']:.2%}"
else: # Miss Rate
value = f"{stats['miss'] / stats['total']:.2%}"
else:
value = '0.00%'
row_data[f'{model_name}_{mode}'] = value
metrics[metric_name]['rows'].append(row_data)
# Add field summary row
for metric_name in metrics:
if metric_name in ['Hard', 'Middle', 'Easy']:
field_row = {
'Discipline': discipline,
'Field': f"{field} (Total: {field_stats['difficulty'][metric_name.lower()]['total']})",
'Subfield': ''
}
else:
field_row = {
'Discipline': discipline,
'Field': f"{field} (Total: {field_total})",
'Subfield': ''
}
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
stats = field_stats['model_stats'].get(model_name, {}).get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
else:
value = '0.00%'
else:
if stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
value = f"{stats['correct'] / stats['total']:.2%}"
elif metric_name == 'Error Rate':
value = f"{stats['error'] / stats['total']:.2%}"
else: # Miss Rate
value = f"{stats['miss'] / stats['total']:.2%}"
else:
value = '0.00%'
field_row[f'{model_name}_{mode}'] = value
metrics[metric_name]['rows'].append(field_row)
# Add discipline summary row
for metric_name in metrics:
if metric_name in ['Hard', 'Middle', 'Easy']:
discipline_row = {
'Discipline': f"{discipline} (Total: {discipline_stats['difficulty'][metric_name.lower()]['total']})",
'Field': '',
'Subfield': ''
}
else:
discipline_row = {
'Discipline': f"{discipline} (Total: {discipline_total})",
'Field': '',
'Subfield': ''
}
for model_name in model_results.keys():
for mode in model_results[model_name].keys():
stats = discipline_stats['model_stats'].get(model_name, {}).get(mode, {})
if metric_name in ['Hard', 'Middle', 'Easy']:
difficulty_stats = stats.get('difficulty', {}).get(metric_name.lower(), {})
if difficulty_stats.get('total', 0) > 0:
value = f"{difficulty_stats['correct'] / difficulty_stats['total']:.2%}"
else:
value = '0.00%'
else:
if stats.get('total', 0) > 0:
if metric_name == 'Accuracy':
value = f"{stats['correct'] / stats['total']:.2%}"
elif metric_name == 'Error Rate':
value = f"{stats['error'] / stats['total']:.2%}"
else: # Miss Rate
value = f"{stats['miss'] / stats['total']:.2%}"
else:
value = '0.00%'
discipline_row[f'{model_name}_{mode}'] = value
metrics[metric_name]['rows'].append(discipline_row)
# Create DataFrames
dfs = {metric: pd.DataFrame(data['rows']) for metric, data in metrics.items()}
# Add overall summary row to each DataFrame
for metric_name, df in dfs.items():
total_rows = calculate_total_row(hierarchy_stats, model_results, metric_name)
dfs[metric_name] = pd.concat([df, pd.DataFrame(total_rows)], ignore_index=True)
# Save to Excel, one sheet per metric
with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
for metric_name, df in dfs.items():
df.to_excel(writer, sheet_name=metric_name, index=False)
format_worksheet(writer.sheets[metric_name], df, metrics[metric_name]['color'])
print(f"Report generation completed, Excel file saved: {save_path}")
def format_worksheet(worksheet, df, color):
"""Format worksheet"""
# Set default font
for row in worksheet.rows:
for cell in row:
cell.font = Font(name='Arial', color='000000') # Use black font uniformly
# Set background color
discipline_fill = PatternFill(start_color='FFFF00', end_color='FFFF00', fill_type='solid')
field_fill = PatternFill(start_color='FFFFD4', end_color='FFFFD4', fill_type='solid')
# Overall row background color
sample_wise_fill = PatternFill(start_color='B8CCE4', end_color='B8CCE4', fill_type='solid') # Bright but not bright blue
subfield_wise_fill = PatternFill(start_color='DCE6F1', end_color='DCE6F1', fill_type='solid') # Light blue
field_wise_fill = PatternFill(start_color='E9EEF5', end_color='E9EEF5', fill_type='solid') # Lighter blue
discipline_wise_fill = PatternFill(start_color='F2F5F9', end_color='F2F5F9', fill_type='solid') # Lightest blue
error_rate_fill = PatternFill(start_color='FFB6C1', end_color='FFB6C1', fill_type='solid') # Red
miss_rate_fill = PatternFill(start_color='D3D3D3', end_color='D3D3D3', fill_type='solid') # Gray
# Set column width
for column in worksheet.columns:
max_length = 0
column = list(column)
for cell in column:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except:
pass
adjusted_width = (max_length + 2)
worksheet.column_dimensions[column[0].column_letter].width = adjusted_width
# Merge cells and apply background color
current_discipline = None
discipline_start = None
current_field = None
field_start = None
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
discipline = row[0].value
field = row[1].value
# Process discipline (Discipline) merge
if discipline and "Total:" in str(discipline):
# If there was an unmerged discipline row before
if discipline_start and current_discipline:
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
# Apply background color to current total row
for cell in row:
cell.fill = discipline_fill
# Reset tracking variables
current_discipline = None
discipline_start = None
elif discipline and discipline != current_discipline:
# If there was an unmerged discipline row before
if discipline_start and current_discipline:
worksheet.merge_cells(f'A{discipline_start}:A{row_idx-1}')
current_discipline = discipline
discipline_start = row_idx
# Process field (Field) merge
if field and "Total:" in str(field):
# If there was an unmerged field row before
if field_start and current_field:
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
# Apply background color to current total row
for cell in row:
cell.fill = field_fill
# Reset tracking variables
current_field = None
field_start = None
elif field and field != current_field:
# If there was an unmerged field row before
if field_start and current_field:
worksheet.merge_cells(f'B{field_start}:B{row_idx-1}')
current_field = field
field_start = row_idx
# Process last unmerged cells
last_row = worksheet.max_row
if discipline_start and current_discipline:
worksheet.merge_cells(f'A{discipline_start}:A{last_row}')
if field_start and current_field:
worksheet.merge_cells(f'B{field_start}:B{last_row}')
# Apply special background color to Overall row
for row_idx, row in enumerate(worksheet.iter_rows(min_row=2), start=2):
cell_value = row[0].value
if cell_value:
if 'Overall (sample-wise)' in str(cell_value):
for cell in row:
cell.fill = sample_wise_fill
elif 'Overall (subfield-wise)' in str(cell_value):
for cell in row:
cell.fill = subfield_wise_fill
elif 'Overall (field-wise)' in str(cell_value):
for cell in row:
cell.fill = field_wise_fill
elif 'Overall (discipline-wise)' in str(cell_value):
for cell in row:
cell.fill = discipline_wise_fill
elif worksheet.title == 'Error Rate' and 'Overall' in str(cell_value):
for cell in row:
cell.fill = error_rate_fill
elif worksheet.title == 'Miss Rate' and 'Overall' in str(cell_value):
for cell in row:
cell.fill = miss_rate_fill
# Set value format to keep two decimal places
for row in worksheet.iter_rows(min_row=2):
for cell in row[3:]: # Start from 4th column (skip Discipline, Field, Subfield columns)
if isinstance(cell.value, str) and '%' in cell.value:
try:
value = float(cell.value.strip('%')) / 100
cell.value = f"{value:.2%}"
except ValueError:
pass
# Set all cells to center alignment
for row in worksheet.rows:
for cell in row:
cell.alignment = Alignment(horizontal='center', vertical='center')
def format_cell_value(stats):
"""Format cell content, return string with acc/error/miss"""
total = stats['total']
if total == 0:
return '0%/0%/0%'
acc = stats['correct'] / total
error = stats['error'] / total
miss = stats['miss'] / total
return f"{acc:.1%}/{error:.1%}/{miss:.1%}"