[Feature] Support SuperGPQA (#1924)

* support supergpqa

* remove unnecessary code

* remove unnecessary code

* Add Readme

* Add Readme

* fix lint

* fix lint

* update

* update

---------

Co-authored-by: mkj3085003 <mkj3085003@gmail.com>
Co-authored-by: MaiziXiao <xxllcc1993@gmail.com>
This commit is contained in:
Kangreen 2025-03-11 19:32:08 +08:00 committed by GitHub
parent e403fd21be
commit 59e49aedf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1317 additions and 8 deletions

View File

@ -57,6 +57,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2025.03.11\]** We have supported evaluation for `SuperGPQA` which is a great benchmark for measuring LLM knowledge ability 🔥🔥🔥
- **\[2025.02.28\]** We have added a tutorial for `DeepSeek-R1` series model, please check [Evaluating Reasoning Model](docs/en/user_guides/deepseek_r1.md) for more details! 🔥🔥🔥
- **\[2025.02.15\]** We have added two powerful evaluation tools: `GenericLLMEvaluator` for LLM-as-judge evaluations and `MATHEvaluator` for mathematical reasoning assessments. Check out the documentation for [LLM Judge](docs/en/advanced_guides/llm_judge.md) and [Math Evaluation](docs/en/advanced_guides/general_math.md) for more details! 🔥🔥🔥
- **\[2025.01.16\]** We now support the [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) model which has enhanced performance on reasoning and knowledge-intensive tasks.

View File

@ -57,6 +57,7 @@
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2025.03.11\]** 现已支持 `SuperGPQA` 覆盖285 个研究生学科的知识能力评测,欢迎尝试!🔥🔥🔥
- **\[2025.02.28\]** 我们为 `DeepSeek-R1` 系列模型添加了教程,请查看 [评估推理模型](docs/en/user_guides/deepseek_r1.md) 了解更多详情!🔥🔥🔥
- **\[2025.02.15\]** 我们新增了两个实用的评测工具用于LLM作为评判器的`GenericLLMEvaluator`和用于数学推理评估的`MATHEvaluator`。查看[LLM评判器](docs/zh_cn/advanced_guides/llm_judge.md)和[数学能力评测](docs/zh_cn/advanced_guides/general_math.md)文档了解更多详情!🔥🔥🔥
- **\[2025.01.16\]** 我们现已支持 [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) 模型,该模型在推理、知识类任务上取得同量级最优性能,欢迎尝试。

View File

@ -734,6 +734,8 @@
category: Understanding
paper: https://arxiv.org/pdf/1808.08745
configpath: opencompass/configs/datasets/Xsum
- supergpqa:
name: SuperGPQA
category: Knowledge
paper: https://arxiv.org/pdf/2502.14739
configpath: opencompass/configs/datasets/supergpqa

View File

