[Feature] Add Xiezhi SQuAD2.0 ANLI (#101)

* add Xiezhi SQuAD2.0 ANLI; update WSC

* update

* update

* update doc string
This commit is contained in:
Leymore 2023-08-10 14:04:18 +08:00 committed by GitHub
parent a205629ff3
commit e7fc54baf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 468 additions and 1 deletions

View File

@ -1,4 +1,4 @@
from mmengine.config import read_base
with read_base():
from .SuperGLUE_WSC_ppl_d0f531 import WSC_datasets # noqa: F401, F403
from .SuperGLUE_WSC_ppl_cbf31c import WSC_datasets # noqa: F401, F403

View File

@ -0,0 +1,49 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import WSCDataset_V3
WSC_reader_cfg = dict(
input_columns=["span1", "span2", "text"],
output_column="label",
)
WSC_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
'A':
dict(round=[
dict(
role="HUMAN",
prompt="Passage: {text}\nDoes the pronoun # {span2} # refer to * {span1} *?\nA. Yes\nB. No\nAnseer: "
),
dict(role='BOT', prompt='A'),
]),
'B':
dict(round=[
dict(
role="HUMAN",
prompt="Passage: {text}\nDoes the pronoun # {span2} # refer to * {span1} *?\nA. Yes\nB. No\nAnseer: "
),
dict(role='BOT', prompt='B'),
]),
},
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer),
)
WSC_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
WSC_datasets = [
dict(
abbr="WSC",
type=WSCDataset_V3,
path="./data/SuperGLUE/WSC/val.jsonl",
reader_cfg=WSC_reader_cfg,
infer_cfg=WSC_infer_cfg,
eval_cfg=WSC_eval_cfg,
)
]

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .anli_gen_fc7328 import anli_datasets # noqa: F401, F403

View File

@ -0,0 +1,42 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AnliDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess
anli_datasets = []
for _split in ['R1', 'R2', 'R3']:
anli_reader_cfg = dict(
input_columns=["context", "hypothesis"],
output_column="label",
)
anli_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role="HUMAN", prompt="{context}\n{hypothesis}\nQuestion: What is the relation between the two sentences?\nA. Contradiction\nB. Entailment\nC. Neutral\nAnswer: "),
dict(role="BOT", prompt="{label}"),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
anli_eval_cfg = dict(evaluator=dict(type=AccEvaluator),
pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess))
anli_datasets.append(
dict(
type=AnliDataset,
abbr=f"anli-{_split}",
path=f"data/anli/anli_v1.0/{_split}/dev.jsonl",
reader_cfg=anli_reader_cfg,
infer_cfg=anli_infer_cfg,
eval_cfg=anli_eval_cfg,
)
)

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .anli_ppl_1d290e import anli_datasets # noqa: F401, F403

View File

@ -0,0 +1,50 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import AnliDataset
anli_datasets = []
for _split in ['R1', 'R2', 'R3']:
anli_reader_cfg = dict(
input_columns=["context", "hypothesis"],
output_column="label",
)
anli_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template={
"A":
dict(round=[
dict(role="HUMAN", prompt="{context}\n{hypothesis}\What is the relation between the two sentences?"),
dict(role="BOT", prompt="Contradiction"),
]),
"B":
dict(round=[
dict(role="HUMAN", prompt="{context}\n{hypothesis}\What is the relation between the two sentences?"),
dict(role="BOT", prompt="Entailment"),
]),
"C":
dict(round=[
dict(role="HUMAN", prompt="{context}\n{hypothesis}\What is the relation between the two sentences?"),
dict(role="BOT", prompt="Neutral"),
]),
},
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer),
)
anli_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
anli_datasets.append(
dict(
type=AnliDataset,
abbr=f"anli-{_split}",
path=f"data/anli/anli_v1.0/{_split}/dev.jsonl",
reader_cfg=anli_reader_cfg,
infer_cfg=anli_infer_cfg,
eval_cfg=anli_eval_cfg,
)
)

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .squad20_gen_1710bc import squad20_datasets # noqa: F401, F403

View File

