[Feature] add llama-oriented dataset configs (#82)

* add llama-oriented dataset configs

* update

* revert cvalues & update llama_example
This commit is contained in:
Leymore 2023-08-11 12:48:05 +08:00 committed by GitHub
parent e464265cf8
commit 14332e08fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 628 additions and 5 deletions

View File

@ -0,0 +1,36 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import ARCDataset
ARC_c_reader_cfg = dict(
input_columns=['question', 'textA', 'textB', 'textC', 'textD'],
output_column='answerKey')
ARC_c_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
opt: dict(
round=[
dict(role="HUMAN", prompt=f"{{question}}\nA. {{textA}}\nB. {{textB}}\nC. {{textC}}\nD. {{textD}}"),
dict(role="BOT", prompt=f"Answer: {opt}"),
]
) for opt in ["A", "B", "C", "D"]
},
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
ARC_c_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
ARC_c_datasets = [
dict(
type=ARCDataset,
abbr='ARC-c',
path='./data/ARC/ARC-c/ARC-Challenge-Dev.jsonl',
reader_cfg=ARC_c_reader_cfg,
infer_cfg=ARC_c_infer_cfg,
eval_cfg=ARC_c_eval_cfg)
]

View File

@ -0,0 +1,36 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import ARCDataset
ARC_e_reader_cfg = dict(
input_columns=['question', 'textA', 'textB', 'textC', 'textD'],
output_column='answerKey')
ARC_e_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
opt: dict(
round=[
dict(role="HUMAN", prompt=f"{{question}}\nA. {{textA}}\nB. {{textB}}\nC. {{textC}}\nD. {{textD}}"),
dict(role="BOT", prompt=f"Answer: {opt}"),
]
) for opt in ["A", "B", "C", "D"]
},
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
ARC_e_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
ARC_e_datasets = [
dict(
type=ARCDataset,
abbr='ARC-e',
path='./data/ARC/ARC-e/ARC-Easy-Dev.jsonl',
reader_cfg=ARC_e_reader_cfg,
infer_cfg=ARC_e_infer_cfg,
eval_cfg=ARC_e_eval_cfg)
]

View File

@ -0,0 +1,43 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import BoolQDataset_V3
BoolQ_reader_cfg = dict(
input_columns=["question", "passage"],
output_column="label",
test_split="train")
BoolQ_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
'false':
dict(round=[
dict(role="HUMAN", prompt="Passage: {passage}\nQuestion: {question}?"),
dict(role="BOT", prompt="Answer: No"),
]),
'true':
dict(round=[
dict(role="HUMAN", prompt="Passage: {passage}\nQuestion: {question}?"),
dict(role="BOT", prompt="Answer: Yes"),
]),
},
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer),
)
BoolQ_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
BoolQ_datasets = [
dict(
abbr="BoolQ",
type=BoolQDataset_V3,
path="./data/SuperGLUE/BoolQ/val.jsonl",
reader_cfg=BoolQ_reader_cfg,
infer_cfg=BoolQ_infer_cfg,
eval_cfg=BoolQ_eval_cfg,
)
]

View File

