OpenCompass/opencompass/datasets/tabmwp.py

246 lines
8.1 KiB
Python
Raw Normal View History

import json
import os.path as osp
import random
import re
from typing import List
import numpy as np
from datasets import Dataset, DatasetDict
from opencompass.openicl.icl_evaluator.icl_hf_evaluator import AccEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from .base import BaseDataset
def get_table_text(problem):
table = problem['table']
title = problem['table_title']
if title and len(title) > 0:
table = f'[TITLE]: {title}\n{table}'
return table
def get_question_text(problem, option_inds='ABCDEFGH'):
question = problem['question']
unit = problem['unit']
if unit and len(unit) > 0:
question = f'{question} (Unit: {unit})'
choices = problem['choices']
if choices and len(choices) > 0:
choice_list = []
for i, c in enumerate(choices):
choice_list.append('({}) {}'.format(option_inds[i], c))
options = ' '.join(choice_list)
question = f'{question}\nOptions: {options}'
return question
def get_answer(problem):
return problem['answer']
def get_choices(problem):
return problem['choices']
def get_unit(problem):
return problem['unit']
def get_solution_text(problem):
# \\n: GPT-3 can generate the solution with more tokens
solution = problem['solution'].replace('\n', '\\n')
return solution
def normalize_answer(text, unit):
# ["1,000", "123", "3/4", "56.456", "$56.4", "-3", "-10.02", "-3/2"]
text = re.sub(r'^[\$]', '', text)
text = re.sub(r'[\,\.\,\/]$', '', text)
result = re.match(r'^[-+]?[\d,./]+$', text)
if result is not None:
# is number?
text = text.replace(',', '')
result = re.match(r'[-+]?\d+$', text)
if result is not None:
number = int(text)
elif '/' in text:
nums = text.split('/')
number = round(float(nums[0]) / float(nums[1]), 3)
else:
number = round(float(text), 3)
number = str(number)
number = re.sub(r'\.[0]+$', '', number)
return number
else:
# is text
if unit:
text = text.replace(unit, '').strip()
return text
def score_string_similarity(str1, str2):
if str1 == str2:
return 2.0
if ' ' in str1 or ' ' in str2:
str1_split = str1.split(' ')
str2_split = str2.split(' ')
overlap = list(set(str1_split) & set(str2_split))
return len(overlap) / max(len(str1_split), len(str2_split))
else:
if str1 == str2:
return 1.0
else:
return 0.0
def extract_prediction(output, options=None, option_inds='ABCDEFGH'):
# $\\frac{16}{95}$ -> 16/95
output = re.sub(r'\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?', r'\1/\2',
output)
output = re.sub(r'(?<![AP]\.M)\.$', '', output)
output = re.sub(r'(?<=\d)[\=](?=[\-\$\d])', ' = ', output)
output = re.sub(r'\u2212', '-', output)
# Multi-choice questions
if options:
patterns = [
r'^\(([A-Za-z])\)$', # "(b)", "(B)"
r'^([A-Za-z])$', # "b", "B"
r'^([A-Za-z]). ', # "b", "B"
r'[Th]he answer is ([A-Z])', # "The answer is B"
r'^\(([A-Za-z])\) [\s\S]+$', # "(A) XXXXX"
r'[Th]he answer is \(([A-Za-z])\) [\s\S]+$'
]
# have "X" in the output
for p in patterns:
pattern = re.compile(p)
res = pattern.findall(output)
if len(res) > 0:
pred = res[0].upper() # e.g., "B"
if pred in option_inds:
ind = option_inds.index(pred) # 1
if ind >= len(options):
random.seed(123)
ind = random.choice(range(len(options)))
prediction = options[ind]
return prediction
# find the most similar options
scores = [score_string_similarity(x, output) for x in options]
max_idx = int(
np.argmax(scores)) # json does not recognize NumPy data types
prediction = options[max_idx]
return prediction
else:
# free_text QA problems, numeric answer
patterns = [
r'[Th]he answer is ([\s\S]+)$', # "The answer is XXXXX.",
r'[Th]he table shows that ([\d\$\.\,\/\:]+) ',
r' = ([\d\$\.\,\/\:]+)', # "= $1.40"
r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "will be $1.40"
r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', # 7:25 P.M.
r'([\-\d\$\.\,\/\:]{0,}[\d]+)', # 14.5
]
for p in patterns:
pattern = re.compile(p)
res = pattern.findall(output)
if len(res) > 0:
prediction = res[-1].strip()
if prediction.endswith('.') and '.M.' not in prediction:
prediction = prediction[:-1]
return prediction
return output
@ICL_EVALUATORS.register_module()
class TabMWPEvaluator(AccEvaluator):
"""Accuracy evaluator for TabMWP Dataset."""
def _preprocess(self, predictions: List, references: List) -> dict:
"""Preprocess the final predictions and references to needed format.
Args:
predictions (List): List of predictions of each sample.
references (List): List of targets for each sample.
Returns:
dict: preprocessed results.
"""
preds, golds = [], []
for idx in range(len(references)):
pred = predictions[idx]
unit = references[idx]['unit']
answer = references[idx]['answer']
choices = references[idx]['choices']
preds.append(
normalize_answer(extract_prediction(pred, choices),
unit).lower())
golds.append(normalize_answer(answer, unit).lower())
return super()._preprocess(preds, golds)
@LOAD_DATASET.register_module()
class TabMWPDataset(BaseDataset):
# The TabMWP dataset contains 38,431 tabular math word problems.
# Each question in TabMWP is aligned with a tabular context,
# which is presented as an image, semi-structured text, and a-
# structured table. There are two types of questions: free-text-
# and multi-choice, and each problem is annotated with gold-
# solutions to reveal the multi-step reasoning process.
# To learn more about it, please follow:
# https://github.com/lupantech/PromptPG/tree/main
@staticmethod
def load(path: str):
dataset = DatasetDict()
for split in ['dev', 'test', 'train']:
raw_data = []
filename = osp.join(path, f'problems_{split}.json')
with open(filename, 'r', encoding='utf-8') as f:
json_data = json.load(f)
for idx in json_data:
problem = json_data[idx]
question = get_question_text(problem)
table = get_table_text(problem)
unit = get_unit(problem)
answer = get_answer(problem)
choices = get_choices(problem)
solution = get_solution_text(problem)
raw_data.append({
'question':
question,
'table':
table,
'test_elements': {
'answer': answer,
'unit': unit,
'choices': choices
},
'answer':
f'Answer: The answer is {answer}.',
'solution':
f'Solution: {solution}',
'answer_and_solution':
f'Answer: The answer is {answer}. BECAUSE: {solution}',
'solution_and_answer':
f'Answer: {solution} The answer is {answer}.'
})
dataset[split] = Dataset.from_list(raw_data)
return dataset