@ -0,0 +1,32 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import SQuAD20Dataset, SQuAD20Evaluator
squad20_reader_cfg = dict(
input_columns=['context', 'question'],
output_column='answers')
squad20_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{context}\nAccording to the above passage, answer the following question. If it is impossible to answer according to the passage, answer `impossible to answer`:\nQuestion: {question}'),
dict(role='BOT', prompt='Answer:'),
], )),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50))
squad20_eval_cfg = dict(
evaluator=dict(type=SQuAD20Evaluator), pred_role='BOT')
squad20_datasets = [
dict(
type=SQuAD20Dataset,
abbr='squad2.0',
path='./data/SQuAD2.0/dev-v2.0.json',
reader_cfg=squad20_reader_cfg,
infer_cfg=squad20_infer_cfg,
eval_cfg=squad20_eval_cfg)
]

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .xiezhi_gen_b86cf5 import xiezhi_datasets # noqa: F401, F403

View File

@ -0,0 +1,50 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import XiezhiDataset, XiezhiRetriever
from opencompass.utils.text_postprocessors import first_capital_postprocess
xiezhi_datasets = []
for split in ["spec_eng", "spec_chn", "inter_eng", "inter_chn"]:
if 'chn' in split:
q_hint, a_hint = "题目", "答案"
else:
q_hint, a_hint = "Question", "Answer"
xiezhi_reader_cfg = dict(
input_columns=["question", "A", "B", "C", "D", "labels"],
output_column="answer",
train_split="train",
test_split='test',
)
xiezhi_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template=dict(
begin="</E>",
round=[
dict(role="HUMAN", prompt=f"{q_hint}: {{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\n{a_hint}: "),
dict(role="BOT", prompt="{answer}"),
]
),
ice_token="</E>",
),
retriever=dict(type=XiezhiRetriever, ice_num=3),
inferencer=dict(type=GenInferencer),
)
xiezhi_eval_cfg = dict(evaluator=dict(type=AccEvaluator),
pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess))
xiezhi_datasets.append(
dict(
type=XiezhiDataset,
abbr=f"xiezhi-{split}",
path="./data/xiezhi/",
name="xiezhi_" + split,
reader_cfg=xiezhi_reader_cfg,
infer_cfg=xiezhi_infer_cfg,
eval_cfg=xiezhi_eval_cfg,
))

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .xiezhi_ppl_ea6bd7 import xiezhi_datasets # noqa: F401, F403

View File

@ -0,0 +1,49 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import XiezhiDataset, XiezhiRetriever
xiezhi_datasets = []
for split in ["spec_eng", "spec_chn", "inter_eng", "inter_chn"]:
if 'chn' in split:
q_hint, a_hint = "题目", "答案"
else:
q_hint, a_hint = "Question", "Answer"
xiezhi_reader_cfg = dict(
input_columns=["question", "A", "B", "C", "D", "labels"],
output_column="answer",
train_split="train",
test_split='test',
)
xiezhi_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template={
answer: dict(
begin="</E>",
round=[
dict(role="HUMAN", prompt=f"{q_hint}: {{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}"),
dict(role="BOT", prompt=f"{a_hint}: {answer}"),
])
for answer in ["A", "B", "C", "D"]
},
ice_token="</E>",
),
retriever=dict(type=XiezhiRetriever, ice_num=3),
inferencer=dict(type=PPLInferencer),
)
xiezhi_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
xiezhi_datasets.append(
dict(
type=XiezhiDataset,
abbr=f"xiezhi-{split}",
path="./data/xiezhi/",
name="xiezhi_" + split,
reader_cfg=xiezhi_reader_cfg,
infer_cfg=xiezhi_infer_cfg,
eval_cfg=xiezhi_eval_cfg,
))

View File

@ -1,5 +1,6 @@
from .afqmcd import * # noqa: F401, F403
from .agieval import * # noqa: F401, F403
from .anli import AnliDataset # noqa: F401, F403
from .arc import * # noqa: F401, F403
from .ax import * # noqa: F401, F403
from .bbh import * # noqa: F401, F403
@ -48,6 +49,7 @@ from .realtoxicprompts import * # noqa: F401, F403
from .record import * # noqa: F401, F403
from .safety import * # noqa: F401, F403
from .siqa import * # noqa: F401, F403
from .squad20 import SQuAD20Dataset, SQuAD20Evaluator # noqa: F401, F403
from .storycloze import * # noqa: F401, F403
from .strategyqa import * # noqa: F401, F403
from .summedits import * # noqa: F401, F403
@ -63,5 +65,6 @@ from .winograd import * # noqa: F401, F403
from .winogrande import * # noqa: F401, F403
from .wsc import * # noqa: F401, F403
from .xcopa import * # noqa: F401, F403
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
from .xlsum import * # noqa: F401, F403
from .xsum import * # noqa: F401, F403

