This commit is contained in:
MaiziXiao 2025-03-11 09:32:35 +00:00
parent 7938f352d7
commit 7f31ef7357
8 changed files with 205 additions and 196 deletions

View File

@ -57,7 +57,7 @@
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2025.03.11\]** 现已支持 `SuperGPQA` LLM知识能力评测欢迎尝试🔥🔥🔥
- **\[2025.03.11\]** 现已支持 `SuperGPQA` LLM知识能力评测欢迎尝试🔥🔥🔥
- **\[2025.02.28\]** 我们为 `DeepSeek-R1` 系列模型添加了教程,请查看 [评估推理模型](docs/en/user_guides/deepseek_r1.md) 了解更多详情!🔥🔥🔥
- **\[2025.02.15\]** 我们新增了两个实用的评测工具用于LLM作为评判器的`GenericLLMEvaluator`和用于数学推理评估的`MATHEvaluator`。查看[LLM评判器](docs/zh_cn/advanced_guides/llm_judge.md)和[数学能力评测](docs/zh_cn/advanced_guides/general_math.md)文档了解更多详情!🔥🔥🔥
- **\[2025.01.16\]** 我们现已支持 [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) 模型,该模型在推理、知识类任务上取得同量级最优性能,欢迎尝试。

View File

@ -11,13 +11,13 @@ from opencompass.openicl.icl_retriever import ZeroRetriever
reader_cfg = dict(
input_columns=[
'question',
"options",
'options',
'discipline',
'field',
'subfield',
'difficulty',
"infer_prompt",
"prompt_mode",
'infer_prompt',
'prompt_mode',
],
output_column='answer_letter',
)
@ -47,7 +47,7 @@ eval_cfg = dict(
supergpqa_dataset = dict(
type=SuperGPQADataset,
abbr='supergpqa',
path="m-a-p/SuperGPQA",
path='m-a-p/SuperGPQA',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,

View File

@ -127,6 +127,7 @@ from .strategyqa import * # noqa: F401, F403
from .subjective import * # noqa: F401, F403
from .summedits import * # noqa: F401, F403
from .summscreen import * # noqa: F401, F403
from .supergpqa import * # noqa: F401, F403
from .svamp import * # noqa: F401, F403
from .tabmwp import * # noqa: F401, F403
from .taco import * # noqa: F401, F403
@ -147,4 +148,3 @@ from .xcopa import * # noqa: F401, F403
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
from .xlsum import * # noqa: F401, F403
from .xsum import * # noqa: F401, F403
from .supergpqa import *

View File

@ -1,38 +1,23 @@
import csv
import json
import os.path as osp
from os import environ
from datasets import load_dataset
import os
from datasets import Dataset, DatasetDict
from opencompass.datasets.supergpqa.supergpqa_utils import (
evaluate_responses,
find_file,
load_json_or_jsonl,
load_json_or_jsonl_with_idx,
load_yaml,
)
from datasets import Dataset, load_dataset
from opencompass.datasets.supergpqa.supergpqa_eval import (
extract_option_content, extract_option_labels)
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
import unittest
from opencompass.utils import get_data_path
from opencompass.datasets.supergpqa.supergpqa_eval import (
extract_option_labels,
extract_option_content,
)
from ..base import BaseDataset
def _parse(item, template, prompt_mode):
prompt_format = [
item['question']
+ '\n'
+ '\n'.join(
[
item['question'] + '\n' + '\n'.join([
f'{chr(65+i)}) {option}'
for i, option in enumerate(item['options'])
]
)
])
]
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
item['prompt_mode'] = prompt_mode
@ -41,6 +26,7 @@ def _parse(item, template, prompt_mode):
@LOAD_DATASET.register_module()
class SuperGPQADataset(BaseDataset):
@staticmethod
def load(path: str, prompt_mode: str, **kwargs):
path = get_data_path(path, local_mode=True)
@ -80,54 +66,42 @@ class SuperGPQAEvaluator(BaseEvaluator):
count = 0
err = 0
miss = 0
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
acc_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
count_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
details = []
for i, sample in enumerate(test_set):
sample["pred"] = prediction = predictions[i]
sample['pred'] = prediction = predictions[i]
gold = references[i]
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(
prediction, sample["options"]
)
predict = (
chr(sample["options"].index(predict) + 65)
if predict
else None
)
sample["extracted_answer"] = predict
if predict is None:
predict = extract_option_content(prediction,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(
response, sample["options"]
)
predict = (
chr(sample["options"].index(predict) + 65)
if predict
else None
)
if predict == None:
if predict is None:
predict = extract_option_content(response,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
if predict is None:
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
if predict is None:
predict = extract_option_content(
prediction, sample["options"]
)
predict = (
chr(sample["options"].index(predict) + 65)
if predict
else None
)
sample["extracted_answer"] = predict
prediction, sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
discipline = sample.get("discipline", "unknown")
field = sample.get("field", "unknown")
subfield = sample.get("subfield", "unknown")
difficulty = sample.get("difficulty", "unknown")
discipline = sample.get('discipline', 'unknown')
field = sample.get('field', 'unknown')
subfield = sample.get('subfield', 'unknown')
difficulty = sample.get('difficulty', 'unknown')
for level, key in [
('discipline', discipline),
@ -136,70 +110,75 @@ class SuperGPQAEvaluator(BaseEvaluator):
]:
if key not in stats[level]:
stats[level][key] = {
"correct": 0,
"total": 0,
"miss": 0,
"error": 0,
"discipline": discipline,
"field": field,
"subfield": subfield,
"difficulty": {
"easy": {"correct": 0, "total": 0},
"middle": {"correct": 0, "total": 0},
"hard": {"correct": 0, "total": 0},
'correct': 0,
'total': 0,
'miss': 0,
'error': 0,
'discipline': discipline,
'field': field,
'subfield': subfield,
'difficulty': {
'easy': {
'correct': 0,
'total': 0
},
'middle': {
'correct': 0,
'total': 0
},
'hard': {
'correct': 0,
'total': 0
},
},
}
stats[level][key]["total"] += 1
stats[level][key]["difficulty"][difficulty]["total"] += 1
stats[level][key]['total'] += 1
stats[level][key]['difficulty'][difficulty]['total'] += 1
answer_letter = sample["answer_letter"]
answer_letter = sample['answer_letter']
assert answer_letter == gold
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
sample["status"] = "correct"
stats[level][key]["correct"] += 1
stats[level][key]["difficulty"][difficulty]["correct"] += 1
elif predict == None or predict == "":
sample['status'] = 'correct'
stats[level][key]['correct'] += 1
stats[level][key]['difficulty'][difficulty]['correct'] += 1
elif predict == None or predict == '':
miss += 1
sample["status"] = "miss"
stats[level][key]["miss"] += 1
sample['status'] = 'miss'
stats[level][key]['miss'] += 1
elif predict == 'error':
err += 1
sample["status"] = "error"
stats[level][key]["error"] += 1
sample['status'] = 'error'
stats[level][key]['error'] += 1
else:
sample["status"] = "incorrect"
sample['status'] = 'incorrect'
count += 1
count_difficulty[difficulty] += 1
details.append(
{
details.append({
'pred': sample['pred'],
'answer': sample['answer'],
'parsed_answer': sample['extracted_answer'],
'correct': True if sample['status'] else False,
}
)
})
return {
'accuracy': acc / count if count > 0 else 0,
'error_rate': err / count if count > 0 else 0,
'miss_rate': miss / count if count > 0 else 0,
'hard_accuracy': (
acc_difficulty["hard"] / count_difficulty["hard"]
if count_difficulty["hard"] > 0
else 0
),
'middle_accuracy': (
acc_difficulty["middle"] / count_difficulty["middle"]
if count_difficulty["middle"] > 0
else 0
),
'easy_accuracy': (
acc_difficulty["easy"] / count_difficulty["easy"]
if count_difficulty["easy"] > 0
else 0
),
'details': details,
'accuracy':
acc / count if count > 0 else 0,
'error_rate':
err / count if count > 0 else 0,
'miss_rate':
miss / count if count > 0 else 0,
'hard_accuracy':
(acc_difficulty['hard'] /
count_difficulty['hard'] if count_difficulty['hard'] > 0 else 0),
'middle_accuracy':
(acc_difficulty['middle'] / count_difficulty['middle']
if count_difficulty['middle'] > 0 else 0),
'easy_accuracy':
(acc_difficulty['easy'] /
count_difficulty['easy'] if count_difficulty['easy'] > 0 else 0),
'details':
details,
}

View File

@ -1,7 +1,8 @@
import yaml
import uuid
class ConfigWrapper:
def __init__(self, config_path):
self._config = {}
with open(config_path, 'r') as file:
@ -19,33 +20,69 @@ class ConfigWrapper:
def __getattr__(self, key):
if key in self._config:
return self._config[key]
raise AttributeError(f"'ConfigWrapper' object has no attribute '{key}'")
raise AttributeError(
f"'ConfigWrapper' object has no attribute '{key}'")
def get_id(self, data):
if isinstance(self._config.get('id_key'), str):
return data.get(self._config.get('id_key'), None)
elif isinstance(self._config.get('id_key'), list):
return '_'.join([str(data[key]) for key in self._config.get('id_key') if key in data])
return '_'.join([
str(data[key]) for key in self._config.get('id_key')
if key in data
])
def print_all_keys(self):
print("config keys:")
print('config keys:')
for key, value in self._config.items():
print(f" - {key}: {value}")
print(f' - {key}: {value}')
config_wrapper = None
def initialize_config(config_path):
global config_wrapper
config_wrapper = ConfigWrapper(config_path)
def get_config_wrapper():
global config_wrapper
if config_wrapper is None:
raise RuntimeError("ConfigWrapper not initialized. Call initialize_config first.")
raise RuntimeError(
'ConfigWrapper not initialized. Call initialize_config first.')
return config_wrapper
if __name__ == '__main__':
config_path = 'config/config.yaml'
initialize_config(config_path)
data = {'idx': '50', 'step':21, 'question': 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\nPlease provide the decrypted answer, encapsulated in double square brackets. For example, the format should be: [[decrypted answer]].', 'answer': '[[P]]', 'category': 'Decryption', 'rule_id': '23', 'input': 'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"', 'steps_num': 23, 'description': 'For a number c=228 in the ciphertext:\nCalculate z = c^e mod n. Here ^ means multiplication.\nz is 80.\nBased on the decimal number represented by z, use the ascii code to find the corresponding letter as the plaintext letter p.\nPlease give the letter p in [[...]] format.\n', 'atom': 80}
data = {
'idx':
'50',
'step':
21,
'question':
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
'Please provide the decrypted answer, encapsulated in double square'
' brackets. For example, the format should be: [[decrypted answer]].',
'answer':
'[[P]]',
'category':
'Decryption',
'rule_id':
'23',
'input':
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
'steps_num':
23,
'description':
'For a number c=228 in the ciphertext:\n'
'Calculate z = c^e mod n. Here ^ means multiplication.\nz is 80.'
'\nBased on the decimal number represented by z, use the ascii '
'code to find the corresponding letter as the plaintext letter p.'
'\nPlease give the letter p in [[...]] format.\n',
'atom':
80,
}
print(config_wrapper.get_id(data))

View File

@ -1,27 +1,21 @@
import json
# flake8: noqa: W605
import re
import argparse
import os
from prettytable import PrettyTable
import pandas as pd
from tqdm import tqdm
import timeout_decorator
import multiprocessing
import time
from functools import partial
@timeout_decorator.timeout(5) # 5 seconds timeout
def safe_regex_search(pattern, text, flags=0):
try:
return re.search(pattern, text, flags)
except timeout_decorator.TimeoutError:
print(f"Regex match timeout: pattern={pattern}, text={text[:100]}...")
print(f'Regex match timeout: pattern={pattern}, text={text[:100]}...')
return None
except Exception as e:
print(f"Regex match error: {str(e)}")
print(f'Regex match error: {str(e)}')
return None
def extract_option_labels(text, options='ABCDEFGHIJ'):
if not isinstance(text, str) or not isinstance(options, str):
return 'error'
@ -29,7 +23,8 @@ def extract_option_labels(text, options='ABCDEFGHIJ'):
text = text.rstrip()
last_line = text.split('\n')[-1]
option_str = ''.join([chr(65 + i) for i in range(len(options))]) if options else 'ABCDEFGHIJ'
option_str = ''.join([chr(65 + i) for i in range(len(options))
]) if options else 'ABCDEFGHIJ'
patterns = [
# e.g. "The final answer to this question is: A."
@ -61,11 +56,14 @@ def extract_option_labels(text, options='ABCDEFGHIJ'):
return None
def extract_option_content(text, options_content=None):
if not isinstance(text, str) or not isinstance(options_content, list):
return 'error'
escaped_options_content = [re.escape(option_content) for option_content in options_content]
escaped_options_content = [
re.escape(option_content) for option_content in options_content
]
escaped_options_content_str = '|'.join(escaped_options_content)
text = text.rstrip()
@ -73,9 +71,7 @@ def extract_option_content(text, options_content=None):
patterns = [
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is:?\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'(?i:Answer)\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'^[^\w\r\n]*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
]
@ -83,7 +79,8 @@ def extract_option_content(text, options_content=None):
match = safe_regex_search(pattern, last_line)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(match.group(1))]
return options_content[escaped_options_content.index(
match.group(1))]
else:
return match.group(1)
@ -91,7 +88,8 @@ def extract_option_content(text, options_content=None):
match = safe_regex_search(pattern, text)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(match.group(1))]
return options_content[escaped_options_content.index(
match.group(1))]
else:
return match.group(1)

View File

@ -6,6 +6,7 @@ import sympy as sp
import yaml
from sympy.parsing.latex import parse_latex
def load_yaml(yaml_path):
"""Load a YAML file."""
if not os.path.exists(yaml_path):
@ -670,8 +671,7 @@ def evaluate_responses(data, mode, base_path=None):
answer = record.get('gold', '')
rule_id = record.get('rule_id', '')
is_correct = evaluate_response_vs_answer(response, answer,
question_type, rule_id,
idx)
question_type, rule_id, idx)
result_dict = {
'idx': idx,
'response': response,
@ -681,8 +681,10 @@ def evaluate_responses(data, mode, base_path=None):
}
if question_type == 'counterfactual':
real_life_answer = record.get('real_life_answer', '')
is_real_life = evaluate_response_vs_answer(
response, real_life_answer, question_type, rule_id, idx)
is_real_life = evaluate_response_vs_answer(response,
real_life_answer,
question_type, rule_id,
idx)
result_dict['real_life_answer'] = real_life_answer
result_dict['is_real_life'] = is_real_life
if question_type == 'cipher' and mode == 'subquestions':

View File

@ -47,9 +47,8 @@ class BaseEvaluator:
# please see opencompass/opencompass/tasks/openicl_eval.py Line 197-200
return self._out_dir
def group(
self, n: int, details: List[Dict[str, Any]], test_set: Dataset
) -> Dict[str, Any]:
def group(self, n: int, details: List[Dict[str, Any]],
test_set: Dataset) -> Dict[str, Any]:
example2replications = {}
for detail, example in zip(details, test_set):
example_abbr = f"{example['subdivision']}_{example['idx']}"
@ -64,28 +63,23 @@ class BaseEvaluator:
def reduce(self, details: List[Dict[str, Any]]) -> Dict[str, Any]:
g_passk_details = OrderedDict()
all_subdivisions = set(
[detail['example_abbr'].split('_')[0] for detail in details]
)
[detail['example_abbr'].split('_')[0] for detail in details])
all_metrics = list(details[0].keys())
for subdivision in sorted(list(all_subdivisions)):
for metric in all_metrics:
if metric in ['predictions', 'example_abbr']:
continue
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean(
[
detail[metric]
for detail in details
g_passk_details[f'{subdivision}/{metric}'] = 100 * np.mean([
detail[metric] for detail in details
if detail['example_abbr'].split('_')[0] == subdivision
]
)
])
for metric in all_metrics:
if metric in ['predictions', 'example_abbr']:
continue
g_passk_details[metric] = 100.0 * np.mean(
[detail[metric] for detail in details]
)
[detail[metric] for detail in details])
return g_passk_details
def evaluate(
@ -112,8 +106,7 @@ class BaseEvaluator:
**{
key: select_fn(i, real_size, value)
for key, value in score_kwargs.items()
}
)
})
details = results.pop('details', None)
if details is not None:
if isinstance(details, Dict):
@ -129,12 +122,10 @@ class BaseEvaluator:
eval_results[key].append(single_results[key])
for key in deepcopy(eval_results):
if isinstance(eval_results[key][0], float) or isinstance(
eval_results[key][0], int
):
eval_results[key][0], int):
if n > 1:
eval_results[key + f' ({n} runs average)'] = np.mean(
eval_results[key]
)
eval_results[key])
eval_results.pop(key)
else:
eval_results[key] = np.mean(eval_results[key])
@ -163,13 +154,14 @@ class BaseEvaluator:
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
for _k in [k] if isinstance(k, int) else k:
for threshold in thresholds:
g_pass = compute_g_pass_at_k(
n=n, c=c, k=_k, t=threshold
)
g_pass = compute_g_pass_at_k(n=n,
c=c,
k=_k,
t=threshold)
detail[f'G-Pass@{_k}_{threshold}'] = g_pass
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(
n=n, c=c, k=_k
)
detail[f'mG-Pass@{_k}'] = compute_mg_pass_at_k(n=n,
c=c,
k=_k)
eval_details.append(detail)
@ -196,7 +188,8 @@ class BaseEvaluator:
# Remove the original predictions field
detail.pop('predictions')
import ipdb; ipdb.set_trace()
import ipdb
ipdb.set_trace()
return eval_results
# If there are no details, return an empty dictionary