@ -0,0 +1,57 @@
from opencompass.datasets.supergpqa.supergpqa import (
SuperGPQADataset,
SuperGPQAEvaluator,
)
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
# Reader configuration
reader_cfg = dict(
input_columns=[
'question',
'options',
'discipline',
'field',
'subfield',
'difficulty',
'infer_prompt',
'prompt_mode',
],
output_column='answer_letter',
)
# Inference configuration
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt='{infer_prompt}',
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# Evaluation configuration
eval_cfg = dict(
evaluator=dict(type=SuperGPQAEvaluator),
pred_role='BOT',
)
supergpqa_dataset = dict(
type=SuperGPQADataset,
abbr='supergpqa',
path='m-a-p/SuperGPQA',
prompt_mode='zero-shot',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
supergpqa_datasets = [supergpqa_dataset]

View File

@ -127,6 +127,7 @@ from .strategyqa import * # noqa: F401, F403
from .subjective import * # noqa: F401, F403
from .summedits import * # noqa: F401, F403
from .summscreen import * # noqa: F401, F403
from .supergpqa import * # noqa: F401, F403
from .svamp import * # noqa: F401, F403
from .tabmwp import * # noqa: F401, F403
from .taco import * # noqa: F401, F403

View File

@ -0,0 +1,184 @@
import os
from datasets import Dataset, load_dataset
from opencompass.datasets.supergpqa.supergpqa_eval import (
extract_option_content, extract_option_labels)
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from opencompass.utils import get_data_path
from ..base import BaseDataset
def _parse(item, template, prompt_mode):
prompt_format = [
item['question'] + '\n' + '\n'.join([
f'{chr(65+i)}) {option}'
for i, option in enumerate(item['options'])
])
]
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
item['prompt_mode'] = prompt_mode
return item
@LOAD_DATASET.register_module()
class SuperGPQADataset(BaseDataset):
@staticmethod
def load(path: str, prompt_mode: str, **kwargs):
path = get_data_path(path, local_mode=True)
dataset = load_dataset(path, split='train')
# get prompt template
template_path = None
if prompt_mode == 'zero-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/zero-shot.yaml',
)
elif prompt_mode == 'five-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/five-shot.yaml',
)
try:
template = load_yaml(template_path)
except FileNotFoundError:
print(f'[ERROR] Missing prompt template: {template_path}')
return Dataset.from_list([])
dataset = dataset.map(lambda item: _parse(item, template, prompt_mode))
return dataset
@ICL_EVALUATORS.register_module()
class SuperGPQAEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()
def score(self, predictions, references, test_set):
mode = test_set[0]['prompt_mode']
acc = 0
count = 0
err = 0
miss = 0
acc_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
count_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
details = []
for i, sample in enumerate(test_set):
sample['pred'] = prediction = predictions[i]
gold = references[i]
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict is None:
predict = extract_option_content(prediction,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
if predict is None:
predict = extract_option_content(response,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
if predict is None:
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict is None:
predict = extract_option_content(
prediction, sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
discipline = sample.get('discipline', 'unknown')
field = sample.get('field', 'unknown')
subfield = sample.get('subfield', 'unknown')
difficulty = sample.get('difficulty', 'unknown')
for level, key in [
('discipline', discipline),
# ('field', f"{discipline}/{field}"),
# ('subfield', f"{discipline}/{field}/{subfield}"),
]:
if key not in stats[level]:
stats[level][key] = {
'correct': 0,
'total': 0,
'miss': 0,
'error': 0,
'discipline': discipline,
'field': field,
'subfield': subfield,
'difficulty': {
'easy': {
'correct': 0,
'total': 0
},
'middle': {
'correct': 0,
'total': 0
},
'hard': {
'correct': 0,
'total': 0
},
},
}
stats[level][key]['total'] += 1
stats[level][key]['difficulty'][difficulty]['total'] += 1
answer_letter = sample['answer_letter']
assert answer_letter == gold
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
sample['status'] = 'correct'
stats[level][key]['correct'] += 1
stats[level][key]['difficulty'][difficulty]['correct'] += 1
elif predict is None or predict == '':
miss += 1
sample['status'] = 'miss'
stats[level][key]['miss'] += 1
elif predict == 'error':
err += 1
sample['status'] = 'error'
stats[level][key]['error'] += 1
else:
sample['status'] = 'incorrect'
count += 1
count_difficulty[difficulty] += 1
details.append({
'pred': sample['pred'],
'answer': sample['answer'],
'parsed_answer': sample['extracted_answer'],
'correct': True if sample['status'] else False,
})
return {
'accuracy':
acc / count if count > 0 else 0,
'error_rate':
err / count if count > 0 else 0,
'miss_rate':
miss / count if count > 0 else 0,
'hard_accuracy':
(acc_difficulty['hard'] /
count_difficulty['hard'] if count_difficulty['hard'] > 0 else 0),
'middle_accuracy':
(acc_difficulty['middle'] / count_difficulty['middle']
if count_difficulty['middle'] > 0 else 0),
'easy_accuracy':
(acc_difficulty['easy'] /
count_difficulty['easy'] if count_difficulty['easy'] > 0 else 0),
'details':
details,
}

View File

@ -0,0 +1,17 @@
response_key: 'response'
error_key: 'error'
id_key:
- 'uuid'
prompt_key: 'prompt'
history_key: 'history'
status_key: 'status'
save_prompt: True
max_tokens: 4096
temperatrue: 0.0
max_rounds: 30
BoN: 32

View File

@ -0,0 +1,17 @@
response_key: 'response'
error_key: 'error'
id_key:
- 'uuid'
prompt_key: 'prompt'
history_key: 'history'
status_key: 'status'
save_prompt: True
max_tokens: 32768
temperatrue: 0.0
max_rounds: 30
BoN: 32

View File

@ -0,0 +1,88 @@
import yaml
class ConfigWrapper:
def __init__(self, config_path):
self._config = {}
with open(config_path, 'r') as file:
self._config = yaml.safe_load(file)
for key, value in self._config.items():
setattr(self, key, value)
def __setattr__(self, key, value):
if key.startswith('_'):
super().__setattr__(key, value)
else:
self._config[key] = value
super().__setattr__(key, value)
def __getattr__(self, key):
if key in self._config:
return self._config[key]
raise AttributeError(
f"'ConfigWrapper' object has no attribute '{key}'")
def get_id(self, data):
if isinstance(self._config.get('id_key'), str):
return data.get(self._config.get('id_key'), None)
elif isinstance(self._config.get('id_key'), list):
return '_'.join([
str(data[key]) for key in self._config.get('id_key')
if key in data
])
def print_all_keys(self):
print('config keys:')
for key, value in self._config.items():
print(f' - {key}: {value}')
config_wrapper = None
def initialize_config(config_path):
global config_wrapper
config_wrapper = ConfigWrapper(config_path)
def get_config_wrapper():
global config_wrapper
if config_wrapper is None:
raise RuntimeError(
'ConfigWrapper not initialized. Call initialize_config first.')
return config_wrapper
if __name__ == '__main__':
config_path = 'config/config.yaml'
initialize_config(config_path)
data = {
'idx':
'50',
'step':
21,
'question':
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
'Please provide the decrypted answer, encapsulated in double square'
' brackets. For example, the format should be: [[decrypted answer]].',
'answer':
'[[P]]',
'category':
'Decryption',
'rule_id':
'23',
'input':
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
'steps_num':
23,
'description':
'For a number c=228 in the ciphertext:\n'
'Calculate z = c^e mod n. Here ^ means multiplication.\nz is 80.'
'\nBased on the decimal number represented by z, use the ascii '
'code to find the corresponding letter as the plaintext letter p.'
'\nPlease give the letter p in [[...]] format.\n',
'atom':
80,
}
print(config_wrapper.get_id(data))

View File

@ -0,0 +1,91 @@
prompt_format:
- |
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
Question:
A refracting telescope consists of two converging lenses separated by 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification of the telescope is
A) 10
B) 40
C) 6
D) 25
E) 15
F) 50
G) 30
H) 4
I) 5
J) 20
Answer: Let's think step by step. In a refracting telescope, if both lenses are converging, the focus of both lenses must be between the two lenses, and thus the focal lengths of the two lenses must add up to their separation. Since the focal length of one lens is 20 cm, the focal length of the other must be 80 cm. The magnification is the ratio of these two focal lengths, or 4.
Answer: H.
Question:
Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?
A) 1000 times more
B) 50 times more
C) 5000 times more
D) 500 times more
E) 10000 times more
F) 20000 times more
G) 2000 times more
H) 100 times more
I) 10 times more
J) N/A
Answer: Let's think step by step. The amount of light a telescope can gather compared to the human eye is proportional to the area of its apertures. The area of a circle is given by the formula $A = \pi \left(\frac{{D}}{{2}}\right)^2$, where $D$ is the diameter. Therefore, the relative light-gathering power is calculated as:
\[
\frac{{\left(\frac{{50 \text{{ cm}}}}{{2}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{2}}\right)^2}} = \frac{{\left(\frac{{50 \text{{ cm}}}}{{0.1 \text{{ cm}}}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{0.1 \text{{ cm}}}}\right)^2}} = \frac{{500^2}}{{5^2}} = 10000.
\]
Answer: E.
Question:
Where do most short-period comets come from and how do we know?
A) The Kuiper belt; short period comets tend to be in the plane of the solar system like the Kuiper belt.
B) The asteroid belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the asteroid belt.
C) The asteroid belt; short period comets tend to be in the plane of the solar system just like the asteroid belt.
D) The Oort cloud; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the Oort cloud.
E) The Oort Cloud; short period comets tend to come from random directions indicating a spherical distribution of comets called the Oort Cloud.
F) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.
G) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt.
Answer: Let's think step by step. Most short-period comets originate from the Kuiper belt. This is deduced from the observation that these comets tend to follow orbits that lie in the plane of the solar system, similar to the distribution of objects in the Kuiper belt itself. Thus, the alignment of these cometary orbits with the ecliptic plane points to their Kuiper belt origin.
Answer: A.
Question:
Colors in a soap bubble result from light
A) dispersion
B) deflection
C) refraction
D) reflection
E) interference
F) converted to a different frequency
G) polarization
H) absorption
I) diffraction
J) transmission
Answer: Let's think step by step. The colorful patterns observed in a soap bubble are caused by the phenomenon of light interference. This occurs when light waves bounce between the two surfaces of the soap film, combining constructively or destructively based on their phase differences and the varying thickness of the film. These interactions result in vibrant color patterns due to variations in the intensity of different wavelengths of light.
Answer: E.
Question:
A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?
A) 240 W
B) 120 W
C) 10 W
D) 480 W
E) 360 W
F) 200 W
G) 30 W
H) 150 W
I) 60 W
J) 300 W
Answer: Let's think step by step. The rate of energy usage, known as power, in an electrical circuit is calculated by the product of voltage and current. For a microwave oven connected to a 120 V outlet and drawing a current of 2 amps, the power consumption can be calculated as follows:
\[
\text{{Power}} = \text{{Voltage}} \times \text{{Current}} = 120 \, \text{{V}} \times 2 \, \text{{A}} = 240 \, \text{{W}}.
\]
Therefore, the microwave oven uses energy at a rate of 240 watts.
Answer: A.
Question:
{}
Answer: Let's think step by step.

