mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support SuperGPQA (#1924)
* support supergpqa * remove unnecessary code * remove unnecessary code * Add Readme * Add Readme * fix lint * fix lint * update * update --------- Co-authored-by: mkj3085003 <mkj3085003@gmail.com> Co-authored-by: MaiziXiao <xxllcc1993@gmail.com>
This commit is contained in:
parent
e403fd21be
commit
59e49aedf1
@ -57,6 +57,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
|
||||
|
||||
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
|
||||
|
||||
- **\[2025.03.11\]** We have supported evaluation for `SuperGPQA` which is a great benchmark for measuring LLM knowledge ability 🔥🔥🔥
|
||||
- **\[2025.02.28\]** We have added a tutorial for `DeepSeek-R1` series model, please check [Evaluating Reasoning Model](docs/en/user_guides/deepseek_r1.md) for more details! 🔥🔥🔥
|
||||
- **\[2025.02.15\]** We have added two powerful evaluation tools: `GenericLLMEvaluator` for LLM-as-judge evaluations and `MATHEvaluator` for mathematical reasoning assessments. Check out the documentation for [LLM Judge](docs/en/advanced_guides/llm_judge.md) and [Math Evaluation](docs/en/advanced_guides/general_math.md) for more details! 🔥🔥🔥
|
||||
- **\[2025.01.16\]** We now support the [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) model which has enhanced performance on reasoning and knowledge-intensive tasks.
|
||||
|
@ -57,6 +57,7 @@
|
||||
|
||||
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
|
||||
|
||||
- **\[2025.03.11\]** 现已支持 `SuperGPQA` 覆盖285 个研究生学科的知识能力评测,欢迎尝试!🔥🔥🔥
|
||||
- **\[2025.02.28\]** 我们为 `DeepSeek-R1` 系列模型添加了教程,请查看 [评估推理模型](docs/en/user_guides/deepseek_r1.md) 了解更多详情!🔥🔥🔥
|
||||
- **\[2025.02.15\]** 我们新增了两个实用的评测工具:用于LLM作为评判器的`GenericLLMEvaluator`和用于数学推理评估的`MATHEvaluator`。查看[LLM评判器](docs/zh_cn/advanced_guides/llm_judge.md)和[数学能力评测](docs/zh_cn/advanced_guides/general_math.md)文档了解更多详情!🔥🔥🔥
|
||||
- **\[2025.01.16\]** 我们现已支持 [InternLM3-8B-Instruct](https://huggingface.co/internlm/internlm3-8b-instruct) 模型,该模型在推理、知识类任务上取得同量级最优性能,欢迎尝试。
|
||||
|
@ -734,6 +734,8 @@
|
||||
category: Understanding
|
||||
paper: https://arxiv.org/pdf/1808.08745
|
||||
configpath: opencompass/configs/datasets/Xsum
|
||||
|
||||
|
||||
|
||||
- supergpqa:
|
||||
name: SuperGPQA
|
||||
category: Knowledge
|
||||
paper: https://arxiv.org/pdf/2502.14739
|
||||
configpath: opencompass/configs/datasets/supergpqa
|
||||
|
57
opencompass/configs/datasets/supergpqa/supergpqa_gen.py
Normal file
57
opencompass/configs/datasets/supergpqa/supergpqa_gen.py
Normal file
@ -0,0 +1,57 @@
|
||||
from opencompass.datasets.supergpqa.supergpqa import (
|
||||
SuperGPQADataset,
|
||||
SuperGPQAEvaluator,
|
||||
)
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
|
||||
|
||||
# Reader configuration
|
||||
reader_cfg = dict(
|
||||
input_columns=[
|
||||
'question',
|
||||
'options',
|
||||
'discipline',
|
||||
'field',
|
||||
'subfield',
|
||||
'difficulty',
|
||||
'infer_prompt',
|
||||
'prompt_mode',
|
||||
],
|
||||
output_column='answer_letter',
|
||||
)
|
||||
|
||||
# Inference configuration
|
||||
infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt='{infer_prompt}',
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
# Evaluation configuration
|
||||
eval_cfg = dict(
|
||||
evaluator=dict(type=SuperGPQAEvaluator),
|
||||
pred_role='BOT',
|
||||
)
|
||||
supergpqa_dataset = dict(
|
||||
type=SuperGPQADataset,
|
||||
abbr='supergpqa',
|
||||
path='m-a-p/SuperGPQA',
|
||||
prompt_mode='zero-shot',
|
||||
reader_cfg=reader_cfg,
|
||||
infer_cfg=infer_cfg,
|
||||
eval_cfg=eval_cfg,
|
||||
)
|
||||
|
||||
supergpqa_datasets = [supergpqa_dataset]
|
@ -127,6 +127,7 @@ from .strategyqa import * # noqa: F401, F403
|
||||
from .subjective import * # noqa: F401, F403
|
||||
from .summedits import * # noqa: F401, F403
|
||||
from .summscreen import * # noqa: F401, F403
|
||||
from .supergpqa import * # noqa: F401, F403
|
||||
from .svamp import * # noqa: F401, F403
|
||||
from .tabmwp import * # noqa: F401, F403
|
||||
from .taco import * # noqa: F401, F403
|
||||
|
0
opencompass/datasets/supergpqa/__init__.py
Normal file
0
opencompass/datasets/supergpqa/__init__.py
Normal file
184
opencompass/datasets/supergpqa/supergpqa.py
Normal file
184
opencompass/datasets/supergpqa/supergpqa.py
Normal file
@ -0,0 +1,184 @@
|
||||
import os
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from opencompass.datasets.supergpqa.supergpqa_eval import (
|
||||
extract_option_content, extract_option_labels)
|
||||
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from ..base import BaseDataset
|
||||
|
||||
|
||||
def _parse(item, template, prompt_mode):
|
||||
prompt_format = [
|
||||
item['question'] + '\n' + '\n'.join([
|
||||
f'{chr(65+i)}) {option}'
|
||||
for i, option in enumerate(item['options'])
|
||||
])
|
||||
]
|
||||
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
|
||||
item['prompt_mode'] = prompt_mode
|
||||
return item
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class SuperGPQADataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, prompt_mode: str, **kwargs):
|
||||
path = get_data_path(path, local_mode=True)
|
||||
dataset = load_dataset(path, split='train')
|
||||
|
||||
# get prompt template
|
||||
template_path = None
|
||||
if prompt_mode == 'zero-shot':
|
||||
template_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'supergpqa_dataset_config/prompt/zero-shot.yaml',
|
||||
)
|
||||
elif prompt_mode == 'five-shot':
|
||||
template_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'supergpqa_dataset_config/prompt/five-shot.yaml',
|
||||
)
|
||||
try:
|
||||
template = load_yaml(template_path)
|
||||
except FileNotFoundError:
|
||||
print(f'[ERROR] Missing prompt template: {template_path}')
|
||||
return Dataset.from_list([])
|
||||
|
||||
dataset = dataset.map(lambda item: _parse(item, template, prompt_mode))
|
||||
return dataset
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class SuperGPQAEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def score(self, predictions, references, test_set):
|
||||
mode = test_set[0]['prompt_mode']
|
||||
acc = 0
|
||||
count = 0
|
||||
err = 0
|
||||
miss = 0
|
||||
acc_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
|
||||
count_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
|
||||
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
|
||||
details = []
|
||||
for i, sample in enumerate(test_set):
|
||||
sample['pred'] = prediction = predictions[i]
|
||||
gold = references[i]
|
||||
if mode == 'zero-shot':
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict is None:
|
||||
predict = extract_option_content(prediction,
|
||||
sample['options'])
|
||||
predict = (chr(sample['options'].index(predict) +
|
||||
65) if predict else None)
|
||||
sample['extracted_answer'] = predict
|
||||
elif mode == 'five-shot':
|
||||
response = prediction.split('Question:')[0]
|
||||
predict = extract_option_labels(response, 'ABCDEFGHIJ')
|
||||
if predict is None:
|
||||
predict = extract_option_content(response,
|
||||
sample['options'])
|
||||
predict = (chr(sample['options'].index(predict) +
|
||||
65) if predict else None)
|
||||
if predict is None:
|
||||
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
||||
if predict is None:
|
||||
predict = extract_option_content(
|
||||
prediction, sample['options'])
|
||||
predict = (chr(sample['options'].index(predict) +
|
||||
65) if predict else None)
|
||||
sample['extracted_answer'] = predict
|
||||
|
||||
discipline = sample.get('discipline', 'unknown')
|
||||
field = sample.get('field', 'unknown')
|
||||
subfield = sample.get('subfield', 'unknown')
|
||||
difficulty = sample.get('difficulty', 'unknown')
|
||||
|
||||
for level, key in [
|
||||
('discipline', discipline),
|
||||
# ('field', f"{discipline}/{field}"),
|
||||
# ('subfield', f"{discipline}/{field}/{subfield}"),
|
||||
]:
|
||||
if key not in stats[level]:
|
||||
stats[level][key] = {
|
||||
'correct': 0,
|
||||
'total': 0,
|
||||
'miss': 0,
|
||||
'error': 0,
|
||||
'discipline': discipline,
|
||||
'field': field,
|
||||
'subfield': subfield,
|
||||
'difficulty': {
|
||||
'easy': {
|
||||
'correct': 0,
|
||||
'total': 0
|
||||
},
|
||||
'middle': {
|
||||
'correct': 0,
|
||||
'total': 0
|
||||
},
|
||||
'hard': {
|
||||
'correct': 0,
|
||||
'total': 0
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stats[level][key]['total'] += 1
|
||||
stats[level][key]['difficulty'][difficulty]['total'] += 1
|
||||
|
||||
answer_letter = sample['answer_letter']
|
||||
assert answer_letter == gold
|
||||
if predict and answer_letter == predict:
|
||||
acc += 1
|
||||
acc_difficulty[difficulty] += 1
|
||||
sample['status'] = 'correct'
|
||||
stats[level][key]['correct'] += 1
|
||||
stats[level][key]['difficulty'][difficulty]['correct'] += 1
|
||||
elif predict is None or predict == '':
|
||||
miss += 1
|
||||
sample['status'] = 'miss'
|
||||
stats[level][key]['miss'] += 1
|
||||
elif predict == 'error':
|
||||
err += 1
|
||||
sample['status'] = 'error'
|
||||
stats[level][key]['error'] += 1
|
||||
else:
|
||||
sample['status'] = 'incorrect'
|
||||
count += 1
|
||||
count_difficulty[difficulty] += 1
|
||||
details.append({
|
||||
'pred': sample['pred'],
|
||||
'answer': sample['answer'],
|
||||
'parsed_answer': sample['extracted_answer'],
|
||||
'correct': True if sample['status'] else False,
|
||||
})
|
||||
|
||||
return {
|
||||
'accuracy':
|
||||
acc / count if count > 0 else 0,
|
||||
'error_rate':
|
||||
err / count if count > 0 else 0,
|
||||
'miss_rate':
|
||||
miss / count if count > 0 else 0,
|
||||
'hard_accuracy':
|
||||
(acc_difficulty['hard'] /
|
||||
count_difficulty['hard'] if count_difficulty['hard'] > 0 else 0),
|
||||
'middle_accuracy':
|
||||
(acc_difficulty['middle'] / count_difficulty['middle']
|
||||
if count_difficulty['middle'] > 0 else 0),
|
||||
'easy_accuracy':
|
||||
(acc_difficulty['easy'] /
|
||||
count_difficulty['easy'] if count_difficulty['easy'] > 0 else 0),
|
||||
'details':
|
||||
details,
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
response_key: 'response'
|
||||
error_key: 'error'
|
||||
id_key:
|
||||
- 'uuid'
|
||||
prompt_key: 'prompt'
|
||||
|
||||
|
||||
|
||||
history_key: 'history'
|
||||
status_key: 'status'
|
||||
|
||||
save_prompt: True
|
||||
max_tokens: 4096
|
||||
temperatrue: 0.0
|
||||
|
||||
max_rounds: 30
|
||||
BoN: 32
|
@ -0,0 +1,17 @@
|
||||
response_key: 'response'
|
||||
error_key: 'error'
|
||||
id_key:
|
||||
- 'uuid'
|
||||
prompt_key: 'prompt'
|
||||
|
||||
|
||||
|
||||
history_key: 'history'
|
||||
status_key: 'status'
|
||||
|
||||
save_prompt: True
|
||||
max_tokens: 32768
|
||||
temperatrue: 0.0
|
||||
|
||||
max_rounds: 30
|
||||
BoN: 32
|
@ -0,0 +1,88 @@
|
||||
import yaml
|
||||
|
||||
|
||||
class ConfigWrapper:
|
||||
|
||||
def __init__(self, config_path):
|
||||
self._config = {}
|
||||
with open(config_path, 'r') as file:
|
||||
self._config = yaml.safe_load(file)
|
||||
for key, value in self._config.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super().__setattr__(key, value)
|
||||
else:
|
||||
self._config[key] = value
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self._config:
|
||||
return self._config[key]
|
||||
raise AttributeError(
|
||||
f"'ConfigWrapper' object has no attribute '{key}'")
|
||||
|
||||
def get_id(self, data):
|
||||
if isinstance(self._config.get('id_key'), str):
|
||||
return data.get(self._config.get('id_key'), None)
|
||||
elif isinstance(self._config.get('id_key'), list):
|
||||
return '_'.join([
|
||||
str(data[key]) for key in self._config.get('id_key')
|
||||
if key in data
|
||||
])
|
||||
|
||||
def print_all_keys(self):
|
||||
print('config keys:')
|
||||
for key, value in self._config.items():
|
||||
print(f' - {key}: {value}')
|
||||
|
||||
|
||||
config_wrapper = None
|
||||
|
||||
|
||||
def initialize_config(config_path):
|
||||
global config_wrapper
|
||||
config_wrapper = ConfigWrapper(config_path)
|
||||
|
||||
|
||||
def get_config_wrapper():
|
||||
global config_wrapper
|
||||
if config_wrapper is None:
|
||||
raise RuntimeError(
|
||||
'ConfigWrapper not initialized. Call initialize_config first.')
|
||||
return config_wrapper
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_path = 'config/config.yaml'
|
||||
initialize_config(config_path)
|
||||
data = {
|
||||
'idx':
|
||||
'50',
|
||||
'step':
|
||||
21,
|
||||
'question':
|
||||
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
|
||||
'Please provide the decrypted answer, encapsulated in double square'
|
||||
' brackets. For example, the format should be: [[decrypted answer]].',
|
||||
'answer':
|
||||
'[[P]]',
|
||||
'category':
|
||||
'Decryption',
|
||||
'rule_id':
|
||||
'23',
|
||||
'input':
|
||||
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
|
||||
'steps_num':
|
||||
23,
|
||||
'description':
|
||||
'For a number c=228 in the ciphertext:\n'
|
||||
'Calculate z = c^e mod n. Here ^ means multiplication.\nz is 80.'
|
||||
'\nBased on the decimal number represented by z, use the ascii '
|
||||
'code to find the corresponding letter as the plaintext letter p.'
|
||||
'\nPlease give the letter p in [[...]] format.\n',
|
||||
'atom':
|
||||
80,
|
||||
}
|
||||
print(config_wrapper.get_id(data))
|
@ -0,0 +1,91 @@
|
||||
prompt_format:
|
||||
- |
|
||||
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
Question:
|
||||
A refracting telescope consists of two converging lenses separated by 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification of the telescope is
|
||||
A) 10
|
||||
B) 40
|
||||
C) 6
|
||||
D) 25
|
||||
E) 15
|
||||
F) 50
|
||||
G) 30
|
||||
H) 4
|
||||
I) 5
|
||||
J) 20
|
||||
|
||||
Answer: Let's think step by step. In a refracting telescope, if both lenses are converging, the focus of both lenses must be between the two lenses, and thus the focal lengths of the two lenses must add up to their separation. Since the focal length of one lens is 20 cm, the focal length of the other must be 80 cm. The magnification is the ratio of these two focal lengths, or 4.
|
||||
Answer: H.
|
||||
|
||||
Question:
|
||||
Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?
|
||||
A) 1000 times more
|
||||
B) 50 times more
|
||||
C) 5000 times more
|
||||
D) 500 times more
|
||||
E) 10000 times more
|
||||
F) 20000 times more
|
||||
G) 2000 times more
|
||||
H) 100 times more
|
||||
I) 10 times more
|
||||
J) N/A
|
||||
|
||||
Answer: Let's think step by step. The amount of light a telescope can gather compared to the human eye is proportional to the area of its apertures. The area of a circle is given by the formula $A = \pi \left(\frac{{D}}{{2}}\right)^2$, where $D$ is the diameter. Therefore, the relative light-gathering power is calculated as:
|
||||
\[
|
||||
\frac{{\left(\frac{{50 \text{{ cm}}}}{{2}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{2}}\right)^2}} = \frac{{\left(\frac{{50 \text{{ cm}}}}{{0.1 \text{{ cm}}}}\right)^2}}{{\left(\frac{{5 \text{{ mm}}}}{{0.1 \text{{ cm}}}}\right)^2}} = \frac{{500^2}}{{5^2}} = 10000.
|
||||
\]
|
||||
Answer: E.
|
||||
|
||||
Question:
|
||||
Where do most short-period comets come from and how do we know?
|
||||
A) The Kuiper belt; short period comets tend to be in the plane of the solar system like the Kuiper belt.
|
||||
B) The asteroid belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the asteroid belt.
|
||||
C) The asteroid belt; short period comets tend to be in the plane of the solar system just like the asteroid belt.
|
||||
D) The Oort cloud; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the Oort cloud.
|
||||
E) The Oort Cloud; short period comets tend to come from random directions indicating a spherical distribution of comets called the Oort Cloud.
|
||||
F) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.
|
||||
G) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt.
|
||||
Answer: Let's think step by step. Most short-period comets originate from the Kuiper belt. This is deduced from the observation that these comets tend to follow orbits that lie in the plane of the solar system, similar to the distribution of objects in the Kuiper belt itself. Thus, the alignment of these cometary orbits with the ecliptic plane points to their Kuiper belt origin.
|
||||
Answer: A.
|
||||
|
||||
Question:
|
||||
Colors in a soap bubble result from light
|
||||
A) dispersion
|
||||
B) deflection
|
||||
C) refraction
|
||||
D) reflection
|
||||
E) interference
|
||||
F) converted to a different frequency
|
||||
G) polarization
|
||||
H) absorption
|
||||
I) diffraction
|
||||
J) transmission
|
||||
|
||||
Answer: Let's think step by step. The colorful patterns observed in a soap bubble are caused by the phenomenon of light interference. This occurs when light waves bounce between the two surfaces of the soap film, combining constructively or destructively based on their phase differences and the varying thickness of the film. These interactions result in vibrant color patterns due to variations in the intensity of different wavelengths of light.
|
||||
Answer: E.
|
||||
|
||||
Question:
|
||||
A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?
|
||||
A) 240 W
|
||||
B) 120 W
|
||||
C) 10 W
|
||||
D) 480 W
|
||||
E) 360 W
|
||||
F) 200 W
|
||||
G) 30 W
|
||||
H) 150 W
|
||||
I) 60 W
|
||||
J) 300 W
|
||||
|
||||
Answer: Let's think step by step. The rate of energy usage, known as power, in an electrical circuit is calculated by the product of voltage and current. For a microwave oven connected to a 120 V outlet and drawing a current of 2 amps, the power consumption can be calculated as follows:
|
||||
\[
|
||||
\text{{Power}} = \text{{Voltage}} \times \text{{Current}} = 120 \, \text{{V}} \times 2 \, \text{{A}} = 240 \, \text{{W}}.
|
||||
\]
|
||||
Therefore, the microwave oven uses energy at a rate of 240 watts.
|
||||
Answer: A.
|
||||
|
||||
Question:
|
||||
{}
|
||||
|
||||
Answer: Let's think step by step.
|
@ -0,0 +1,23 @@
|
||||
initial_prompt_0:
|
||||
- |
|
||||
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
||||
|
||||
initial_prompt_1:
|
||||
- |
|
||||
You are a helpful assistant. Answer the given multiple-choice question. Only one option is correct. The last line of your response should be in the format 'The correct answer is: $LETTER', where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
||||
|
||||
initial_prompt_2:
|
||||
- |
|
||||
Select the correct answer for the following multiple-choice question. There is only one valid choice. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
||||
|
||||
initial_prompt_3:
|
||||
- |
|
||||
Review the following multiple-choice question and choose the one correct answer. Ensure that your response concludes with a line exactly formatted as 'The correct answer is: $LETTER', where LETTER represents one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
@ -0,0 +1,5 @@
|
||||
prompt_format:
|
||||
- |
|
||||
Answer the following multiple choice question about {}. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
@ -0,0 +1,5 @@
|
||||
prompt_format:
|
||||
- |
|
||||
Answer the following multiple choice question. There is only one correct answer. The last line of your response should be in the format 'Answer: $LETTER' (without quotes), where LETTER is one of A, B, C, D, E, F, G, H, I, or J.
|
||||
|
||||
{}
|
96
opencompass/datasets/supergpqa/supergpqa_eval.py
Normal file
96
opencompass/datasets/supergpqa/supergpqa_eval.py
Normal file
@ -0,0 +1,96 @@
|
||||
# flake8: noqa: W605
|
||||
import re
|
||||
|
||||
import timeout_decorator
|
||||
|
||||
|
||||
@timeout_decorator.timeout(5) # 5 seconds timeout
|
||||
def safe_regex_search(pattern, text, flags=0):
|
||||
try:
|
||||
return re.search(pattern, text, flags)
|
||||
except timeout_decorator.TimeoutError:
|
||||
print(f'Regex match timeout: pattern={pattern}, text={text[:100]}...')
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f'Regex match error: {str(e)}')
|
||||
return None
|
||||
|
||||
|
||||
def extract_option_labels(text, options='ABCDEFGHIJ'):
|
||||
if not isinstance(text, str) or not isinstance(options, str):
|
||||
return 'error'
|
||||
|
||||
text = text.rstrip()
|
||||
last_line = text.split('\n')[-1]
|
||||
|
||||
option_str = ''.join([chr(65 + i) for i in range(len(options))
|
||||
]) if options else 'ABCDEFGHIJ'
|
||||
|
||||
patterns = [
|
||||
# e.g. "The final answer to this question is: A."
|
||||
# "The best option is $\boxed{B}:"
|
||||
# "The correct answer is (C)."
|
||||
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is?:?\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
|
||||
# e.g. "ANSWER: A"
|
||||
# "Answer: $\boxed{B}."
|
||||
# "ANSWER: (C):"
|
||||
f'(?i:Answer)[\*\s]*:\s*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
|
||||
# e.g. "A"
|
||||
# "$\boxed{B}$"
|
||||
# "(C)."
|
||||
# "[D]:"
|
||||
f'^[^\w\r\n]*(?:[\*\$\\{{(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*([{option_str}])(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, last_line, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_option_content(text, options_content=None):
|
||||
if not isinstance(text, str) or not isinstance(options_content, list):
|
||||
return 'error'
|
||||
|
||||
escaped_options_content = [
|
||||
re.escape(option_content) for option_content in options_content
|
||||
]
|
||||
escaped_options_content_str = '|'.join(escaped_options_content)
|
||||
|
||||
text = text.rstrip()
|
||||
last_line = text.split('\n')[-1]
|
||||
|
||||
patterns = [
|
||||
f'[Tt]he\s+(?:\w+\s+)?(?:answer|option)(?:\w+\s+)?\s+is:?\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
f'(?i:Answer)\s*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
f'^[^\w\r\n]*(?:[\*\$\\{{\(\[\\\\(]*?(?:(?:\\\\boxed|\\\\mathbf|\\\\mathrm|\\\\text){{)?)*\s*({escaped_options_content_str})(?:\\\\?\}}?\$?\)?\]?\}}?)*(?:[\s:\.\*)]|$)',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, last_line)
|
||||
if match:
|
||||
if match.group(1) in escaped_options_content:
|
||||
return options_content[escaped_options_content.index(
|
||||
match.group(1))]
|
||||
else:
|
||||
return match.group(1)
|
||||
|
||||
for pattern in patterns:
|
||||
match = safe_regex_search(pattern, text)
|
||||
if match:
|
||||
if match.group(1) in escaped_options_content:
|
||||
return options_content[escaped_options_content.index(
|
||||
match.group(1))]
|
||||
else:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
693
opencompass/datasets/supergpqa/supergpqa_utils.py
Normal file
693
opencompass/datasets/supergpqa/supergpqa_utils.py
Normal file
@ -0,0 +1,693 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import sympy as sp
|
||||
import yaml
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
|
||||
def load_yaml(yaml_path):
|
||||
"""Load a YAML file."""
|
||||
if not os.path.exists(yaml_path):
|
||||
raise FileNotFoundError(f'YAML file not found: {yaml_path}')
|
||||
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
|
||||
def load_json_or_jsonl(file_path):
|
||||
"""Load data from a JSON or JSONL file."""
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
if file_path.endswith('.json'):
|
||||
return json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
return [json.loads(line) for line in file]
|
||||
return None
|
||||
|
||||
|
||||
def find_file(base_path, sub_path, extensions=('json', 'jsonl')):
|
||||
"""Find the first available file with given extensions."""
|
||||
for ext in extensions:
|
||||
file_path = os.path.join(base_path, f'{sub_path}.{ext}')
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
return None
|
||||
|
||||
|
||||
def load_json_or_jsonl_with_idx(data_path, split='', idx=None):
|
||||
base_path = os.path.join(data_path, split)
|
||||
if os.path.exists(f'{base_path}.json'):
|
||||
file_path = f'{base_path}.json'
|
||||
elif os.path.exists(f'{base_path}.jsonl'):
|
||||
file_path = f'{base_path}.jsonl'
|
||||
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
||||
file_path = base_path
|
||||
else:
|
||||
raise FileNotFoundError('No JSON or JSONL file found.')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
if file_path.endswith('.json'):
|
||||
data = json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
data = [json.loads(line) for line in file]
|
||||
|
||||
if idx is not None:
|
||||
try:
|
||||
return next(item for item in data if item.get('idx') == idx)
|
||||
except StopIteration:
|
||||
raise ValueError(f'No entry found for idx {idx}')
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def load_split_data(base_path, split_name):
|
||||
"""Load the rule and sample data for a specific split."""
|
||||
split_path = os.path.join(base_path, split_name)
|
||||
rule_path = find_file(split_path, 'rule')
|
||||
sample_path = find_file(split_path, 'sample')
|
||||
|
||||
rules = load_json_or_jsonl(rule_path) if rule_path else []
|
||||
samples = load_json_or_jsonl(sample_path) if sample_path else []
|
||||
|
||||
return {'rules': rules, 'samples': samples}
|
||||
|
||||
|
||||
def process_mixed_data(base_path, mode):
|
||||
"""Load and process data for the 'mixed' split and specific mode."""
|
||||
mixed_path = os.path.join(base_path, 'mixed')
|
||||
file_path = find_file(mixed_path, mode)
|
||||
if not file_path:
|
||||
print(f'[WARNING] Missing file for mixed mode: {mode}')
|
||||
return []
|
||||
|
||||
data = load_json_or_jsonl(file_path)
|
||||
template_path = os.path.join(base_path, 'config/prompt/mixed.yaml')
|
||||
template = load_yaml(template_path)
|
||||
|
||||
processed = []
|
||||
for item in data:
|
||||
rules = '\n'.join(item.get('rule_list', []))
|
||||
questions = '\n'.join(item.get('question_list', []))
|
||||
item['prompt'] = template['prompt_format'][0].format(rules, questions)
|
||||
processed.append(item)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
class ConfigWrapper:
|
||||
|
||||
def __init__(self, config_path):
|
||||
self._config = {}
|
||||
with open(config_path, 'r') as file:
|
||||
self._config = yaml.safe_load(file)
|
||||
for key, value in self._config.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super().__setattr__(key, value)
|
||||
else:
|
||||
self._config[key] = value
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self._config:
|
||||
return self._config[key]
|
||||
raise AttributeError(
|
||||
f"'ConfigWrapper' object has no attribute '{key}'")
|
||||
|
||||
def get_id(self, data):
|
||||
if isinstance(self._config.get('id_key'), str):
|
||||
return data.get(self._config.get('id_key'), None)
|
||||
elif isinstance(self._config.get('id_key'), list):
|
||||
return '_'.join([
|
||||
str(data[key]) for key in self._config.get('id_key')
|
||||
if key in data
|
||||
])
|
||||
|
||||
def print_all_keys(self):
|
||||
print('config keys:')
|
||||
for key, value in self._config.items():
|
||||
print(f' - {key}: {value}')
|
||||
|
||||
|
||||
config_wrapper = None
|
||||
|
||||
|
||||
def initialize_config(config_path):
|
||||
global config_wrapper
|
||||
config_wrapper = ConfigWrapper(config_path)
|
||||
|
||||
|
||||
def get_config_wrapper():
|
||||
global config_wrapper
|
||||
if config_wrapper is None:
|
||||
raise RuntimeError(
|
||||
'ConfigWrapper not initialized. Call initialize_config first.')
|
||||
return config_wrapper
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_path = 'config/config.yaml'
|
||||
initialize_config(config_path)
|
||||
data = {
|
||||
'idx':
|
||||
'50',
|
||||
'step':
|
||||
21,
|
||||
'question':
|
||||
('Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
|
||||
'Please provide the decrypted answer, encapsulated in double '
|
||||
'square brackets. '
|
||||
'For example, the format should be: [[decrypted answer]].'),
|
||||
'answer':
|
||||
'[[P]]',
|
||||
'category':
|
||||
'Decryption',
|
||||
'rule_id':
|
||||
'23',
|
||||
'input':
|
||||
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
|
||||
'steps_num':
|
||||
23,
|
||||
'description':
|
||||
('For a number c=228 in the ciphertext:\n'
|
||||
'Calculate z = c^e mod n. Here ^ means multiplication.\n'
|
||||
'z is 80.\nBased on the decimal number represented by z, '
|
||||
'use the ascii code to find the corresponding letter '
|
||||
'as the plaintext letter p.\n'
|
||||
'Please give the letter p in [[...]] format.\n'),
|
||||
'atom':
|
||||
80
|
||||
}
|
||||
print(config_wrapper.get_id(data))
|
||||
|
||||
|
||||
def read_yaml(config='default'):
|
||||
if os.path.exists(f'config/prompt/{config}.yaml'):
|
||||
yaml_file = f'config/prompt/{config}.yaml'
|
||||
else:
|
||||
yaml_file = config
|
||||
with open(yaml_file, 'r') as yaml_file:
|
||||
return yaml.safe_load(yaml_file)
|
||||
|
||||
|
||||
def write_jsonl_lines(file, data):
|
||||
config_wrapper = get_config_wrapper()
|
||||
if config_wrapper.save_prompt:
|
||||
json.dump(data, file, ensure_ascii=False)
|
||||
else:
|
||||
data.pop(config_wrapper.prompt_key)
|
||||
json.dump(data, file, ensure_ascii=False)
|
||||
file.write('\n')
|
||||
file.flush()
|
||||
|
||||
|
||||
def print_info(info):
|
||||
print('-' * 100)
|
||||
print('[INFO] model_name:', info['model_name'])
|
||||
print('[INFO] splits:', info['splits'])
|
||||
print('[INFO] modes:', info['modes'])
|
||||
print('[INFO] output_dir:', info['output_dir'])
|
||||
print('[INFO] Infer Limit:',
|
||||
'No limit' if info['infer_limit'] is None else info['infer_limit'])
|
||||
print('[INFO] Number of Workers:', info['num_workers'])
|
||||
print('[INFO] Batch Size:', info['batch_size'])
|
||||
print('[INFO] Use Accel:', info['use_accel'])
|
||||
print('-' * 100)
|
||||
|
||||
|
||||
def read_json_or_jsonl(data_path, split='', mapping_key=None):
|
||||
base_path = os.path.join(data_path, split)
|
||||
if os.path.exists(f'{base_path}.json'):
|
||||
file_path = f'{base_path}.json'
|
||||
elif os.path.exists(f'{base_path}.jsonl'):
|
||||
file_path = f'{base_path}.jsonl'
|
||||
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
||||
file_path = base_path
|
||||
else:
|
||||
raise FileNotFoundError('No JSON or JSONL file found.')
|
||||
|
||||
with open(file_path, 'r') as file:
|
||||
if file_path.endswith('.json'):
|
||||
data = json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
data = [json.loads(line) for line in file]
|
||||
|
||||
if mapping_key:
|
||||
return {
|
||||
item[mapping_key]: item
|
||||
for item in data if mapping_key in item
|
||||
}
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def read_json_or_jsonl_with_idx(data_path, split='', idx=None):
|
||||
base_path = os.path.join(data_path, split)
|
||||
if os.path.exists(f'{base_path}.json'):
|
||||
file_path = f'{base_path}.json'
|
||||
elif os.path.exists(f'{base_path}.jsonl'):
|
||||
file_path = f'{base_path}.jsonl'
|
||||
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
||||
file_path = base_path
|
||||
else:
|
||||
raise FileNotFoundError('No JSON or JSONL file found.')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
if file_path.endswith('.json'):
|
||||
data = json.load(file)
|
||||
elif file_path.endswith('.jsonl'):
|
||||
data = [json.loads(line) for line in file]
|
||||
|
||||
if idx is not None:
|
||||
try:
|
||||
return next(item for item in data if item.get('idx') == idx)
|
||||
except StopIteration:
|
||||
raise ValueError(f'No entry found for idx {idx}')
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
idx_ranges = [
|
||||
[18],
|
||||
[73, 74, 77],
|
||||
[94],
|
||||
[115, 116, 117],
|
||||
[121, 122, 123, 125],
|
||||
[131, 132, 134, 135, 136],
|
||||
[141, 143, 149],
|
||||
list(range(145, 148)),
|
||||
list(range(151, 157)),
|
||||
[160, 161, 162],
|
||||
[164, 165, 166],
|
||||
[170],
|
||||
[206, 209],
|
||||
list(range(211, 216)),
|
||||
[217, 218],
|
||||
]
|
||||
|
||||
|
||||
def clean_json_string(json_str):
|
||||
json_str = re.sub(r'[\x00-\x1F\x7F]', '', json_str)
|
||||
return json_str
|
||||
|
||||
|
||||
def is_in_idx_ranges(idx, idx_ranges):
|
||||
for range_list in idx_ranges:
|
||||
if int(idx) in range_list:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
matches = re.findall(r'{.*}', text, re.DOTALL)
|
||||
if matches:
|
||||
json_str = matches[-1]
|
||||
json_str = clean_json_string(json_str)
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f'Error decoding JSON: {e}')
|
||||
return 'NULL'
|
||||
return 'NULL'
|
||||
|
||||
|
||||
def extract_all_responses_from_json(response_json):
|
||||
results = []
|
||||
for key, value in response_json.items():
|
||||
results.append(str(value))
|
||||
return results
|
||||
|
||||
|
||||
def clean_latex(latex_expr):
|
||||
if '=' in latex_expr:
|
||||
latex_expr = latex_expr.rsplit('=', 1)[1]
|
||||
latex_expr = re.sub(r'\\[()\[\]]', '', latex_expr)
|
||||
latex_expr = re.sub(r'\\text\{.*?\}', '', latex_expr)
|
||||
latex_expr = re.sub(r'\\(left|right|displaystyle)', '', latex_expr)
|
||||
latex_expr = latex_expr.replace('\\\\', '\\')
|
||||
return latex_expr
|
||||
|
||||
|
||||
def extract_text_from_brackets(text, clean_level='basic'):
|
||||
matches = re.findall(r'\[\[\s*(.*?)\s*\]\]', text, re.DOTALL)
|
||||
if not matches:
|
||||
matches = re.findall(r'\$\\boxed\{(.*?)\}\$', text, re.DOTALL)
|
||||
if not matches:
|
||||
matches = re.findall(r'\[\s*(.*?)\s*\]', text, re.DOTALL)
|
||||
if matches:
|
||||
match_str = matches[0].strip()
|
||||
if clean_level == 'clean':
|
||||
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
||||
' ', '').replace('[', '').replace(']', '')
|
||||
elif clean_level == 'logic':
|
||||
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
||||
' ', '').replace('.', '')
|
||||
elif clean_level == 'math':
|
||||
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
||||
'[', '').replace(']', '').replace('$', '')
|
||||
return f'{clean_latex(match_str)}'
|
||||
return f'[[{match_str}]]'
|
||||
return 'NULL'
|
||||
|
||||
|
||||
def extract_inner_text_from_brackets(text):
|
||||
if not isinstance(text, str):
|
||||
print(f'text type: {type(text)}, text value: {text}')
|
||||
return 'NULL'
|
||||
match = re.search(r'\[\[(.*?)\]\]', text, re.DOTALL)
|
||||
return match.group(1) if match else 'NULL'
|
||||
|
||||
|
||||
def extract_numbers(str):
|
||||
numbers = re.findall(r'\d+', str)
|
||||
numbers = list(map(int, numbers))
|
||||
return numbers
|
||||
|
||||
|
||||
def extract_and_sort_inequalities(latex_expr):
|
||||
pattern = r'(≥|≤)\s*([-]?\d+\.?\d*)'
|
||||
matches = re.findall(pattern, latex_expr)
|
||||
extracted_inequalities = [''.join(match) for match in matches]
|
||||
sorted_inequalities = sorted(extracted_inequalities)
|
||||
return sorted_inequalities
|
||||
|
||||
|
||||
def rule5_normalize_content(content):
|
||||
parts = [part for part in content.split(';')]
|
||||
sorted_parts = sorted(parts)
|
||||
return sorted_parts
|
||||
|
||||
|
||||
def normalize_string(s):
|
||||
s = re.sub(r'[^0-9]', '', s)
|
||||
pairs = s.split(',')
|
||||
pairs.sort()
|
||||
return pairs
|
||||
|
||||
|
||||
def remove_commas_and_spaces(s):
|
||||
return re.sub(r'[,\s\[\]]+', '', s)
|
||||
|
||||
|
||||
def remove_non_alphanumeric(s):
|
||||
return re.sub(r'\W+', '', s)
|
||||
|
||||
|
||||
def contains_or(answer):
|
||||
return 'or' in answer
|
||||
|
||||
|
||||
def compare_multi_results(response, answer):
|
||||
try:
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
response_text = re.sub(r'\\text\{or\}', 'or', response_text)
|
||||
if response_text == 'NULL':
|
||||
return False
|
||||
answer = extract_text_from_brackets(answer, 'clean')
|
||||
response_split = response_text.strip('[[]]').split('or')
|
||||
answer_split = answer.strip('[[]]').split('or')
|
||||
response_sorted = sorted([x.strip() for x in response_split])
|
||||
answer_sorted = sorted([x.strip() for x in answer_split])
|
||||
return response_sorted == answer_sorted
|
||||
except Exception as e:
|
||||
print(f'Error during comparison: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def split_or_expression(expression):
|
||||
return [part.strip() for part in expression.split('or')]
|
||||
|
||||
|
||||
def compare_math_expressions(response, answer):
|
||||
response_text = extract_text_from_brackets(response, 'math')
|
||||
answer_text = extract_text_from_brackets(answer, 'math')
|
||||
if response_text == 'NULL':
|
||||
return False
|
||||
if contains_or(answer_text):
|
||||
response_parts = split_or_expression(response_text)
|
||||
answer_parts = split_or_expression(answer_text)
|
||||
try:
|
||||
response_exprs = {
|
||||
sp.simplify(parse_latex(part))
|
||||
for part in response_parts
|
||||
}
|
||||
answer_exprs = {
|
||||
sp.simplify(parse_latex(part))
|
||||
for part in answer_parts
|
||||
}
|
||||
return response_exprs == answer_exprs
|
||||
except Exception as e:
|
||||
print(f'Error during simplification or parsing: {e}')
|
||||
return response_text == answer_text
|
||||
else:
|
||||
try:
|
||||
response_expr = sp.simplify(parse_latex(response_text))
|
||||
answer_expr = sp.simplify(parse_latex(answer_text))
|
||||
return response_expr == answer_expr
|
||||
except Exception as e:
|
||||
print(f'Error during simplification or parsing: {e}')
|
||||
return response_text == answer_text
|
||||
|
||||
|
||||
def method_equal(response_text, answer):
|
||||
return response_text == answer
|
||||
|
||||
|
||||
def method_1(response_text, answer):
|
||||
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
||||
cleaned_string = cleaned_string.lower()
|
||||
answer = re.sub(r'[^A-Za-z]', '', answer)
|
||||
answer = answer.lower()
|
||||
return cleaned_string == answer
|
||||
|
||||
|
||||
def method_2(response_text, answer):
|
||||
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
||||
cleaned_string = cleaned_string.lower()
|
||||
answer = answer.split(',')
|
||||
return cleaned_string in answer
|
||||
|
||||
|
||||
def method_3(response_text, answer):
|
||||
response_text = response_text.lower()
|
||||
pairs1 = re.split(r'\W+', response_text)
|
||||
pairs2 = answer.split(' ')
|
||||
pairs1 = [word for word in pairs1 if word]
|
||||
pairs1.sort()
|
||||
pairs2.sort()
|
||||
return pairs1 == pairs2
|
||||
|
||||
|
||||
def method_4(response_text, answer):
|
||||
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
||||
cleaned_string = cleaned_string.lower()
|
||||
return cleaned_string in answer
|
||||
|
||||
|
||||
def method_5(response_text, answer):
|
||||
response_text = re.sub(r'\s+', '', response_text)
|
||||
response_text = response_text.split(',')
|
||||
answer = answer.split(',')
|
||||
response_text.sort()
|
||||
answer.sort()
|
||||
return response_text == answer
|
||||
|
||||
|
||||
def method_9(response_text, answer):
|
||||
response_text = response_text.replace('×', '*').replace('−', '-')
|
||||
answer = answer.replace('×', '*').replace('−', '-')
|
||||
|
||||
def extract_operators(s):
|
||||
return re.findall(r'[+\-*/]', s)
|
||||
|
||||
response_ops = extract_operators(response_text.split('=')[0])
|
||||
answer_ops = extract_operators(answer.split('=')[0])
|
||||
if response_ops != answer_ops:
|
||||
return False
|
||||
match = re.search(r'=\s*(-?\d+)', answer)
|
||||
expected_result = int(match.group(1))
|
||||
try:
|
||||
left_side = response_text.split('=')[0]
|
||||
result = eval(left_side)
|
||||
except Exception as e:
|
||||
print(f'Error during evaluation: {e}')
|
||||
return False
|
||||
return result == expected_result
|
||||
|
||||
|
||||
def method_10(response_text, answer):
|
||||
response_text = response_text.replace('×', '*').replace('−', '-')
|
||||
response_text = response_text.split('=')[0]
|
||||
answer = answer.split('\n')[0].split('=')[0]
|
||||
response_ops = sorted(remove_non_alphanumeric(response_text))
|
||||
answer_ops = sorted(remove_non_alphanumeric(answer))
|
||||
if response_ops != answer_ops:
|
||||
return False
|
||||
try:
|
||||
result = eval(response_text)
|
||||
except Exception as e:
|
||||
print(f'Error during evaluation: {e}')
|
||||
return False
|
||||
return result == 24
|
||||
|
||||
|
||||
def method_18(response_text, answer):
|
||||
cleaned_s1 = remove_commas_and_spaces(response_text)
|
||||
cleaned_s2 = remove_commas_and_spaces(answer)
|
||||
return cleaned_s1 == cleaned_s2
|
||||
|
||||
|
||||
def method_general(response_text, answer):
|
||||
cleaned_s1 = remove_non_alphanumeric(response_text)
|
||||
cleaned_s2 = remove_non_alphanumeric(answer)
|
||||
return cleaned_s1 == cleaned_s2
|
||||
|
||||
|
||||
question_methods = {
|
||||
'1': method_1,
|
||||
'2': method_2,
|
||||
'3': method_3,
|
||||
'4': method_4,
|
||||
'5': method_5,
|
||||
'9': method_9,
|
||||
'10': method_10,
|
||||
'18': method_18,
|
||||
}
|
||||
|
||||
|
||||
def evaluate_response_vs_answer(response, answer, question_type, rule_id, idx):
|
||||
if question_type == 'logic' and rule_id == '5':
|
||||
response_text = extract_text_from_brackets(response, 'logic')
|
||||
answer_text = extract_text_from_brackets(answer, 'logic')
|
||||
if response_text is None:
|
||||
return False
|
||||
normalized_response = rule5_normalize_content(response_text)
|
||||
normalized_answer = rule5_normalize_content(answer)
|
||||
return normalized_response == normalized_answer
|
||||
elif question_type == 'logic':
|
||||
response_text = extract_text_from_brackets(response, 'logic')
|
||||
answer_text = extract_text_from_brackets(answer, 'logic')
|
||||
return response_text == answer_text
|
||||
elif question_type == 'operation' and (idx == '178' or idx == '179'):
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
response_text = extract_and_sort_inequalities(response_text)
|
||||
answer_text = extract_and_sort_inequalities(answer)
|
||||
# print(response_text, answer_text)
|
||||
return response_text == answer_text
|
||||
elif question_type == 'operation' and rule_id == '18':
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
answer = extract_inner_text_from_brackets(answer)
|
||||
response_text = ''.join(sorted(re.sub(r'\W+', '', response_text)))
|
||||
answer = ''.join(sorted(re.sub(r'\W+', '', answer)))
|
||||
return response_text == answer
|
||||
elif question_type == 'operation' and rule_id in {'23', '24', '25'}:
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
if response_text is None:
|
||||
return False
|
||||
response_text = extract_numbers(response_text)
|
||||
answer_text = extract_numbers(answer)
|
||||
return response_text == answer_text
|
||||
elif question_type == 'operation' and is_in_idx_ranges(idx, idx_ranges):
|
||||
return compare_math_expressions(response, answer)
|
||||
elif question_type == 'operation' and contains_or(answer):
|
||||
return compare_multi_results(response, answer)
|
||||
elif question_type == 'puzzle':
|
||||
response_text = extract_inner_text_from_brackets(response)
|
||||
answer = extract_inner_text_from_brackets(answer)
|
||||
method = question_methods.get(rule_id)
|
||||
if method:
|
||||
return method(response_text, answer)
|
||||
return method_general(response_text, answer)
|
||||
else:
|
||||
response_text = extract_text_from_brackets(response, 'clean')
|
||||
return response_text == answer
|
||||
|
||||
|
||||
def compute_one_mixed_question_pass_rate(idx,
|
||||
question_list,
|
||||
response_json,
|
||||
base_path=None):
|
||||
if response_json == 'NULL':
|
||||
result_dict = {
|
||||
'idx': idx,
|
||||
'response': response_json,
|
||||
'details': None,
|
||||
'pass_rate': 0,
|
||||
'is_correct': False
|
||||
}
|
||||
return result_dict
|
||||
response_list = extract_all_responses_from_json(response_json)
|
||||
correct_num = 0
|
||||
results = []
|
||||
for q_idx, question in enumerate(question_list):
|
||||
category, question_idx = question.rsplit('_', 1)
|
||||
question_content = load_json_or_jsonl_with_idx(base_path,
|
||||
os.path.join(
|
||||
category, 'sample'),
|
||||
idx=question_idx)
|
||||
answer = question_content['answer']
|
||||
if q_idx >= len(response_list):
|
||||
break
|
||||
response = response_list[q_idx]
|
||||
response_text = extract_text_from_brackets(response)
|
||||
rule_id = question_content['rule_id']
|
||||
is_correct = evaluate_response_vs_answer(response, answer, category,
|
||||
rule_id, q_idx)
|
||||
if is_correct:
|
||||
correct_num += 1
|
||||
results.append({
|
||||
'question': question,
|
||||
'response_text': response_text,
|
||||
'answer': answer,
|
||||
'is_correct': is_correct
|
||||
})
|
||||
|
||||
pass_rate = correct_num / len(question_list)
|
||||
question_correct = pass_rate == 1.0
|
||||
result_dict = {
|
||||
'idx': idx,
|
||||
'response': response_json,
|
||||
'details': results,
|
||||
'pass_rate': pass_rate,
|
||||
'is_correct': question_correct
|
||||
}
|
||||
return result_dict
|
||||
|
||||
|
||||
def evaluate_responses(data, mode, base_path=None):
|
||||
results = []
|
||||
|
||||
# Iterate over the values of the dictionary (numerical keys)
|
||||
for key, record in data.items():
|
||||
idx = key # Use the dictionary key as the "idx"
|
||||
response = record.get('prediction', '')
|
||||
question_type = record.get('category', '')
|
||||
response_text = extract_text_from_brackets(response)
|
||||
answer = record.get('gold', '')
|
||||
rule_id = record.get('rule_id', '')
|
||||
is_correct = evaluate_response_vs_answer(response, answer,
|
||||
question_type, rule_id, idx)
|
||||
result_dict = {
|
||||
'idx': idx,
|
||||
'response': response,
|
||||
'response_text': response_text,
|
||||
'answer': answer,
|
||||
'is_correct': is_correct
|
||||
}
|
||||
if question_type == 'counterfactual':
|
||||
real_life_answer = record.get('real_life_answer', '')
|
||||
is_real_life = evaluate_response_vs_answer(response,
|
||||
real_life_answer,
|
||||
question_type, rule_id,
|
||||
idx)
|
||||
result_dict['real_life_answer'] = real_life_answer
|
||||
result_dict['is_real_life'] = is_real_life
|
||||
if question_type == 'cipher' and mode == 'subquestions':
|
||||
result_dict['type'] = record.get('type', '')
|
||||
results.append(result_dict)
|
||||
return results
|
@ -1,4 +1,5 @@
|
||||
"""Base Evaluator."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Union
|
||||
@ -77,12 +78,17 @@ class BaseEvaluator:
|
||||
for metric in all_metrics:
|
||||
if metric in ['predictions', 'example_abbr']:
|
||||
continue
|
||||
g_passk_details[metric] = 100. * np.mean(
|
||||
g_passk_details[metric] = 100.0 * np.mean(
|
||||
[detail[metric] for detail in details])
|
||||
return g_passk_details
|
||||
|
||||
def evaluate(self, k: Union[int, List[int]], n: int,
|
||||
original_dataset: Dataset, **score_kwargs):
|
||||
def evaluate(
|
||||
self,
|
||||
k: Union[int, List[int]],
|
||||
n: int,
|
||||
original_dataset: Dataset,
|
||||
**score_kwargs,
|
||||
):
|
||||
real_size = len(original_dataset) // n
|
||||
all_details = []
|
||||
all_results = []
|
||||
@ -146,7 +152,7 @@ class BaseEvaluator:
|
||||
|
||||
if can_calculate and n > 1 and k > 1:
|
||||
thresholds = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
for _k in ([k] if isinstance(k, int) else k):
|
||||
for _k in [k] if isinstance(k, int) else k:
|
||||
for threshold in thresholds:
|
||||
g_pass = compute_g_pass_at_k(n=n,
|
||||
c=c,
|
||||
@ -161,9 +167,31 @@ class BaseEvaluator:
|
||||
|
||||
if can_calculate and n > 1 and k > 1:
|
||||
eval_results.update(self.reduce(eval_details))
|
||||
|
||||
# Store eval_details in eval_results
|
||||
eval_results['details'] = eval_details
|
||||
|
||||
return eval_results
|
||||
# Process details to flatten the predictions
|
||||
for detail in eval_details:
|
||||
# Extract all prediction fields and flatten them
|
||||
flattened_predictions = {}
|
||||
for pred in detail['predictions']:
|
||||
for k, v in pred.items():
|
||||
if k not in flattened_predictions:
|
||||
flattened_predictions[k] = [v]
|
||||
else:
|
||||
flattened_predictions[k].append(v)
|
||||
|
||||
# Replace the predictions list with the flattened dictionary
|
||||
for k, v in flattened_predictions.items():
|
||||
detail[k] = v
|
||||
|
||||
# Remove the original predictions field
|
||||
detail.pop('predictions')
|
||||
return eval_results
|
||||
|
||||
# If there are no details, return an empty dictionary
|
||||
return {}
|
||||
|
||||
def score(self):
|
||||
raise NotImplementedError("Method hasn't been implemented yet")
|
||||
|
Loading…
Reference in New Issue
Block a user