OpenCompass/opencompass/datasets/mathbench.py
Xingjun.Wang edab1c07ba
[Feature] Support ModelScope datasets (#1289)
* add ceval, gsm8k modelscope surpport

* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest

* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets

* format file

* format file

* update dataset format

* support ms_dataset

* udpate dataset for modelscope support

* merge myl_dev and update test_ms_dataset

* udpate dataset for modelscope support

* update readme

* update eval_api_zhipu_v2

* remove unused code

* add get_data_path function

* update readme

* remove tydiqa japanese subset

* add ceval, gsm8k modelscope surpport

* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest

* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets

* format file

* format file

* update dataset format

* support ms_dataset

* udpate dataset for modelscope support

* merge myl_dev and update test_ms_dataset

* update readme

* udpate dataset for modelscope support

* update eval_api_zhipu_v2

* remove unused code

* add get_data_path function

* remove tydiqa japanese subset

* update util

* remove .DS_Store

* fix md format

* move util into package

* update docs/get_started.md

* restore eval_api_zhipu_v2.py, add environment setting

* Update dataset

* Update

* Update

* Update

* Update

---------

Co-authored-by: Yun lin <yunlin@U-Q9X2K4QV-1904.local>
Co-authored-by: Yunnglin <mao.looper@qq.com>
Co-authored-by: Yun lin <yunlin@laptop.local>
Co-authored-by: Yunnglin <maoyl@smail.nju.edu.cn>
Co-authored-by: zhangsongyang <zhangsongyang@pjlab.org.cn>
2024-07-29 13:48:32 +08:00

384 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import json
import os.path as osp
import re
from datasets import Dataset
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.utils import get_data_path
from .base import BaseDataset
def get_number(options):
result_string = ''
for i, option in enumerate(options, start=ord('A')):
result_string += f'{chr(i)}. {option}\n'
return result_string
def get_circular_example(entry, id):
"""For given example, generate four circular examples."""
# Only 4 options is supported for current circular eval.
circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC']
data = []
for c in circular_patterns:
line = copy.deepcopy(entry)
options = []
for i in range(4):
options.append(line['options'][ord(c[i]) - ord('A')])
line['options'] = options
line['answer'] = {
c[0]: 'A',
c[1]: 'B',
c[2]: 'C',
c[3]: 'D'
}[line['answer']]
line['answer'] = str(id) + '--' + line['answer'] + '--' + c
line['question'] = line['question'].strip() + '\n' + get_number(
line['options'])
data.append(line)
return data
@LOAD_DATASET.register_module()
class MathBenchDataset(BaseDataset):
@staticmethod
def load(path: str, name: str, with_circular: bool = True):
"""MathBenth Dataset.
Args:
path (str): Path of the mathbench dataset.
name (str): Name of the target subset.
with_circular (bool): Whether to create circular dataset for
single choice question. Defaults to True.
"""
path = get_data_path(path, local_mode=True)
data = []
filename = osp.join(path, f'{name}.jsonl')
with open(filename, 'r', encoding='utf-8') as infile:
for id, line in enumerate(infile):
entry = json.loads(line)
if 'cloze' in name:
data.append({
'question': entry['question'].strip(),
'answer': entry['answer'].strip()
})
else:
if with_circular:
data.extend(get_circular_example(entry, id))
else:
question = entry['question'].strip(
) + '\n' + get_number(entry['options'])
info = {
'question': question,
'answer': entry['answer'].strip()
}
# # For PPL evaluation
# for i in range(4):
# info[chr(ord('A') +
# i)] = entry['options'][i].strip()
data.append(info)
dataset = Dataset.from_list(data)
return dataset
@TEXT_POSTPROCESSORS.register_module()
def mathbench_postprocess(text: str, name: str) -> str:
split = False
ans = text
if '_cn' in name:
ans_line = ans.split('答案是')
else:
ans_line = ans.split('The answer is')
if len(ans_line) != 1:
ans = ans_line[1].strip()
split = True
output = re.sub(r'(\d),(\d)', r'\1\2', ans)
numbers = re.findall(r'-?\d*\.?/?\d+|\d+', output)
if numbers:
return numbers[0] if split else numbers[-1]
return ans
@LOAD_DATASET.register_module()
class MathBenchBuggyDataset(BaseDataset):
@staticmethod
def load(path: str, name: str, with_circular: bool = True):
data = []
filename = osp.join(path, f'{name}.jsonl')
with open(filename, 'r', encoding='utf-8') as infile:
for id, line in enumerate(infile):
entry = json.loads(line)
if 'cloze' in name:
data.append({
'question': entry['question'].strip(),
'answer': entry['answer'].strip()
})
else:
if with_circular:
data.extend(get_circular_example(entry, id))
else:
question = entry['question'].strip(
) + '\n' + get_number(entry['options'])
info = {
'question': question,
'answer': entry['answer'].strip()
}
# For PPL evaluation
for i in range(4):
info[chr(ord('A') +
i)] = entry['options'][i].strip()
data.append(info)
if 'cloze' not in name:
data = data[:(len(data) // 4 + 7) // 8 * 8]
dataset = Dataset.from_list(data)
return dataset
import collections
from ..openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
from ..registry import ICL_EVALUATORS
def first_option_postprocess(text: str, options: str, cushion=True) -> str:
"""Find first valid option for text."""
# yapf: disable
# flake8: noqa: W605
patterns = [
f'答案是?\s*([{options}])',
f'答案是?\s*\s*([{options}])',
f'答案是?\s*:\s*([{options}])',
f'答案应该?是\s*([{options}])',
f'答案应该?选\s*([{options}])',
f'答案为\s*([{options}])',
f'答案选\s*([{options}])',
f'选择?\s*([{options}])',
f'故选?\s*([{options}])'
f'只有选?项?\s?([{options}])\s?是?对',
f'只有选?项?\s?([{options}])\s?是?错',
f'只有选?项?\s?([{options}])\s?不?正确',
f'只有选?项?\s?([{options}])\s?错误',
f'说法不?对选?项?的?是\s?([{options}])',
f'说法不?正确选?项?的?是\s?([{options}])',
f'说法错误选?项?的?是\s?([{options}])',
f'([{options}])\s?是正确的',
f'([{options}])\s?是正确答案',
f'选项\s?([{options}])\s?正确',
f'所以答\s?([{options}])',
f'所以\s?([{options}][.。$]?$)',
f'所有\s?([{options}][.。$]?$)',
f'[\s:,]([{options}])[。,,\.]?$',
f'[\s,:][故即]([{options}])[。\.]?$',
f'[\s,:]因此([{options}])[。\.]?$',
f'[是为。]\s?([{options}])[。\.]?$',
f'因此\s?([{options}])[。\.]?$',
f'显然\s?([{options}])[。\.]?$',
f'回答[\s:]\s?([{options}])',
f'Answer[\s:]\s?([{options}])',
f'答案是\s?(\S+)(?:。|$)',
f'答案应该是\s?(\S+)(?:。|$)',
f'答案为\s?(\S+)(?:。|$)',
f'[Tt]he answer is:?\s+\(?([{options}])\)?',
f'[Tt]he answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he answer to the question is:?\s+\(?([{options}])\)?',
]
cushion_patterns = [
f'^选项\s?([{options}])',
f'^([{options}])\s?选?项',
# f'[\s|^]([{options}])[\s。,:\.$]',
f'[\s|^]([{options}])[。,,:\.$]',
f'1.\s?([{options}])[.。$]?$',
f'([{options}]):',
f'([{options}])',
]
# flake8: noqa
# yapf: enable
for pattern in patterns:
match = re.search(pattern, text, re.DOTALL)
if match:
outputs = match.group(0)
for i in options:
if i in outputs:
return i, pattern
if cushion:
for pattern in cushion_patterns:
outputs = []
current_text = text
while True:
match = re.search(pattern, current_text, re.DOTALL)
if match:
outputs.append(match.group(0))
current_text = current_text[match.end():]
else:
break
# if len(outputs) >= 2:
# from IPython import embed; embed(); exit()
if outputs:
outputs = outputs[-1]
for i in options:
if i in outputs:
return i, pattern
return '', None
def remove_invisible_chars(text: str) -> str:
"""Remove invisible characters."""
text = re.sub(r'\s+', '', text)
text = re.sub(r'\u200b', '', text)
return text
@ICL_EVALUATORS.register_module()
class MathBenchCircularEvaluator(BaseEvaluator):
"""Robust circular evaluator for multi-choice questions."""
def __init__(self) -> None:
super().__init__()
self.cp4 = ['ABCD', 'BCDA', 'CDAB', 'DABC']
self.cp1 = ['ABCD']
def score(self, predictions, references, test_set):
"""Calculate the accuracy of predictions.
Args:
predictions (list): List of predictions.
references (list): List of references.
Returns:
dict: A dict of evaluation results.
"""
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
extract_details = {}
extracted_predictions = []
for index, p in enumerate(predictions):
extracted_p = None
matched_pattern = None
if '\\boxed' in p:
match = re.findall(r'\\boxed\{(.*)\}', p)
if match:
for m in match:
for j in range(4):
m = remove_invisible_chars(m)
o = remove_invisible_chars(
test_set[index]['options'][j])
if m == o:
extracted_p = chr(ord('A') + j)
matched_pattern = 'boxed_answer'
break
else:
if m in ['A', 'B', 'C', 'D']:
extracted_p = m
matched_pattern = 'boxed_ABCD'
else:
continue
break
if extracted_p is None:
extracted_p, matched_pattern = first_option_postprocess(
p, 'ABCD')
extracted_predictions.append(extracted_p)
extract_details[str(index)] = {
'question': test_set[index]['question'],
'options': test_set[index]['options'],
'origin_pred': p,
'extracted_pred': extracted_p,
'matched_pattern': matched_pattern,
'ref': references[index],
}
predictions = extracted_predictions
results = {}
results.update({'acc_4': 0, 'acc_1': 0})
# Accuracy for patterns with no circular shift / 4 circular shifts
for index, (pred, reference) in enumerate(zip(predictions,
references)):
_, ref, circular_pattern = reference.split('--')
extract_details[str(index)]['is_correct'] = pred == ref
if circular_pattern in self.cp4:
results['acc_4'] += 1 if pred == ref else 0
if circular_pattern in self.cp1:
results['acc_1'] += 1 if pred == ref else 0
for k in ['acc_4', 'acc_1']:
results[k] = results[k] / len(predictions) * 4 / int(
k.split('_')[-1]) * 100
# Accuracy for patterns with no circular shift / 4 circular shifts
details = {4: {}, 1: {}}
for pred, reference in zip(predictions, references):
index, ref, circular_pattern = reference.split('--')
if index not in details[4]:
details[4][index] = []
details[1][index] = []
if circular_pattern in self.cp4:
details[4][index].append(True if pred == ref else False)
if circular_pattern in self.cp1:
details[1][index].append(True if pred == ref else False)
# Calculate accuracy for having at least j correct out of i total
for i in [1, 4]:
for j in range(0, i + 1):
count, total = 0, 0
for index in details[i]:
if sum(details[i][index]) >= j:
count += 1
total += 1
results[f'more_{i}_{j}'] = count / total * 100
# Consider fully correct as correct
for i in [1, 4]:
results[f'perf_{i}'] = results[f'more_{i}_{i}']
# Calculate voting accuracy
voting = {'vote_4': {}, 'vote_1': {}}
refs = {}
for pred, reference in zip(predictions, references):
index, ref, circular_pattern = reference.split('--')
c = circular_pattern
back_map = {'A': c[0], 'B': c[1], 'C': c[2], 'D': c[3]}
ref = back_map[ref]
if pred not in ['A', 'B', 'C', 'D']:
pred = '-'
else:
pred = back_map[pred]
if index not in voting['vote_4']:
voting['vote_4'][index] = collections.Counter()
voting['vote_1'][index] = collections.Counter()
refs[index] = ref
if c in self.cp4:
voting['vote_4'][index][pred] += 1
if c in self.cp1:
voting['vote_1'][index][pred] += 1
for k in ['vote_4', 'vote_1']:
voting_count = 0
for index in voting[k]:
if refs[index] == voting[k][index].most_common(1)[0][0]:
voting_count += 1
results[k] = voting_count / len(voting[k]) * 100
# Calculate the frequency of ABCD in model predictions
prior_counts = {'A': 0, 'B': 0, 'C': 0, 'D': 0, '-': 0}
for pred, reference in zip(predictions, references):
if pred in ['A', 'B', 'C', 'D']:
prior_counts[pred] += 1
else:
prior_counts['-'] += 1
for k in ['A', 'B', 'C', 'D', '-']:
results[f'prior_{k}'] = prior_counts[k] / len(predictions) * 100
results['details'] = extract_details
return results