mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[fFeat] Add an opensource dataset Tabmwp (#505)
* TabMWP * TabMWP * fixed * fixed * fixed * done * done * done --------- Co-authored-by: caomaosong <caomaosong@pjlab.org.cn>
This commit is contained in:
parent
987a711232
commit
f25a980043
4
configs/datasets/TabMWP/TabMWP_gen.py
Normal file
4
configs/datasets/TabMWP/TabMWP_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .TabMWP_gen_2aef96 import TabMWP_datasets # noqa: F401, F403
|
53
configs/datasets/TabMWP/TabMWP_gen_2aef96.py
Normal file
53
configs/datasets/TabMWP/TabMWP_gen_2aef96.py
Normal file
@ -0,0 +1,53 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import TabMWPDataset, TabMWPEvaluator
|
||||
|
||||
# None of the TabMWP dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
# Please download the dataset from https://github.com/lupantech/PromptPG/tree/main
|
||||
|
||||
input_format='TQ'
|
||||
output_format='A'
|
||||
elements = {"Q": "Question: {question}",
|
||||
"T": "Table: {table}",
|
||||
"S": "Solution: {solution}",
|
||||
"A": "Answer: The answer is {answer}.",
|
||||
"AS": "Answer: The answer is {answer}. BECAUSE: {solution}",
|
||||
"SA": "Answer: {solution} The answer is {answer}."}
|
||||
|
||||
|
||||
TabMWP_reader_cfg = dict(
|
||||
input_columns=["question", "table"],
|
||||
output_column="test_elements",
|
||||
train_split='dev',
|
||||
)
|
||||
|
||||
TabMWP_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt= "\n".join(elements[label] for label in input_format)
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
TabMWP_eval_cfg = dict(
|
||||
evaluator=dict(type=TabMWPEvaluator)
|
||||
)
|
||||
|
||||
TabMWP_datasets = [
|
||||
dict(
|
||||
type=TabMWPDataset,
|
||||
path="./data/tabmwp/",
|
||||
reader_cfg=TabMWP_reader_cfg,
|
||||
infer_cfg=TabMWP_infer_cfg,
|
||||
eval_cfg=TabMWP_eval_cfg,)
|
||||
]
|
||||
|
@ -66,6 +66,7 @@ from .storycloze import * # noqa: F401, F403
|
||||
from .strategyqa import * # noqa: F401, F403
|
||||
from .summedits import * # noqa: F401, F403
|
||||
from .summscreen import * # noqa: F401, F403
|
||||
from .tabmwp import * # noqa: F401, F403
|
||||
from .TheoremQA import * # noqa: F401, F403
|
||||
from .tnews import * # noqa: F401, F403
|
||||
from .triviaqa import * # noqa: F401, F403
|
||||
|
245
opencompass/datasets/tabmwp.py
Normal file
245
opencompass/datasets/tabmwp.py
Normal file
@ -0,0 +1,245 @@
|
||||
import json
|
||||
import os.path as osp
|
||||
import random
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.openicl.icl_evaluator.icl_hf_evaluator import AccEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
def get_table_text(problem):
|
||||
table = problem['table']
|
||||
title = problem['table_title']
|
||||
if title and len(title) > 0:
|
||||
table = f'[TITLE]: {title}\n{table}'
|
||||
return table
|
||||
|
||||
|
||||
def get_question_text(problem, option_inds='ABCDEFGH'):
|
||||
question = problem['question']
|
||||
|
||||
unit = problem['unit']
|
||||
if unit and len(unit) > 0:
|
||||
question = f'{question} (Unit: {unit})'
|
||||
|
||||
choices = problem['choices']
|
||||
if choices and len(choices) > 0:
|
||||
choice_list = []
|
||||
for i, c in enumerate(choices):
|
||||
choice_list.append('({}) {}'.format(option_inds[i], c))
|
||||
options = ' '.join(choice_list)
|
||||
question = f'{question}\nOptions: {options}'
|
||||
|
||||
return question
|
||||
|
||||
|
||||
def get_answer(problem):
|
||||
return problem['answer']
|
||||
|
||||
|
||||
def get_choices(problem):
|
||||
return problem['choices']
|
||||
|
||||
|
||||
def get_unit(problem):
|
||||
return problem['unit']
|
||||
|
||||
|
||||
def get_solution_text(problem):
|
||||
# \\n: GPT-3 can generate the solution with more tokens
|
||||
solution = problem['solution'].replace('\n', '\\n')
|
||||
return solution
|
||||
|
||||
|
||||
def normalize_answer(text, unit):
|
||||
# ["1,000", "123", "3/4", "56.456", "$56.4", "-3", "-10.02", "-3/2"]
|
||||
|
||||
text = re.sub(r'^[\$]', '', text)
|
||||
text = re.sub(r'[\,\.\,\/]$', '', text)
|
||||
|
||||
result = re.match(r'^[-+]?[\d,./]+$', text)
|
||||
|
||||
if result is not None:
|
||||
# is number?
|
||||
text = text.replace(',', '')
|
||||
result = re.match(r'[-+]?\d+$', text)
|
||||
|
||||
if result is not None:
|
||||
number = int(text)
|
||||
elif '/' in text:
|
||||
nums = text.split('/')
|
||||
number = round(float(nums[0]) / float(nums[1]), 3)
|
||||
else:
|
||||
number = round(float(text), 3)
|
||||
number = str(number)
|
||||
number = re.sub(r'\.[0]+$', '', number)
|
||||
return number
|
||||
else:
|
||||
# is text
|
||||
if unit:
|
||||
text = text.replace(unit, '').strip()
|
||||
return text
|
||||
|
||||
|
||||
def score_string_similarity(str1, str2):
|
||||
if str1 == str2:
|
||||
return 2.0
|
||||
if ' ' in str1 or ' ' in str2:
|
||||
str1_split = str1.split(' ')
|
||||
str2_split = str2.split(' ')
|
||||
overlap = list(set(str1_split) & set(str2_split))
|
||||
return len(overlap) / max(len(str1_split), len(str2_split))
|
||||
else:
|
||||
if str1 == str2:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
def extract_prediction(output, options=None, option_inds='ABCDEFGH'):
|
||||
|
||||
# $\\frac{16}{95}$ -> 16/95
|
||||
output = re.sub(r'\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?', r'\1/\2',
|
||||
output)
|
||||
|
||||
output = re.sub(r'(?<![AP]\.M)\.$', '', output)
|
||||
output = re.sub(r'(?<=\d)[\=](?=[\-\$\d])', ' = ', output)
|
||||
output = re.sub(r'\u2212', '-', output)
|
||||
|
||||
# Multi-choice questions
|
||||
if options:
|
||||
patterns = [
|
||||
r'^\(([A-Za-z])\)$', # "(b)", "(B)"
|
||||
r'^([A-Za-z])$', # "b", "B"
|
||||
r'^([A-Za-z]). ', # "b", "B"
|
||||
r'[Th]he answer is ([A-Z])', # "The answer is B"
|
||||
r'^\(([A-Za-z])\) [\s\S]+$', # "(A) XXXXX"
|
||||
r'[Th]he answer is \(([A-Za-z])\) [\s\S]+$'
|
||||
]
|
||||
|
||||
# have "X" in the output
|
||||
for p in patterns:
|
||||
pattern = re.compile(p)
|
||||
res = pattern.findall(output)
|
||||
if len(res) > 0:
|
||||
pred = res[0].upper() # e.g., "B"
|
||||
if pred in option_inds:
|
||||
ind = option_inds.index(pred) # 1
|
||||
if ind >= len(options):
|
||||
random.seed(123)
|
||||
ind = random.choice(range(len(options)))
|
||||
prediction = options[ind]
|
||||
return prediction
|
||||
|
||||
# find the most similar options
|
||||
scores = [score_string_similarity(x, output) for x in options]
|
||||
max_idx = int(
|
||||
np.argmax(scores)) # json does not recognize NumPy data types
|
||||
prediction = options[max_idx]
|
||||
return prediction
|
||||
|
||||
else:
|
||||
# free_text QA problems, numeric answer
|
||||
patterns = [
|
||||
r'[Th]he answer is ([\s\S]+)$', # "The answer is XXXXX.",
|
||||
r'[Th]he table shows that ([\d\$\.\,\/\:]+) ',
|
||||
r' = ([\d\$\.\,\/\:]+)', # "= $1.40"
|
||||
r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "will be $1.40"
|
||||
r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
|
||||
r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
|
||||
r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', # 7:25 P.M.
|
||||
r'([\-\d\$\.\,\/\:]{0,}[\d]+)', # 14.5
|
||||
]
|
||||
|
||||
for p in patterns:
|
||||
pattern = re.compile(p)
|
||||
res = pattern.findall(output)
|
||||
if len(res) > 0:
|
||||
prediction = res[-1].strip()
|
||||
if prediction.endswith('.') and '.M.' not in prediction:
|
||||
prediction = prediction[:-1]
|
||||
return prediction
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class TabMWPEvaluator(AccEvaluator):
|
||||
"""Accuracy evaluator for TabMWP Dataset."""
|
||||
|
||||
def _preprocess(self, predictions: List, references: List) -> dict:
|
||||
"""Preprocess the final predictions and references to needed format.
|
||||
|
||||
Args:
|
||||
predictions (List): List of predictions of each sample.
|
||||
references (List): List of targets for each sample.
|
||||
|
||||
Returns:
|
||||
dict: preprocessed results.
|
||||
"""
|
||||
preds, golds = [], []
|
||||
for idx in range(len(references)):
|
||||
pred = predictions[idx]
|
||||
unit = references[idx]['unit']
|
||||
answer = references[idx]['answer']
|
||||
choices = references[idx]['choices']
|
||||
preds.append(
|
||||
normalize_answer(extract_prediction(pred, choices),
|
||||
unit).lower())
|
||||
golds.append(normalize_answer(answer, unit).lower())
|
||||
return super()._preprocess(preds, golds)
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class TabMWPDataset(BaseDataset):
|
||||
# The TabMWP dataset contains 38,431 tabular math word problems.
|
||||
# Each question in TabMWP is aligned with a tabular context,
|
||||
# which is presented as an image, semi-structured text, and a-
|
||||
# structured table. There are two types of questions: free-text-
|
||||
# and multi-choice, and each problem is annotated with gold-
|
||||
# solutions to reveal the multi-step reasoning process.
|
||||
# To learn more about it, please follow:
|
||||
# https://github.com/lupantech/PromptPG/tree/main
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
dataset = DatasetDict()
|
||||
for split in ['dev', 'test', 'train']:
|
||||
raw_data = []
|
||||
filename = osp.join(path, f'problems_{split}.json')
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
for idx in json_data:
|
||||
problem = json_data[idx]
|
||||
question = get_question_text(problem)
|
||||
table = get_table_text(problem)
|
||||
unit = get_unit(problem)
|
||||
answer = get_answer(problem)
|
||||
choices = get_choices(problem)
|
||||
solution = get_solution_text(problem)
|
||||
raw_data.append({
|
||||
'question':
|
||||
question,
|
||||
'table':
|
||||
table,
|
||||
'test_elements': {
|
||||
'answer': answer,
|
||||
'unit': unit,
|
||||
'choices': choices
|
||||
},
|
||||
'answer':
|
||||
f'Answer: The answer is {answer}.',
|
||||
'solution':
|
||||
f'Solution: {solution}',
|
||||
'answer_and_solution':
|
||||
f'Answer: The answer is {answer}. BECAUSE: {solution}',
|
||||
'solution_and_answer':
|
||||
f'Answer: {solution} The answer is {answer}.'
|
||||
})
|
||||
dataset[split] = Dataset.from_list(raw_data)
|
||||
return dataset
|
Loading…
Reference in New Issue
Block a user