View File

@ -0,0 +1,23 @@
initial_prompt_0:
- |
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}
initial_prompt_1:
- |
You are a helpful assistant. Answer the given multiple-choice question. Only one option is correct. The last line of your response should be in the format 'The correct answer is: $LETTER', where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}
initial_prompt_2:
- |
Select the correct answer for the following multiple-choice question. There is only one valid choice. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}
initial_prompt_3:
- |
Review the following multiple-choice question and choose the one correct answer. Ensure that your response concludes with a line exactly formatted as 'The correct answer is: $LETTER', where LETTER represents one of A, B, C, D, E, F, G, H, I, or J.
{}

View File

@ -0,0 +1,5 @@
prompt_format:
- |
Answer the following multiple choice question about {}. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}

View File

@ -0,0 +1,5 @@
prompt_format:
- |
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
{}

View File

@ -0,0 +1,96 @@
# flake8: noqa: W605
import re
import timeout_decorator
@timeout_decorator.timeout(5) # 5 seconds timeout
def safe_regex_search(pattern, text, flags=0):
try:
return re.search(pattern, text, flags)
except timeout_decorator.TimeoutError:
print(f'Regex match timeout: pattern={pattern}, text={text[:100]}...')
return None
except Exception as e:
print(f'Regex match error: {str(e)}')
return None
def extract_option_labels(text, options='ABCDEFGHIJ'):
if not isinstance(text, str) or not isinstance(options, str):
return 'error'
text = text.rstrip()
last_line = text.split('\n')[-1]
option_str = ''.join([chr(65 + i) for i in range(len(options))
]) if options else 'ABCDEFGHIJ'
patterns = [
# e.g. "The final answer to this question is: A."
# "The best option is $\boxed{B}:"
# "The correct answer is (C)."
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is?:?\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
# e.g. "ANSWER: A"
# "Answer: $\boxed{B}."
# "ANSWER: (C):"
f'(?i:Answer)[\*\s]*:\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
# e.g. "A"
# "$\boxed{B}$"
# "(C)."
# "[D]:"
f'^[^\w\r\n]*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
]
for pattern in patterns:
match = safe_regex_search(pattern, last_line, re.IGNORECASE)
if match:
return match.group(1)
for pattern in patterns:
match = safe_regex_search(pattern, text, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_option_content(text, options_content=None):
if not isinstance(text, str) or not isinstance(options_content, list):
return 'error'
escaped_options_content = [
re.escape(option_content) for option_content in options_content
]
escaped_options_content_str = '|'.join(escaped_options_content)
text = text.rstrip()
last_line = text.split('\n')[-1]
patterns = [
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is:?\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'(?i:Answer)\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
f'^[^\w\r\n]*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
]
for pattern in patterns:
match = safe_regex_search(pattern, last_line)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(
match.group(1))]
else:
return match.group(1)
for pattern in patterns:
match = safe_regex_search(pattern, text)
if match:
if match.group(1) in escaped_options_content:
return options_content[escaped_options_content.index(
match.group(1))]
else:
return match.group(1)
return None