View File

@ -0,0 +1,18 @@
import json
from datasets import Dataset
from .base import BaseDataset
class AnliDataset(BaseDataset):
@staticmethod
def load(path: str):
dataset = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
line['label'] = {'c': 'A', 'e': 'B', 'n': 'C'}[line['label']]
dataset.append(line)
return Dataset.from_list(dataset)

View File

@ -0,0 +1,66 @@
import json
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.utils.text_postprocessors import general_postprocess
from .base import BaseDataset
class SQuAD20Dataset(BaseDataset):
@staticmethod
def load(path: str):
with open(path, 'r') as f:
data = json.load(f)
data = data['data']
dataset = []
for article in data:
for paragraph in article['paragraphs']:
for qa in paragraph['qas']:
is_impossible = qa['is_impossible']
if not is_impossible:
answers = list(
set([answer['text'] for answer in qa['answers']]))
else:
answers = list(
set([
answer['text']
for answer in qa['plausible_answers']
]))
answers += ['impossible to answer']
item = {
'context': paragraph['context'],
'question': qa['question'],
'answers': answers,
}
dataset.append(item)
dataset = Dataset.from_list(dataset)
return dataset
class SQuAD20Evaluator(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
processed_predictions = []
for prediction in predictions:
prediction = prediction.split('\n')[0].lower()
if 'answer is' in prediction:
prediction = prediction.split('answer is')[-1]
prediction = general_postprocess(prediction)
processed_predictions.append(prediction)
processed_answers = [[general_postprocess(j).lower() for j in i]
for i in references]
cnt = 0
for pred, cand_ans in zip(processed_predictions, processed_answers):
cnt += int(any([cand == pred for cand in cand_ans]))
score = cnt / len(predictions) * 100
return {'score': score}

View File

@ -0,0 +1,88 @@
import json
import os.path as osp
from typing import Optional
from datasets import Dataset, DatasetDict
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from .base import BaseDataset
class XiezhiDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
filename = osp.join(path, name, 'xiezhi.v1.json')
if 'chn' in name:
train_filename = osp.join(path, 'xiezhi_train_chn',
'xiezhi.v1.json')
else:
train_filename = osp.join(path, 'xiezhi_train_eng',
'xiezhi.v1.json')
for split, filename in [['train', train_filename], ['test', filename]]:
raw_data = []
with open(filename, encoding='utf-8') as f:
for line in f:
data = json.loads(line)
if data['options'].endswith("\"\n"):
data['options'] = data['options'][:-2]
options = data['options'].split('\n')
if len(options) != 4:
continue
answer = 'ABCD'[options.index(data['answer'])]
# The longer the label, the more fine-grained the concept
labels = sorted(
data['labels' if 'chn' in name else 'label'],
key=lambda x: len(x),
reverse=True)
raw_data.append({
'question': data['question'],
'A': options[0],
'B': options[1],
'C': options[2],
'D': options[3],
'labels': labels,
'answer': answer,
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
class XiezhiRetriever(BaseRetriever):
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
def retrieve(self):
"""Retrieve in-context examples for each test case.
For each one of the in-context example, there is a list of label,
indicating the categories to which the example is related. For each one
of the test case, there is also a list of label, indicating the
categories. This retriever will retrieve the in-context examples that
share at least one label with the test case.
"""
label2indice = {}
for index, item in enumerate(self.index_ds):
for label in item['labels']:
if label not in label2indice:
label2indice[label] = []
label2indice[label].append(index)
rtr_idx_list = []
for index in trange(len(self.test_ds),
disable=not self.is_main_process):
id_list = []
for label in self.test_ds[index]['labels']:
if len(id_list) < self.ice_num:
id_list += label2indice[label]
else:
break
rtr_idx_list.append(id_list[:self.ice_num])
return rtr_idx_list