@ -0,0 +1,57 @@
from mmengine.config import read_base
with read_base():
from ..mmlu.mmlu_ppl_ac766d import mmlu_datasets
from ..ceval.ceval_ppl_578f8d import ceval_datasets
from ..agieval.agieval_mixed_2f14ad import agieval_datasets
from ..GaokaoBench.GaokaoBench_mixed_f2038e import GaokaoBench_datasets
from ..bbh.bbh_gen_5b92b0 import bbh_datasets
from ..humaneval.humaneval_gen_a82cae import humaneval_datasets
from ..mbpp.mbpp_gen_1e1056 import mbpp_datasets
from ..CLUE_C3.CLUE_C3_ppl_e24a31 import C3_datasets
from ..CLUE_CMRC.CLUE_CMRC_gen_1bd3c8 import CMRC_datasets
from ..CLUE_DRCD.CLUE_DRCD_gen_1bd3c8 import DRCD_datasets
from ..CLUE_afqmc.CLUE_afqmc_ppl_6507d7 import afqmc_datasets
from ..CLUE_cmnli.CLUE_cmnli_ppl_fdc6de import cmnli_datasets
from ..CLUE_ocnli.CLUE_ocnli_ppl_fdc6de import ocnli_datasets
from ..FewCLUE_bustm.FewCLUE_bustm_ppl_e53034 import bustm_datasets
from ..FewCLUE_chid.FewCLUE_chid_ppl_8f2872 import chid_datasets
from ..FewCLUE_cluewsc.FewCLUE_cluewsc_ppl_4284a0 import cluewsc_datasets
from ..FewCLUE_csl.FewCLUE_csl_ppl_841b62 import csl_datasets
from ..FewCLUE_eprstmt.FewCLUE_eprstmt_ppl_f1e631 import eprstmt_datasets
from ..FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_ppl_c08300 import ocnli_fc_datasets
from ..FewCLUE_tnews.FewCLUE_tnews_ppl_d10e8a import tnews_datasets
from ..lcsts.lcsts_gen_8ee1fe import lcsts_datasets
from ..lambada.lambada_gen_217e11 import lambada_datasets
from ..storycloze.storycloze_ppl_496661 import storycloze_datasets
from ..SuperGLUE_AX_b.SuperGLUE_AX_b_ppl_6db806 import AX_b_datasets
from ..SuperGLUE_AX_g.SuperGLUE_AX_g_ppl_66caf3 import AX_g_datasets
from ..SuperGLUE_BoolQ.SuperGLUE_BoolQ_ppl_314797 import BoolQ_datasets
from ..SuperGLUE_CB.SuperGLUE_CB_ppl_0143fe import CB_datasets
from ..SuperGLUE_COPA.SuperGLUE_COPA_ppl_9f3618 import COPA_datasets
from ..SuperGLUE_MultiRC.SuperGLUE_MultiRC_ppl_ced824 import MultiRC_datasets
from ..SuperGLUE_RTE.SuperGLUE_RTE_ppl_66caf3 import RTE_datasets
from ..SuperGLUE_ReCoRD.SuperGLUE_ReCoRD_gen_30dea0 import ReCoRD_datasets
from ..SuperGLUE_WiC.SuperGLUE_WiC_ppl_312de9 import WiC_datasets
from ..SuperGLUE_WSC.SuperGLUE_WSC_ppl_003529 import WSC_datasets
from ..race.race_ppl_5831a0 import race_datasets
from ..Xsum.Xsum_gen_31397e import Xsum_datasets
from ..gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from ..summedits.summedits_ppl_1fbeb6 import summedits_datasets
from ..math.math_gen_265cce import math_datasets
from ..TheoremQA.TheoremQA_gen_ef26ca import TheoremQA_datasets
from ..hellaswag.hellaswag_ppl_a6e128 import hellaswag_datasets
from ..ARC_e.ARC_e_ppl_2ef631 import ARC_e_datasets
from ..ARC_c.ARC_c_ppl_2ef631 import ARC_c_datasets
from ..commonsenseqa.commonsenseqa_ppl_5545e2 import commonsenseqa_datasets
from ..piqa.piqa_ppl_0cfff2 import piqa_datasets
from ..siqa.siqa_ppl_e8d8c5 import siqa_datasets
from ..strategyqa.strategyqa_gen_1180a7 import strategyqa_datasets
from ..winogrande.winogrande_ppl_55a66e import winogrande_datasets
from ..obqa.obqa_ppl_6aac9e import obqa_datasets
from ..nq.nq_gen_0356ec import nq_datasets
from ..triviaqa.triviaqa_gen_0356ec import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

View File

@ -0,0 +1,41 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import hellaswagDataset_V3
hellaswag_reader_cfg = dict(
input_columns=['query', 'A', 'B', 'C', 'D'],
output_column='gold')
hellaswag_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
"0": dict(
round=[dict(role="HUMAN", prompt="{query} {A}")]
),
"1": dict(
round=[dict(role="HUMAN", prompt="{query} {B}")]
),
"2": dict(
round=[dict(role="HUMAN", prompt="{query} {C}")]
),
"3": dict(
round=[dict(role="HUMAN", prompt="{query} {D}")]
),
}),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
hellaswag_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
hellaswag_datasets = [
dict(
abbr='hellaswag',
type=hellaswagDataset_V3,
path='./data/hellaswag/hellaswag.jsonl',
reader_cfg=hellaswag_reader_cfg,
infer_cfg=hellaswag_infer_cfg,
eval_cfg=hellaswag_eval_cfg)
]

View File

@ -0,0 +1,35 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HFDataset, HumanEvaluator, humaneval_postprocess
humaneval_reader_cfg = dict(
input_columns=['prompt'], output_column='task_id', train_split='test')
# TODO: allow empty output-column
humaneval_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt='{prompt}'),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=512))
humaneval_eval_cfg = dict(
evaluator=dict(type=HumanEvaluator),
pred_role='BOT',
k=[1, 10, 100], # the parameter only for humaneval
pred_postprocessor=dict(type=humaneval_postprocess),
)
humaneval_datasets = [
dict(
type=HFDataset,
path='openai_humaneval',
reader_cfg=humaneval_reader_cfg,
infer_cfg=humaneval_infer_cfg,
eval_cfg=humaneval_eval_cfg)
]

View File