View File

@ -0,0 +1,693 @@
import json
import os
import re
import sympy as sp
import yaml
from sympy.parsing.latex import parse_latex
def load_yaml(yaml_path):
"""Load a YAML file."""
if not os.path.exists(yaml_path):
raise FileNotFoundError(f'YAML file not found: {yaml_path}')
with open(yaml_path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
def load_json_or_jsonl(file_path):
"""Load data from a JSON or JSONL file."""
if not os.path.exists(file_path):
return None
with open(file_path, 'r', encoding='utf-8') as file:
if file_path.endswith('.json'):
return json.load(file)
elif file_path.endswith('.jsonl'):
return [json.loads(line) for line in file]
return None
def find_file(base_path, sub_path, extensions=('json', 'jsonl')):
"""Find the first available file with given extensions."""
for ext in extensions:
file_path = os.path.join(base_path, f'{sub_path}.{ext}')
if os.path.exists(file_path):
return file_path
return None
def load_json_or_jsonl_with_idx(data_path, split='', idx=None):
base_path = os.path.join(data_path, split)
if os.path.exists(f'{base_path}.json'):
file_path = f'{base_path}.json'
elif os.path.exists(f'{base_path}.jsonl'):
file_path = f'{base_path}.jsonl'
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
file_path = base_path
else:
raise FileNotFoundError('No JSON or JSONL file found.')
with open(file_path, 'r', encoding='utf-8') as file:
if file_path.endswith('.json'):
data = json.load(file)
elif file_path.endswith('.jsonl'):
data = [json.loads(line) for line in file]
if idx is not None:
try:
return next(item for item in data if item.get('idx') == idx)
except StopIteration:
raise ValueError(f'No entry found for idx {idx}')
else:
return data
def load_split_data(base_path, split_name):
"""Load the rule and sample data for a specific split."""
split_path = os.path.join(base_path, split_name)
rule_path = find_file(split_path, 'rule')
sample_path = find_file(split_path, 'sample')
rules = load_json_or_jsonl(rule_path) if rule_path else []
samples = load_json_or_jsonl(sample_path) if sample_path else []
return {'rules': rules, 'samples': samples}
def process_mixed_data(base_path, mode):
"""Load and process data for the 'mixed' split and specific mode."""
mixed_path = os.path.join(base_path, 'mixed')
file_path = find_file(mixed_path, mode)
if not file_path:
print(f'[WARNING] Missing file for mixed mode: {mode}')
return []
data = load_json_or_jsonl(file_path)
template_path = os.path.join(base_path, 'config/prompt/mixed.yaml')
template = load_yaml(template_path)
processed = []
for item in data:
rules = '\n'.join(item.get('rule_list', []))
questions = '\n'.join(item.get('question_list', []))
item['prompt'] = template['prompt_format'][0].format(rules, questions)
processed.append(item)
return processed
class ConfigWrapper:
def __init__(self, config_path):
self._config = {}
with open(config_path, 'r') as file:
self._config = yaml.safe_load(file)
for key, value in self._config.items():
setattr(self, key, value)
def __setattr__(self, key, value):
if key.startswith('_'):
super().__setattr__(key, value)
else:
self._config[key] = value
super().__setattr__(key, value)
def __getattr__(self, key):
if key in self._config:
return self._config[key]
raise AttributeError(
f"'ConfigWrapper' object has no attribute '{key}'")
def get_id(self, data):
if isinstance(self._config.get('id_key'), str):
return data.get(self._config.get('id_key'), None)
elif isinstance(self._config.get('id_key'), list):
return '_'.join([
str(data[key]) for key in self._config.get('id_key')
if key in data
])
def print_all_keys(self):
print('config keys:')
for key, value in self._config.items():
print(f' - {key}: {value}')
config_wrapper = None
def initialize_config(config_path):
global config_wrapper
config_wrapper = ConfigWrapper(config_path)
def get_config_wrapper():
global config_wrapper
if config_wrapper is None:
raise RuntimeError(
'ConfigWrapper not initialized. Call initialize_config first.')
return config_wrapper
if __name__ == '__main__':
config_path = 'config/config.yaml'
initialize_config(config_path)
data = {
'idx':
'50',
'step':
21,
'question':
('Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
'Please provide the decrypted answer, encapsulated in double '
'square brackets. '
'For example, the format should be: [[decrypted answer]].'),
'answer':
'[[P]]',
'category':
'Decryption',
'rule_id':
'23',
'input':
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
'steps_num':
23,
'description':
('For a number c=228 in the ciphertext:\n'
'Calculate z = c^e mod n. Here ^ means multiplication.\n'
'z is 80.\nBased on the decimal number represented by z, '
'use the ascii code to find the corresponding letter '
'as the plaintext letter p.\n'
'Please give the letter p in [[...]] format.\n'),
'atom':
80
}
print(config_wrapper.get_id(data))
def read_yaml(config='default'):
if os.path.exists(f'config/prompt/{config}.yaml'):
yaml_file = f'config/prompt/{config}.yaml'
else:
yaml_file = config
with open(yaml_file, 'r') as yaml_file:
return yaml.safe_load(yaml_file)
def write_jsonl_lines(file, data):
config_wrapper = get_config_wrapper()
if config_wrapper.save_prompt:
json.dump(data, file, ensure_ascii=False)
else:
data.pop(config_wrapper.prompt_key)
json.dump(data, file, ensure_ascii=False)
file.write('\n')
file.flush()
def print_info(info):
print('-' * 100)
print('[INFO] model_name:', info['model_name'])
print('[INFO] splits:', info['splits'])
print('[INFO] modes:', info['modes'])
print('[INFO] output_dir:', info['output_dir'])
print('[INFO] Infer Limit:',
'No limit' if info['infer_limit'] is None else info['infer_limit'])
print('[INFO] Number of Workers:', info['num_workers'])
print('[INFO] Batch Size:', info['batch_size'])
print('[INFO] Use Accel:', info['use_accel'])
print('-' * 100)
def read_json_or_jsonl(data_path, split='', mapping_key=None):
base_path = os.path.join(data_path, split)
if os.path.exists(f'{base_path}.json'):
file_path = f'{base_path}.json'
elif os.path.exists(f'{base_path}.jsonl'):
file_path = f'{base_path}.jsonl'
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
file_path = base_path
else:
raise FileNotFoundError('No JSON or JSONL file found.')
with open(file_path, 'r') as file:
if file_path.endswith('.json'):
data = json.load(file)
elif file_path.endswith('.jsonl'):
data = [json.loads(line) for line in file]
if mapping_key:
return {
item[mapping_key]: item
for item in data if mapping_key in item
}
else:
return data
def read_json_or_jsonl_with_idx(data_path, split='', idx=None):
base_path = os.path.join(data_path, split)
if os.path.exists(f'{base_path}.json'):
file_path = f'{base_path}.json'
elif os.path.exists(f'{base_path}.jsonl'):
file_path = f'{base_path}.jsonl'
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
file_path = base_path
else:
raise FileNotFoundError('No JSON or JSONL file found.')
with open(file_path, 'r', encoding='utf-8') as file:
if file_path.endswith('.json'):
data = json.load(file)
elif file_path.endswith('.jsonl'):
data = [json.loads(line) for line in file]
if idx is not None:
try:
return next(item for item in data if item.get('idx') == idx)
except StopIteration:
raise ValueError(f'No entry found for idx {idx}')
else:
return data
idx_ranges = [
[18],
[73, 74, 77],
[94],
[115, 116, 117],
[121, 122, 123, 125],
[131, 132, 134, 135, 136],
[141, 143, 149],
list(range(145, 148)),
list(range(151, 157)),
[160, 161, 162],
[164, 165, 166],
[170],
[206, 209],
list(range(211, 216)),
[217, 218],
]
def clean_json_string(json_str):
json_str = re.sub(r'[\x00-\x1F\x7F]', '', json_str)
return json_str
def is_in_idx_ranges(idx, idx_ranges):
for range_list in idx_ranges:
if int(idx) in range_list:
return True
return False
def extract_json(text):
matches = re.findall(r'{.*}', text, re.DOTALL)
if matches:
json_str = matches[-1]
json_str = clean_json_string(json_str)
try:
data = json.loads(json_str)
return data
except json.JSONDecodeError as e:
print(f'Error decoding JSON: {e}')
return 'NULL'
return 'NULL'
def extract_all_responses_from_json(response_json):
results = []
for key, value in response_json.items():
results.append(str(value))
return results
def clean_latex(latex_expr):
if '=' in latex_expr:
latex_expr = latex_expr.rsplit('=', 1)[1]
latex_expr = re.sub(r'\\[()\[\]]', '', latex_expr)
latex_expr = re.sub(r'\\text\{.*?\}', '', latex_expr)
latex_expr = re.sub(r'\\(left|right|displaystyle)', '', latex_expr)
latex_expr = latex_expr.replace('\\\\', '\\')
return latex_expr
def extract_text_from_brackets(text, clean_level='basic'):
matches = re.findall(r'\[\[\s*(.*?)\s*\]\]', text, re.DOTALL)
if not matches:
matches = re.findall(r'\$\\boxed\{(.*?)\}\$', text, re.DOTALL)
if not matches:
matches = re.findall(r'\[\s*(.*?)\s*\]', text, re.DOTALL)
if matches:
match_str = matches[0].strip()
if clean_level == 'clean':
match_str = match_str.replace('"', '').replace('\n', '').replace(
' ', '').replace('[', '').replace(']', '')
elif clean_level == 'logic':
match_str = match_str.replace('"', '').replace('\n', '').replace(
' ', '').replace('.', '')
elif clean_level == 'math':
match_str = match_str.replace('"', '').replace('\n', '').replace(
'[', '').replace(']', '').replace('$', '')
return f'{clean_latex(match_str)}'
return f'[[{match_str}]]'
return 'NULL'
def extract_inner_text_from_brackets(text):
if not isinstance(text, str):
print(f'text type: {type(text)}, text value: {text}')
return 'NULL'
match = re.search(r'\[\[(.*?)\]\]', text, re.DOTALL)
return match.group(1) if match else 'NULL'
def extract_numbers(str):
numbers = re.findall(r'\d+', str)
numbers = list(map(int, numbers))
return numbers
def extract_and_sort_inequalities(latex_expr):
pattern = r'(≥|≤)\s*([-]?\d+\.?\d*)'
matches = re.findall(pattern, latex_expr)
extracted_inequalities = [''.join(match) for match in matches]
sorted_inequalities = sorted(extracted_inequalities)
return sorted_inequalities
def rule5_normalize_content(content):
parts = [part for part in content.split(';')]
sorted_parts = sorted(parts)
return sorted_parts
def normalize_string(s):
s = re.sub(r'[^0-9]', '', s)
pairs = s.split(',')
pairs.sort()
return pairs
def remove_commas_and_spaces(s):
return re.sub(r'[,\s\[\]]+', '', s)
def remove_non_alphanumeric(s):
return re.sub(r'\W+', '', s)
def contains_or(answer):
return 'or' in answer
def compare_multi_results(response, answer):
try:
response_text = extract_text_from_brackets(response, 'clean')
response_text = re.sub(r'\\text\{or\}', 'or', response_text)
if response_text == 'NULL':
return False
answer = extract_text_from_brackets(answer, 'clean')
response_split = response_text.strip('[[]]').split('or')
answer_split = answer.strip('[[]]').split('or')
response_sorted = sorted([x.strip() for x in response_split])
answer_sorted = sorted([x.strip() for x in answer_split])
return response_sorted == answer_sorted
except Exception as e:
print(f'Error during comparison: {e}')
return False
def split_or_expression(expression):
return [part.strip() for part in expression.split('or')]
def compare_math_expressions(response, answer):
response_text = extract_text_from_brackets(response, 'math')
answer_text = extract_text_from_brackets(answer, 'math')
if response_text == 'NULL':
return False
if contains_or(answer_text):
response_parts = split_or_expression(response_text)
answer_parts = split_or_expression(answer_text)
try:
response_exprs = {
sp.simplify(parse_latex(part))
for part in response_parts
}
answer_exprs = {
sp.simplify(parse_latex(part))
for part in answer_parts
}
return response_exprs == answer_exprs
except Exception as e:
print(f'Error during simplification or parsing: {e}')
return response_text == answer_text
else:
try:
response_expr = sp.simplify(parse_latex(response_text))
answer_expr = sp.simplify(parse_latex(answer_text))
return response_expr == answer_expr
except Exception as e:
print(f'Error during simplification or parsing: {e}')
return response_text == answer_text
def method_equal(response_text, answer):
return response_text == answer
def method_1(response_text, answer):
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
cleaned_string = cleaned_string.lower()
answer = re.sub(r'[^A-Za-z]', '', answer)
answer = answer.lower()
return cleaned_string == answer
def method_2(response_text, answer):
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
cleaned_string = cleaned_string.lower()
answer = answer.split(',')
return cleaned_string in answer
def method_3(response_text, answer):
response_text = response_text.lower()
pairs1 = re.split(r'\W+', response_text)
pairs2 = answer.split(' ')
pairs1 = [word for word in pairs1 if word]
pairs1.sort()
pairs2.sort()
return pairs1 == pairs2
def method_4(response_text, answer):
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
cleaned_string = cleaned_string.lower()
return cleaned_string in answer
def method_5(response_text, answer):
response_text = re.sub(r'\s+', '', response_text)
response_text = response_text.split(',')
answer = answer.split(',')
response_text.sort()
answer.sort()
return response_text == answer
def method_9(response_text, answer):
response_text = response_text.replace('×', '*').replace('', '-')
answer = answer.replace('×', '*').replace('', '-')
def extract_operators(s):
return re.findall(r'[+\-*/]', s)
response_ops = extract_operators(response_text.split('=')[0])
answer_ops = extract_operators(answer.split('=')[0])
if response_ops != answer_ops:
return False
match = re.search(r'=\s*(-?\d+)', answer)
expected_result = int(match.group(1))
try:
left_side = response_text.split('=')[0]
result = eval(left_side)
except Exception as e:
print(f'Error during evaluation: {e}')
return False
return result == expected_result
def method_10(response_text, answer):
response_text = response_text.replace('×', '*').replace('', '-')
response_text = response_text.split('=')[0]
answer = answer.split('\n')[0].split('=')[0]
response_ops = sorted(remove_non_alphanumeric(response_text))
answer_ops = sorted(remove_non_alphanumeric(answer))
if response_ops != answer_ops:
return False
try:
result = eval(response_text)
except Exception as e:
print(f'Error during evaluation: {e}')
return False
return result == 24
def method_18(response_text, answer):
cleaned_s1 = remove_commas_and_spaces(response_text)
cleaned_s2 = remove_commas_and_spaces(answer)
return cleaned_s1 == cleaned_s2
def method_general(response_text, answer):
cleaned_s1 = remove_non_alphanumeric(response_text)
cleaned_s2 = remove_non_alphanumeric(answer)
return cleaned_s1 == cleaned_s2
question_methods = {
'1': method_1,
'2': method_2,
'3': method_3,
'4': method_4,
'5': method_5,
'9': method_9,
'10': method_10,
'18': method_18,
}
def evaluate_response_vs_answer(response, answer, question_type, rule_id, idx):
if question_type == 'logic' and rule_id == '5':
response_text = extract_text_from_brackets(response, 'logic')
answer_text = extract_text_from_brackets(answer, 'logic')
if response_text is None:
return False
normalized_response = rule5_normalize_content(response_text)
normalized_answer = rule5_normalize_content(answer)
return normalized_response == normalized_answer
elif question_type == 'logic':
response_text = extract_text_from_brackets(response, 'logic')
answer_text = extract_text_from_brackets(answer, 'logic')
return response_text == answer_text
elif question_type == 'operation' and (idx == '178' or idx == '179'):
response_text = extract_text_from_brackets(response, 'clean')
response_text = extract_and_sort_inequalities(response_text)
answer_text = extract_and_sort_inequalities(answer)
# print(response_text, answer_text)
return response_text == answer_text
elif question_type == 'operation' and rule_id == '18':
response_text = extract_text_from_brackets(response, 'clean')
answer = extract_inner_text_from_brackets(answer)
response_text = ''.join(sorted(re.sub(r'\W+', '', response_text)))
answer = ''.join(sorted(re.sub(r'\W+', '', answer)))
return response_text == answer
elif question_type == 'operation' and rule_id in {'23', '24', '25'}:
response_text = extract_text_from_brackets(response, 'clean')
if response_text is None:
return False
response_text = extract_numbers(response_text)
answer_text = extract_numbers(answer)
return response_text == answer_text
elif question_type == 'operation' and is_in_idx_ranges(idx, idx_ranges):
return compare_math_expressions(response, answer)
elif question_type == 'operation' and contains_or(answer):
return compare_multi_results(response, answer)
elif question_type == 'puzzle':
response_text = extract_inner_text_from_brackets(response)
answer = extract_inner_text_from_brackets(answer)
method = question_methods.get(rule_id)
if method:
return method(response_text, answer)
return method_general(response_text, answer)
else:
response_text = extract_text_from_brackets(response, 'clean')
return response_text == answer
def compute_one_mixed_question_pass_rate(idx,
question_list,
response_json,
base_path=None):
if response_json == 'NULL':
result_dict = {
'idx': idx,
'response': response_json,
'details': None,
'pass_rate': 0,
'is_correct': False
}
return result_dict
response_list = extract_all_responses_from_json(response_json)
correct_num = 0
results = []
for q_idx, question in enumerate(question_list):
category, question_idx = question.rsplit('_', 1)
question_content = load_json_or_jsonl_with_idx(base_path,
os.path.join(
category, 'sample'),
idx=question_idx)
answer = question_content['answer']
if q_idx >= len(response_list):
break
response = response_list[q_idx]
response_text = extract_text_from_brackets(response)
rule_id = question_content['rule_id']
is_correct = evaluate_response_vs_answer(response, answer, category,
rule_id, q_idx)
if is_correct:
correct_num += 1
results.append({
'question': question,
'response_text': response_text,
'answer': answer,
'is_correct': is_correct
})
pass_rate = correct_num / len(question_list)
question_correct = pass_rate == 1.0
result_dict = {
'idx': idx,
'response': response_json,
'details': results,
'pass_rate': pass_rate,
'is_correct': question_correct
}
return result_dict
def evaluate_responses(data, mode, base_path=None):
results = []
# Iterate over the values of the dictionary (numerical keys)
for key, record in data.items():
idx = key # Use the dictionary key as the "idx"
response = record.get('prediction', '')
question_type = record.get('category', '')
response_text = extract_text_from_brackets(response)
answer = record.get('gold', '')
rule_id = record.get('rule_id', '')
is_correct = evaluate_response_vs_answer(response, answer,
question_type, rule_id, idx)
result_dict = {
'idx': idx,
'response': response,
'response_text': response_text,
'answer': answer,
'is_correct': is_correct
}
if question_type == 'counterfactual':
real_life_answer = record.get('real_life_answer', '')
is_real_life = evaluate_response_vs_answer(response,
real_life_answer,
question_type, rule_id,
idx)
result_dict['real_life_answer'] = real_life_answer
result_dict['is_real_life'] = is_real_life
if question_type == 'cipher' and mode == 'subquestions':
result_dict['type'] = record.get('type', '')
results.append(result_dict)
return results

View File

@ -1,4 +1,5 @@
"""Base Evaluator."""
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Union
@ -77,12 +78,17 @@ class BaseEvaluator:
for metric in all_metrics:
if metric in ['predictions', 'example_abbr']:
continue
g_passk_details[metric] = 100. * np.mean(
g_passk_details[metric] = 100.0 * np.mean(
[detail[metric] for detail in details])
return g_passk_details
def evaluate(self, k: Union[int, List[int]], n: int,
original_dataset: Dataset, **score_kwargs):
def evaluate(
self,
k: Union[int, List[int]],
n: int,
original_dataset: Dataset,
**score_kwargs,
):
real_size = len(original_dataset) // n
all_details = []
all_results = []
@ -146,7 +152,7 @@ class BaseEvaluator:
if can_calculate and n > 1 and k > 1:
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
for _k in ([k] if isinstance(k, int) else k):
for _k in [k] if isinstance(k, int) else k:
for threshold in thresholds:
g_pass = compute_g_pass_at_k(n=n,
c=c,
@ -161,9 +167,31 @@ class BaseEvaluator:
if can_calculate and n > 1 and k > 1:
eval_results.update(self.reduce(eval_details))
# Store eval_details in eval_results
eval_results['details'] = eval_details
return eval_results
# Process details to flatten the predictions
for detail in eval_details:
# Extract all prediction fields and flatten them
flattened_predictions = {}
for pred in detail['predictions']:
for k, v in pred.items():
if k not in flattened_predictions:
flattened_predictions[k] = [v]
else:
flattened_predictions[k].append(v)
# Replace the predictions list with the flattened dictionary
for k, v in flattened_predictions.items():
detail[k] = v
# Remove the original predictions field
detail.pop('predictions')
return eval_results
# If there are no details, return an empty dictionary
return {}
def score(self):
raise NotImplementedError("Method hasn't been implemented yet")