mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
700 lines
23 KiB
Python
700 lines
23 KiB
Python
![]() |
import json
|
|||
|
import os
|
|||
|
import re
|
|||
|
|
|||
|
import sympy as sp
|
|||
|
import yaml
|
|||
|
from sympy.parsing.latex import parse_latex
|
|||
|
|
|||
|
|
|||
|
def load_yaml(yaml_path):
|
|||
|
"""Load a YAML file."""
|
|||
|
if not os.path.exists(yaml_path):
|
|||
|
raise FileNotFoundError(f'YAML file not found: {yaml_path}')
|
|||
|
with open(yaml_path, 'r', encoding='utf-8') as file:
|
|||
|
return yaml.safe_load(file)
|
|||
|
|
|||
|
|
|||
|
def load_json_or_jsonl(file_path):
|
|||
|
"""Load data from a JSON or JSONL file."""
|
|||
|
if not os.path.exists(file_path):
|
|||
|
return None
|
|||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|||
|
if file_path.endswith('.json'):
|
|||
|
return json.load(file)
|
|||
|
elif file_path.endswith('.jsonl'):
|
|||
|
return [json.loads(line) for line in file]
|
|||
|
return None
|
|||
|
|
|||
|
|
|||
|
def find_file(base_path, sub_path, extensions=('json', 'jsonl')):
|
|||
|
"""Find the first available file with given extensions."""
|
|||
|
for ext in extensions:
|
|||
|
file_path = os.path.join(base_path, f'{sub_path}.{ext}')
|
|||
|
if os.path.exists(file_path):
|
|||
|
return file_path
|
|||
|
return None
|
|||
|
|
|||
|
|
|||
|
def load_json_or_jsonl_with_idx(data_path, split='', idx=None):
|
|||
|
base_path = os.path.join(data_path, split)
|
|||
|
if os.path.exists(f'{base_path}.json'):
|
|||
|
file_path = f'{base_path}.json'
|
|||
|
elif os.path.exists(f'{base_path}.jsonl'):
|
|||
|
file_path = f'{base_path}.jsonl'
|
|||
|
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
|||
|
file_path = base_path
|
|||
|
else:
|
|||
|
raise FileNotFoundError('No JSON or JSONL file found.')
|
|||
|
|
|||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|||
|
if file_path.endswith('.json'):
|
|||
|
data = json.load(file)
|
|||
|
elif file_path.endswith('.jsonl'):
|
|||
|
data = [json.loads(line) for line in file]
|
|||
|
|
|||
|
if idx is not None:
|
|||
|
try:
|
|||
|
return next(item for item in data if item.get('idx') == idx)
|
|||
|
except StopIteration:
|
|||
|
raise ValueError(f'No entry found for idx {idx}')
|
|||
|
else:
|
|||
|
return data
|
|||
|
|
|||
|
|
|||
|
def load_split_data(base_path, split_name):
|
|||
|
"""Load the rule and sample data for a specific split."""
|
|||
|
split_path = os.path.join(base_path, split_name)
|
|||
|
rule_path = find_file(split_path, 'rule')
|
|||
|
sample_path = find_file(split_path, 'sample')
|
|||
|
|
|||
|
rules = load_json_or_jsonl(rule_path) if rule_path else []
|
|||
|
samples = load_json_or_jsonl(sample_path) if sample_path else []
|
|||
|
|
|||
|
return {'rules': rules, 'samples': samples}
|
|||
|
|
|||
|
|
|||
|
def process_mixed_data(base_path, mode):
|
|||
|
"""Load and process data for the 'mixed' split and specific mode."""
|
|||
|
mixed_path = os.path.join(base_path, 'mixed')
|
|||
|
file_path = find_file(mixed_path, mode)
|
|||
|
if not file_path:
|
|||
|
print(f'[WARNING] Missing file for mixed mode: {mode}')
|
|||
|
return []
|
|||
|
|
|||
|
data = load_json_or_jsonl(file_path)
|
|||
|
template_path = os.path.join(base_path, 'config/prompt/mixed.yaml')
|
|||
|
template = load_yaml(template_path)
|
|||
|
|
|||
|
processed = []
|
|||
|
for item in data:
|
|||
|
rules = '\n'.join(item.get('rule_list', []))
|
|||
|
questions = '\n'.join(item.get('question_list', []))
|
|||
|
item['prompt'] = template['prompt_format'][0].format(rules, questions)
|
|||
|
processed.append(item)
|
|||
|
|
|||
|
return processed
|
|||
|
|
|||
|
|
|||
|
class ConfigWrapper:
|
|||
|
|
|||
|
def __init__(self, config_path):
|
|||
|
self._config = {}
|
|||
|
with open(config_path, 'r') as file:
|
|||
|
self._config = yaml.safe_load(file)
|
|||
|
for key, value in self._config.items():
|
|||
|
setattr(self, key, value)
|
|||
|
|
|||
|
def __setattr__(self, key, value):
|
|||
|
if key.startswith('_'):
|
|||
|
super().__setattr__(key, value)
|
|||
|
else:
|
|||
|
self._config[key] = value
|
|||
|
super().__setattr__(key, value)
|
|||
|
|
|||
|
def __getattr__(self, key):
|
|||
|
if key in self._config:
|
|||
|
return self._config[key]
|
|||
|
raise AttributeError(
|
|||
|
f"'ConfigWrapper' object has no attribute '{key}'")
|
|||
|
|
|||
|
def get_id(self, data):
|
|||
|
if isinstance(self._config.get('id_key'), str):
|
|||
|
return data.get(self._config.get('id_key'), None)
|
|||
|
elif isinstance(self._config.get('id_key'), list):
|
|||
|
return '_'.join([
|
|||
|
str(data[key]) for key in self._config.get('id_key')
|
|||
|
if key in data
|
|||
|
])
|
|||
|
|
|||
|
def print_all_keys(self):
|
|||
|
print('config keys:')
|
|||
|
for key, value in self._config.items():
|
|||
|
print(f' - {key}: {value}')
|
|||
|
|
|||
|
|
|||
|
config_wrapper = None
|
|||
|
|
|||
|
|
|||
|
def initialize_config(config_path):
|
|||
|
global config_wrapper
|
|||
|
config_wrapper = ConfigWrapper(config_path)
|
|||
|
|
|||
|
|
|||
|
def get_config_wrapper():
|
|||
|
global config_wrapper
|
|||
|
if config_wrapper is None:
|
|||
|
raise RuntimeError(
|
|||
|
'ConfigWrapper not initialized. Call initialize_config first.')
|
|||
|
return config_wrapper
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
config_path = 'config/config.yaml'
|
|||
|
initialize_config(config_path)
|
|||
|
data = {
|
|||
|
'idx':
|
|||
|
'50',
|
|||
|
'step':
|
|||
|
21,
|
|||
|
'question':
|
|||
|
('Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"\n\n'
|
|||
|
'Please provide the decrypted answer, encapsulated in double '
|
|||
|
'square brackets. '
|
|||
|
'For example, the format should be: [[decrypted answer]].'),
|
|||
|
'answer':
|
|||
|
'[[P]]',
|
|||
|
'category':
|
|||
|
'Decryption',
|
|||
|
'rule_id':
|
|||
|
'23',
|
|||
|
'input':
|
|||
|
'Ciphertext: "17,156,4,54,213,17,23,84,228,54,281"',
|
|||
|
'steps_num':
|
|||
|
23,
|
|||
|
'description':
|
|||
|
('For a number c=228 in the ciphertext:\n'
|
|||
|
'Calculate z = c^e mod n. Here ^ means multiplication.\n'
|
|||
|
'z is 80.\nBased on the decimal number represented by z, '
|
|||
|
'use the ascii code to find the corresponding letter '
|
|||
|
'as the plaintext letter p.\n'
|
|||
|
'Please give the letter p in [[...]] format.\n'),
|
|||
|
'atom':
|
|||
|
80
|
|||
|
}
|
|||
|
print(config_wrapper.get_id(data))
|
|||
|
|
|||
|
|
|||
|
def read_yaml(config='default'):
|
|||
|
if os.path.exists(f'config/prompt/{config}.yaml'):
|
|||
|
yaml_file = f'config/prompt/{config}.yaml'
|
|||
|
else:
|
|||
|
yaml_file = config
|
|||
|
with open(yaml_file, 'r') as yaml_file:
|
|||
|
return yaml.safe_load(yaml_file)
|
|||
|
|
|||
|
|
|||
|
def write_jsonl_lines(file, data):
|
|||
|
config_wrapper = get_config_wrapper()
|
|||
|
if config_wrapper.save_prompt:
|
|||
|
json.dump(data, file, ensure_ascii=False)
|
|||
|
else:
|
|||
|
data.pop(config_wrapper.prompt_key)
|
|||
|
json.dump(data, file, ensure_ascii=False)
|
|||
|
file.write('\n')
|
|||
|
file.flush()
|
|||
|
|
|||
|
|
|||
|
def print_info(info):
|
|||
|
print('-' * 100)
|
|||
|
print('[INFO] model_name:', info['model_name'])
|
|||
|
print('[INFO] splits:', info['splits'])
|
|||
|
print('[INFO] modes:', info['modes'])
|
|||
|
print('[INFO] output_dir:', info['output_dir'])
|
|||
|
print('[INFO] Infer Limit:',
|
|||
|
'No limit' if info['infer_limit'] is None else info['infer_limit'])
|
|||
|
print('[INFO] Number of Workers:', info['num_workers'])
|
|||
|
print('[INFO] Batch Size:', info['batch_size'])
|
|||
|
print('[INFO] Use Accel:', info['use_accel'])
|
|||
|
print('-' * 100)
|
|||
|
|
|||
|
|
|||
|
def read_json_or_jsonl(data_path, split='', mapping_key=None):
|
|||
|
base_path = os.path.join(data_path, split)
|
|||
|
if os.path.exists(f'{base_path}.json'):
|
|||
|
file_path = f'{base_path}.json'
|
|||
|
elif os.path.exists(f'{base_path}.jsonl'):
|
|||
|
file_path = f'{base_path}.jsonl'
|
|||
|
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
|||
|
file_path = base_path
|
|||
|
else:
|
|||
|
raise FileNotFoundError('No JSON or JSONL file found.')
|
|||
|
|
|||
|
with open(file_path, 'r') as file:
|
|||
|
if file_path.endswith('.json'):
|
|||
|
data = json.load(file)
|
|||
|
elif file_path.endswith('.jsonl'):
|
|||
|
data = [json.loads(line) for line in file]
|
|||
|
|
|||
|
if mapping_key:
|
|||
|
return {
|
|||
|
item[mapping_key]: item
|
|||
|
for item in data if mapping_key in item
|
|||
|
}
|
|||
|
else:
|
|||
|
return data
|
|||
|
|
|||
|
|
|||
|
def read_json_or_jsonl_with_idx(data_path, split='', idx=None):
|
|||
|
base_path = os.path.join(data_path, split)
|
|||
|
if os.path.exists(f'{base_path}.json'):
|
|||
|
file_path = f'{base_path}.json'
|
|||
|
elif os.path.exists(f'{base_path}.jsonl'):
|
|||
|
file_path = f'{base_path}.jsonl'
|
|||
|
elif base_path.endswith('.json') or base_path.endswith('.jsonl'):
|
|||
|
file_path = base_path
|
|||
|
else:
|
|||
|
raise FileNotFoundError('No JSON or JSONL file found.')
|
|||
|
|
|||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|||
|
if file_path.endswith('.json'):
|
|||
|
data = json.load(file)
|
|||
|
elif file_path.endswith('.jsonl'):
|
|||
|
data = [json.loads(line) for line in file]
|
|||
|
|
|||
|
if idx is not None:
|
|||
|
try:
|
|||
|
return next(item for item in data if item.get('idx') == idx)
|
|||
|
except StopIteration:
|
|||
|
raise ValueError(f'No entry found for idx {idx}')
|
|||
|
else:
|
|||
|
return data
|
|||
|
|
|||
|
|
|||
|
idx_ranges = [
|
|||
|
[18],
|
|||
|
[73, 74, 77],
|
|||
|
[94],
|
|||
|
[115, 116, 117],
|
|||
|
[121, 122, 123, 125],
|
|||
|
[131, 132, 134, 135, 136],
|
|||
|
[141, 143, 149],
|
|||
|
list(range(145, 148)),
|
|||
|
list(range(151, 157)),
|
|||
|
[160, 161, 162],
|
|||
|
[164, 165, 166],
|
|||
|
[170],
|
|||
|
[206, 209],
|
|||
|
list(range(211, 216)),
|
|||
|
[217, 218],
|
|||
|
]
|
|||
|
|
|||
|
|
|||
|
def clean_json_string(json_str):
|
|||
|
json_str = re.sub(r'[\x00-\x1F\x7F]', '', json_str)
|
|||
|
return json_str
|
|||
|
|
|||
|
|
|||
|
def is_in_idx_ranges(idx, idx_ranges):
|
|||
|
for range_list in idx_ranges:
|
|||
|
if int(idx) in range_list:
|
|||
|
return True
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def extract_json(text):
|
|||
|
matches = re.findall(r'{.*}', text, re.DOTALL)
|
|||
|
if matches:
|
|||
|
json_str = matches[-1]
|
|||
|
json_str = clean_json_string(json_str)
|
|||
|
try:
|
|||
|
data = json.loads(json_str)
|
|||
|
return data
|
|||
|
except json.JSONDecodeError as e:
|
|||
|
print(f'Error decoding JSON: {e}')
|
|||
|
return 'NULL'
|
|||
|
return 'NULL'
|
|||
|
|
|||
|
|
|||
|
def extract_all_responses_from_json(response_json):
|
|||
|
results = []
|
|||
|
for key, value in response_json.items():
|
|||
|
results.append(str(value))
|
|||
|
return results
|
|||
|
|
|||
|
|
|||
|
def clean_latex(latex_expr):
|
|||
|
if '=' in latex_expr:
|
|||
|
latex_expr = latex_expr.rsplit('=', 1)[1]
|
|||
|
latex_expr = re.sub(r'\\[()\[\]]', '', latex_expr)
|
|||
|
latex_expr = re.sub(r'\\text\{.*?\}', '', latex_expr)
|
|||
|
latex_expr = re.sub(r'\\(left|right|displaystyle)', '', latex_expr)
|
|||
|
latex_expr = latex_expr.replace('\\\\', '\\')
|
|||
|
return latex_expr
|
|||
|
|
|||
|
|
|||
|
def extract_text_from_brackets(text, clean_level='basic'):
|
|||
|
matches = re.findall(r'\[\[\s*(.*?)\s*\]\]', text, re.DOTALL)
|
|||
|
if not matches:
|
|||
|
matches = re.findall(r'\$\\boxed\{(.*?)\}\$', text, re.DOTALL)
|
|||
|
if not matches:
|
|||
|
matches = re.findall(r'\[\s*(.*?)\s*\]', text, re.DOTALL)
|
|||
|
if matches:
|
|||
|
match_str = matches[0].strip()
|
|||
|
if clean_level == 'clean':
|
|||
|
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
|||
|
' ', '').replace('[', '').replace(']', '')
|
|||
|
elif clean_level == 'logic':
|
|||
|
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
|||
|
' ', '').replace('.', '')
|
|||
|
elif clean_level == 'math':
|
|||
|
match_str = match_str.replace('"', '').replace('\n', '').replace(
|
|||
|
'[', '').replace(']', '').replace('$', '')
|
|||
|
return f'{clean_latex(match_str)}'
|
|||
|
return f'[[{match_str}]]'
|
|||
|
return 'NULL'
|
|||
|
|
|||
|
|
|||
|
def extract_inner_text_from_brackets(text):
|
|||
|
if not isinstance(text, str):
|
|||
|
print(f'text type: {type(text)}, text value: {text}')
|
|||
|
return 'NULL'
|
|||
|
match = re.search(r'\[\[(.*?)\]\]', text, re.DOTALL)
|
|||
|
return match.group(1) if match else 'NULL'
|
|||
|
|
|||
|
|
|||
|
def extract_numbers(str):
|
|||
|
numbers = re.findall(r'\d+', str)
|
|||
|
numbers = list(map(int, numbers))
|
|||
|
return numbers
|
|||
|
|
|||
|
|
|||
|
def extract_and_sort_inequalities(latex_expr):
|
|||
|
pattern = r'(≥|≤)\s*([-]?\d+\.?\d*)'
|
|||
|
matches = re.findall(pattern, latex_expr)
|
|||
|
extracted_inequalities = [''.join(match) for match in matches]
|
|||
|
sorted_inequalities = sorted(extracted_inequalities)
|
|||
|
return sorted_inequalities
|
|||
|
|
|||
|
|
|||
|
def rule5_normalize_content(content):
|
|||
|
parts = [part for part in content.split(';')]
|
|||
|
sorted_parts = sorted(parts)
|
|||
|
return sorted_parts
|
|||
|
|
|||
|
|
|||
|
def normalize_string(s):
|
|||
|
s = re.sub(r'[^0-9]', '', s)
|
|||
|
pairs = s.split(',')
|
|||
|
pairs.sort()
|
|||
|
return pairs
|
|||
|
|
|||
|
|
|||
|
def remove_commas_and_spaces(s):
|
|||
|
return re.sub(r'[,\s\[\]]+', '', s)
|
|||
|
|
|||
|
|
|||
|
def remove_non_alphanumeric(s):
|
|||
|
return re.sub(r'\W+', '', s)
|
|||
|
|
|||
|
|
|||
|
def contains_or(answer):
|
|||
|
return 'or' in answer
|
|||
|
|
|||
|
|
|||
|
def compare_multi_results(response, answer):
|
|||
|
try:
|
|||
|
response_text = extract_text_from_brackets(response, 'clean')
|
|||
|
response_text = re.sub(r'\\text\{or\}', 'or', response_text)
|
|||
|
if response_text == 'NULL':
|
|||
|
return False
|
|||
|
answer = extract_text_from_brackets(answer, 'clean')
|
|||
|
response_split = response_text.strip('[[]]').split('or')
|
|||
|
answer_split = answer.strip('[[]]').split('or')
|
|||
|
response_sorted = sorted([x.strip() for x in response_split])
|
|||
|
answer_sorted = sorted([x.strip() for x in answer_split])
|
|||
|
return response_sorted == answer_sorted
|
|||
|
except Exception as e:
|
|||
|
print(f'Error during comparison: {e}')
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def split_or_expression(expression):
|
|||
|
return [part.strip() for part in expression.split('or')]
|
|||
|
|
|||
|
|
|||
|
def compare_math_expressions(response, answer):
|
|||
|
response_text = extract_text_from_brackets(response, 'math')
|
|||
|
answer_text = extract_text_from_brackets(answer, 'math')
|
|||
|
if response_text == 'NULL':
|
|||
|
return False
|
|||
|
if contains_or(answer_text):
|
|||
|
response_parts = split_or_expression(response_text)
|
|||
|
answer_parts = split_or_expression(answer_text)
|
|||
|
try:
|
|||
|
response_exprs = {
|
|||
|
sp.simplify(parse_latex(part))
|
|||
|
for part in response_parts
|
|||
|
}
|
|||
|
answer_exprs = {
|
|||
|
sp.simplify(parse_latex(part))
|
|||
|
for part in answer_parts
|
|||
|
}
|
|||
|
return response_exprs == answer_exprs
|
|||
|
except Exception as e:
|
|||
|
print(f'Error during simplification or parsing: {e}')
|
|||
|
return response_text == answer_text
|
|||
|
else:
|
|||
|
try:
|
|||
|
response_expr = sp.simplify(parse_latex(response_text))
|
|||
|
answer_expr = sp.simplify(parse_latex(answer_text))
|
|||
|
return response_expr == answer_expr
|
|||
|
except Exception as e:
|
|||
|
print(f'Error during simplification or parsing: {e}')
|
|||
|
return response_text == answer_text
|
|||
|
|
|||
|
|
|||
|
def method_equal(response_text, answer):
|
|||
|
return response_text == answer
|
|||
|
|
|||
|
|
|||
|
def method_1(response_text, answer):
|
|||
|
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
|||
|
cleaned_string = cleaned_string.lower()
|
|||
|
answer = re.sub(r'[^A-Za-z]', '', answer)
|
|||
|
answer = answer.lower()
|
|||
|
return cleaned_string == answer
|
|||
|
|
|||
|
|
|||
|
def method_2(response_text, answer):
|
|||
|
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
|||
|
cleaned_string = cleaned_string.lower()
|
|||
|
answer = answer.split(',')
|
|||
|
return cleaned_string in answer
|
|||
|
|
|||
|
|
|||
|
def method_3(response_text, answer):
|
|||
|
response_text = response_text.lower()
|
|||
|
pairs1 = re.split(r'\W+', response_text)
|
|||
|
pairs2 = answer.split(' ')
|
|||
|
pairs1 = [word for word in pairs1 if word]
|
|||
|
pairs1.sort()
|
|||
|
pairs2.sort()
|
|||
|
return pairs1 == pairs2
|
|||
|
|
|||
|
|
|||
|
def method_4(response_text, answer):
|
|||
|
cleaned_string = re.sub(r'[^A-Za-z]', '', response_text)
|
|||
|
cleaned_string = cleaned_string.lower()
|
|||
|
return cleaned_string in answer
|
|||
|
|
|||
|
|
|||
|
def method_5(response_text, answer):
|
|||
|
response_text = re.sub(r'\s+', '', response_text)
|
|||
|
response_text = response_text.split(',')
|
|||
|
answer = answer.split(',')
|
|||
|
response_text.sort()
|
|||
|
answer.sort()
|
|||
|
return response_text == answer
|
|||
|
|
|||
|
|
|||
|
def method_9(response_text, answer):
|
|||
|
response_text = response_text.replace('×', '*').replace('−', '-')
|
|||
|
answer = answer.replace('×', '*').replace('−', '-')
|
|||
|
|
|||
|
def extract_operators(s):
|
|||
|
return re.findall(r'[+\-*/]', s)
|
|||
|
|
|||
|
response_ops = extract_operators(response_text.split('=')[0])
|
|||
|
answer_ops = extract_operators(answer.split('=')[0])
|
|||
|
if response_ops != answer_ops:
|
|||
|
return False
|
|||
|
match = re.search(r'=\s*(-?\d+)', answer)
|
|||
|
expected_result = int(match.group(1))
|
|||
|
try:
|
|||
|
left_side = response_text.split('=')[0]
|
|||
|
result = eval(left_side)
|
|||
|
except Exception as e:
|
|||
|
print(f'Error during evaluation: {e}')
|
|||
|
return False
|
|||
|
return result == expected_result
|
|||
|
|
|||
|
|
|||
|
def method_10(response_text, answer):
|
|||
|
response_text = response_text.replace('×', '*').replace('−', '-')
|
|||
|
response_text = response_text.split('=')[0]
|
|||
|
answer = answer.split('\n')[0].split('=')[0]
|
|||
|
response_ops = sorted(remove_non_alphanumeric(response_text))
|
|||
|
answer_ops = sorted(remove_non_alphanumeric(answer))
|
|||
|
if response_ops != answer_ops:
|
|||
|
return False
|
|||
|
try:
|
|||
|
result = eval(response_text)
|
|||
|
except Exception as e:
|
|||
|
print(f'Error during evaluation: {e}')
|
|||
|
return False
|
|||
|
return result == 24
|
|||
|
|
|||
|
|
|||
|
def method_18(response_text, answer):
|
|||
|
cleaned_s1 = remove_commas_and_spaces(response_text)
|
|||
|
cleaned_s2 = remove_commas_and_spaces(answer)
|
|||
|
return cleaned_s1 == cleaned_s2
|
|||
|
|
|||
|
|
|||
|
def method_general(response_text, answer):
|
|||
|
cleaned_s1 = remove_non_alphanumeric(response_text)
|
|||
|
cleaned_s2 = remove_non_alphanumeric(answer)
|
|||
|
return cleaned_s1 == cleaned_s2
|
|||
|
|
|||
|
|
|||
|
question_methods = {
|
|||
|
'1': method_1,
|
|||
|
'2': method_2,
|
|||
|
'3': method_3,
|
|||
|
'4': method_4,
|
|||
|
'5': method_5,
|
|||
|
'9': method_9,
|
|||
|
'10': method_10,
|
|||
|
'18': method_18,
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
def evaluate_response_vs_answer(response, answer, question_type, rule_id, idx):
|
|||
|
if question_type == 'logic' and rule_id == '5':
|
|||
|
response_text = extract_text_from_brackets(response, 'logic')
|
|||
|
answer_text = extract_text_from_brackets(answer, 'logic')
|
|||
|
if response_text is None:
|
|||
|
return False
|
|||
|
normalized_response = rule5_normalize_content(response_text)
|
|||
|
normalized_answer = rule5_normalize_content(answer)
|
|||
|
return normalized_response == normalized_answer
|
|||
|
elif question_type == 'logic':
|
|||
|
response_text = extract_text_from_brackets(response, 'logic')
|
|||
|
answer_text = extract_text_from_brackets(answer, 'logic')
|
|||
|
return response_text == answer_text
|
|||
|
elif question_type == 'operation' and (idx == '178' or idx == '179'):
|
|||
|
response_text = extract_text_from_brackets(response, 'clean')
|
|||
|
response_text = extract_and_sort_inequalities(response_text)
|
|||
|
answer_text = extract_and_sort_inequalities(answer)
|
|||
|
# print(response_text, answer_text)
|
|||
|
return response_text == answer_text
|
|||
|
elif question_type == 'operation' and rule_id == '18':
|
|||
|
response_text = extract_text_from_brackets(response, 'clean')
|
|||
|
answer = extract_inner_text_from_brackets(answer)
|
|||
|
response_text = ''.join(sorted(re.sub(r'\W+', '', response_text)))
|
|||
|
answer = ''.join(sorted(re.sub(r'\W+', '', answer)))
|
|||
|
return response_text == answer
|
|||
|
elif question_type == 'operation' and rule_id in {'23', '24', '25'}:
|
|||
|
response_text = extract_text_from_brackets(response, 'clean')
|
|||
|
if response_text is None:
|
|||
|
return False
|
|||
|
response_text = extract_numbers(response_text)
|
|||
|
answer_text = extract_numbers(answer)
|
|||
|
return response_text == answer_text
|
|||
|
elif question_type == 'operation' and is_in_idx_ranges(idx, idx_ranges):
|
|||
|
return compare_math_expressions(response, answer)
|
|||
|
elif question_type == 'operation' and contains_or(answer):
|
|||
|
return compare_multi_results(response, answer)
|
|||
|
elif question_type == 'puzzle':
|
|||
|
response_text = extract_inner_text_from_brackets(response)
|
|||
|
answer = extract_inner_text_from_brackets(answer)
|
|||
|
method = question_methods.get(rule_id)
|
|||
|
if method:
|
|||
|
return method(response_text, answer)
|
|||
|
return method_general(response_text, answer)
|
|||
|
else:
|
|||
|
response_text = extract_text_from_brackets(response, 'clean')
|
|||
|
return response_text == answer
|
|||
|
|
|||
|
|
|||
|
def compute_one_mixed_question_pass_rate(idx,
|
|||
|
question_list,
|
|||
|
response_json,
|
|||
|
base_path=None):
|
|||
|
if response_json == 'NULL':
|
|||
|
result_dict = {
|
|||
|
'idx': idx,
|
|||
|
'response': response_json,
|
|||
|
'details': None,
|
|||
|
'pass_rate': 0,
|
|||
|
'is_correct': False
|
|||
|
}
|
|||
|
return result_dict
|
|||
|
response_list = extract_all_responses_from_json(response_json)
|
|||
|
correct_num = 0
|
|||
|
results = []
|
|||
|
for q_idx, question in enumerate(question_list):
|
|||
|
category, question_idx = question.rsplit('_', 1)
|
|||
|
question_content = load_json_or_jsonl_with_idx(base_path,
|
|||
|
os.path.join(
|
|||
|
category, 'sample'),
|
|||
|
idx=question_idx)
|
|||
|
answer = question_content['answer']
|
|||
|
if q_idx >= len(response_list):
|
|||
|
break
|
|||
|
response = response_list[q_idx]
|
|||
|
response_text = extract_text_from_brackets(response)
|
|||
|
rule_id = question_content['rule_id']
|
|||
|
is_correct = evaluate_response_vs_answer(response, answer, category,
|
|||
|
rule_id, q_idx)
|
|||
|
if is_correct:
|
|||
|
correct_num += 1
|
|||
|
results.append({
|
|||
|
'question': question,
|
|||
|
'response_text': response_text,
|
|||
|
'answer': answer,
|
|||
|
'is_correct': is_correct
|
|||
|
})
|
|||
|
|
|||
|
pass_rate = correct_num / len(question_list)
|
|||
|
question_correct = pass_rate == 1.0
|
|||
|
result_dict = {
|
|||
|
'idx': idx,
|
|||
|
'response': response_json,
|
|||
|
'details': results,
|
|||
|
'pass_rate': pass_rate,
|
|||
|
'is_correct': question_correct
|
|||
|
}
|
|||
|
return result_dict
|
|||
|
|
|||
|
|
|||
|
def evaluate_responses(data, mode, base_path=None):
|
|||
|
results = []
|
|||
|
|
|||
|
# Iterate over the values of the dictionary (numerical keys)
|
|||
|
for key, record in data.items():
|
|||
|
idx = key # Use the dictionary key as the "idx"
|
|||
|
response = record.get('prediction', '')
|
|||
|
question_type = record.get('category', '')
|
|||
|
if mode == 'mixed':
|
|||
|
question_list = record.get('question_list')
|
|||
|
response_json = extract_json(response)
|
|||
|
result_dict = compute_one_mixed_question_pass_rate(
|
|||
|
idx, question_list, response_json, base_path)
|
|||
|
results.append(result_dict)
|
|||
|
else:
|
|||
|
response_text = extract_text_from_brackets(response)
|
|||
|
answer = record.get('gold', '')
|
|||
|
rule_id = record.get('rule_id', '')
|
|||
|
is_correct = evaluate_response_vs_answer(response, answer,
|
|||
|
question_type, rule_id,
|
|||
|
idx)
|
|||
|
result_dict = {
|
|||
|
'idx': idx,
|
|||
|
'response': response,
|
|||
|
'response_text': response_text,
|
|||
|
'answer': answer,
|
|||
|
'is_correct': is_correct
|
|||
|
}
|
|||
|
if question_type == 'counterfactual':
|
|||
|
real_life_answer = record.get('real_life_answer', '')
|
|||
|
is_real_life = evaluate_response_vs_answer(
|
|||
|
response, real_life_answer, question_type, rule_id, idx)
|
|||
|
result_dict['real_life_answer'] = real_life_answer
|
|||
|
result_dict['is_real_life'] = is_real_life
|
|||
|
if question_type == 'cipher' and mode == 'subquestions':
|
|||
|
result_dict['type'] = record.get('type', '')
|
|||
|
results.append(result_dict)
|
|||
|
return results
|