@ -0,0 +1,61 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import NaturalQuestionDataset, NQEvaluator
nq_datasets = []
for k in [0, 1, 5]:
nq_reader_cfg = dict(
input_columns=['question'], output_column='answer', train_split='dev')
if k == 0:
nq_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='Answer these questions, your answer should be as simple as possible, start your answer with the prompt \'The answer is \'.\nQ: {question}?'),
dict(role='BOT', prompt='A:'),
]
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50)
)
else:
nq_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='Answer the question, your answer should be as simple as possible, start your answer with the prompt \'The answer is \'.\nQ: {question}?'),
dict(role='BOT', prompt='A: The answer is {answer}.\n'),
]
),
),
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin="</E>",
round=[
dict(role='HUMAN', prompt='Answer the question, your answer should be as simple as possible, start your answer with the prompt \'The answer is \'.\nQ: {question}?'),
dict(role='BOT', prompt='A:'),
]
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
)
nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT")
nq_datasets.append(
dict(
type=NaturalQuestionDataset,
abbr='nq' if k == 0 else f'nq_{k}shot',
path='./data/nq/',
reader_cfg=nq_reader_cfg,
infer_cfg=nq_infer_cfg,
eval_cfg=nq_eval_cfg)
)

View File

@ -0,0 +1,43 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import OBQADataset_V2
obqa_reader_cfg = dict(
input_columns=['question_stem', 'A', 'B', 'C', 'D', 'fact1'],
output_column="answerKey"
)
obqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
ans: dict(
round=[
dict(
role="HUMAN",
prompt="We know the fact that {fact1}.\nQuestion: {question_stem}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n"
),
dict(role="BOT", prompt=f"Answer: {ans}"),
], )
for ans in ['A', 'B', 'C', 'D']
}
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer),
)
obqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
obqa_datasets = [
dict(
abbr='openbookqa_fact',
type=OBQADataset_V2,
path='openbookqa',
name='additional',
split='test',
reader_cfg=obqa_reader_cfg,
infer_cfg=obqa_infer_cfg,
eval_cfg=obqa_eval_cfg,
),
]

View File

@ -0,0 +1,37 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import piqaDataset_V3
piqa_reader_cfg = dict(
input_columns=['goal', 'sol1', 'sol2'],
output_column='label',
test_split='validation')
piqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
"0": dict(
round=[dict(role="HUMAN", prompt="{goal} {sol1}")]
),
"1": dict(
round=[dict(role="HUMAN", prompt="{goal} {sol2}")]
),
}
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
piqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
piqa_datasets = [
dict(
abbr='piqa',
type=piqaDataset_V3,
path='piqa',
reader_cfg=piqa_reader_cfg,
infer_cfg=piqa_infer_cfg,
eval_cfg=piqa_eval_cfg)
]

View File

@ -0,0 +1,45 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import RaceDataset
race_reader_cfg = dict(
input_columns=['article', 'question', 'A', 'B', 'C', 'D'],
output_column='answer')
race_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
ans: dict(
round=[
dict(role="HUMAN", prompt="Article:\n{article}\nQuestion:\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}"),
dict(role="BOT", prompt=f'Answer: {ans}'),
]
)
for ans in ['A', 'B', 'C', 'D']
}),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
race_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
race_datasets = [
dict(
type=RaceDataset,
abbr='race-middle',
path='race',
name='middle',
reader_cfg=race_reader_cfg,
infer_cfg=race_infer_cfg,
eval_cfg=race_eval_cfg),
dict(
type=RaceDataset,
abbr='race-high',
path='race',
name='high',
reader_cfg=race_reader_cfg,
infer_cfg=race_infer_cfg,
eval_cfg=race_eval_cfg)
]

View File

@ -0,0 +1,45 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import HFDataset
siqa_reader_cfg = dict(
input_columns=['context', 'question', 'answerA', 'answerB', 'answerC'],
output_column='label',
test_split='validation')
siqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
"1":
dict(round=[
dict(role='HUMAN', prompt="{context}\nQuestion: {question}\nA. {answerA}\nB. {answerB}\nC. {answerC}"),
dict(role='BOT', prompt="Answer: A")
]),
"2":
dict(round=[
dict(role='HUMAN', prompt="{context}\nQuestion: {question}\nA. {answerA}\nB. {answerB}\nC. {answerC}"),
dict(role='BOT', prompt="Answer: B")
]),
"3":
dict(round=[
dict(role='HUMAN', prompt="{context}\nQuestion: {question}\nA. {answerA}\nB. {answerB}\nC. {answerC}"),
dict(role='BOT', prompt="Answer: C")
]),
}),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer))
siqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
siqa_datasets = [
dict(
abbr="siqa",
type=HFDataset,
path='social_i_qa',
reader_cfg=siqa_reader_cfg,
infer_cfg=siqa_infer_cfg,
eval_cfg=siqa_eval_cfg)
]

View File

@ -0,0 +1,62 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import TriviaQADataset, TriviaQAEvaluator
triviaqa_datasets = []
for k in [0, 1, 5]:
triviaqa_reader_cfg = dict(
input_columns=['question'], output_column='answer', train_split='test', test_split='dev')
if k == 0:
triviaqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='Answer these questions, your answer should be as simple as possible, start your answer with the prompt \'The answer is \'.\nQ: {question}?'),
dict(role='BOT', prompt='A:'),
]
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50)
)
else:
triviaqa_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='Answer the question, your answer should be as simple as possible, start your answer with the prompt \'The answer is \'.\nQ: {question}?'),
dict(role='BOT', prompt='A: The answer is {answer}.\n'),
]
),
),
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin="</E>",
round=[
dict(role='HUMAN', prompt='Answer the question, your answer should be as simple as possible, start your answer with the prompt \'The answer is \'.\nQ: {question}?'),
dict(role='BOT', prompt='A:'),
]
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
)
triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT")
triviaqa_datasets.append(
dict(
type=TriviaQADataset,
abbr='triviaqa' if k == 0 else f'triviaqa_{k}shot',
path='./data/triviaqa/',
reader_cfg=triviaqa_reader_cfg,
infer_cfg=triviaqa_infer_cfg,
eval_cfg=triviaqa_eval_cfg)
)

View File

@ -1,8 +1,7 @@
from mmengine.config import read_base
with read_base():
from .datasets.piqa.piqa_ppl import piqa_datasets
from .datasets.siqa.siqa_gen import siqa_datasets
from .datasets.collections.base_medium_llama import piqa_datasets, siqa_datasets
from .models.hf_llama_7b import models

View File

@ -12,7 +12,6 @@ class BoolQDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
@ -20,7 +19,6 @@ class BoolQDataset(BaseDataset):
example['answer'] = 1
else:
example['answer'] = 0
return example
dataset = dataset.map(preprocess)
@ -39,3 +37,20 @@ class BoolQDataset_V2(BaseDataset):
line['label'] = {'true': 'A', 'false': 'B'}[line['label']]
dataset.append(line)
return Dataset.from_list(dataset)
@LOAD_DATASET.register_module()
class BoolQDataset_V3(BaseDataset):
@staticmethod
def load(path):
dataset = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
line['passage'] = ' -- '.join(
line['passage'].split(' -- ')[1:])
line['question'] = line['question'][0].upper(
) + line['question'][1:]
dataset.append(line)
return Dataset.from_list(dataset)

View File

@ -1,4 +1,6 @@
from datasets import load_dataset
import json
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
@ -39,3 +41,24 @@ class hellaswagDataset_V2(BaseDataset):
dataset = dataset.map(preprocess).remove_columns(['endings'])
return dataset
@LOAD_DATASET.register_module()
class hellaswagDataset_V3(BaseDataset):
@staticmethod
def load(path):
dataset = []
with open(path, 'r') as f:
for line in f:
data = json.loads(line)
dataset.append({
'query': data['query'],
'A': data['choices'][0],
'B': data['choices'][1],
'C': data['choices'][2],
'D': data['choices'][3],
'gold': data['gold'],
})
dataset = Dataset.from_list(dataset)
return dataset

View File

@ -19,3 +19,23 @@ class OBQADataset(BaseDataset):
dataset = dataset.map(pre_process).remove_columns(['id', 'choices'])
return dataset
@LOAD_DATASET.register_module()
class OBQADataset_V2(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def pre_process(example):
example['A'] = example['choices']['text'][0]
example['B'] = example['choices']['text'][1]
example['C'] = example['choices']['text'][2]
example['D'] = example['choices']['text'][3]
if not example['question_stem'].endswith('?'):
example['question_stem'] += ' what?'
return example
dataset = dataset.map(pre_process).remove_columns(['id', 'choices'])
return dataset

View File

@ -23,3 +23,28 @@ class piqaDataset_V2(BaseDataset):
dataset = dataset.map(preprocess)
return dataset
@LOAD_DATASET.register_module()
class piqaDataset_V3(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
example['goal'] = example['goal'][0].upper() + example['goal'][1:]
if example['goal'].endswith('?') or example['goal'].endswith('.'):
example['sol1'] = example['sol1'][0].upper(
) + example['sol1'][1:]
example['sol2'] = example['sol2'][0].upper(
) + example['sol2'][1:]
else:
example['sol1'] = example['sol1'][0].lower(
) + example['sol1'][1:]
example['sol2'] = example['sol2'][0].lower(
) + example['sol2'][1:]
return example
dataset = dataset.map(preprocess)
return dataset