mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add Xiezhi SQuAD2.0 ANLI (#101)
* add Xiezhi SQuAD2.0 ANLI; update WSC * update * update * update doc string
This commit is contained in:
parent
a205629ff3
commit
e7fc54baf1
@ -1,4 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .SuperGLUE_WSC_ppl_d0f531 import WSC_datasets # noqa: F401, F403
|
||||
from .SuperGLUE_WSC_ppl_cbf31c import WSC_datasets # noqa: F401, F403
|
||||
|
49
configs/datasets/SuperGLUE_WSC/SuperGLUE_WSC_ppl_cbf31c.py
Normal file
49
configs/datasets/SuperGLUE_WSC/SuperGLUE_WSC_ppl_cbf31c.py
Normal file
@ -0,0 +1,49 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import WSCDataset_V3
|
||||
|
||||
WSC_reader_cfg = dict(
|
||||
input_columns=["span1", "span2", "text"],
|
||||
output_column="label",
|
||||
)
|
||||
|
||||
WSC_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
'A':
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="Passage: {text}\nDoes the pronoun # {span2} # refer to * {span1} *?\nA. Yes\nB. No\nAnseer: "
|
||||
),
|
||||
dict(role='BOT', prompt='A'),
|
||||
]),
|
||||
'B':
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="Passage: {text}\nDoes the pronoun # {span2} # refer to * {span1} *?\nA. Yes\nB. No\nAnseer: "
|
||||
),
|
||||
dict(role='BOT', prompt='B'),
|
||||
]),
|
||||
},
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
WSC_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
||||
WSC_datasets = [
|
||||
dict(
|
||||
abbr="WSC",
|
||||
type=WSCDataset_V3,
|
||||
path="./data/SuperGLUE/WSC/val.jsonl",
|
||||
reader_cfg=WSC_reader_cfg,
|
||||
infer_cfg=WSC_infer_cfg,
|
||||
eval_cfg=WSC_eval_cfg,
|
||||
)
|
||||
]
|
4
configs/datasets/anli/anli_gen.py
Normal file
4
configs/datasets/anli/anli_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .anli_gen_fc7328 import anli_datasets # noqa: F401, F403
|
42
configs/datasets/anli/anli_gen_fc7328.py
Normal file
42
configs/datasets/anli/anli_gen_fc7328.py
Normal file
@ -0,0 +1,42 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import AnliDataset
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
anli_datasets = []
|
||||
for _split in ['R1', 'R2', 'R3']:
|
||||
anli_reader_cfg = dict(
|
||||
input_columns=["context", "hypothesis"],
|
||||
output_column="label",
|
||||
)
|
||||
|
||||
anli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role="HUMAN", prompt="{context}\n{hypothesis}\nQuestion: What is the relation between the two sentences?\nA. Contradiction\nB. Entailment\nC. Neutral\nAnswer: "),
|
||||
dict(role="BOT", prompt="{label}"),
|
||||
]
|
||||
),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
anli_eval_cfg = dict(evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type=first_capital_postprocess))
|
||||
|
||||
anli_datasets.append(
|
||||
dict(
|
||||
type=AnliDataset,
|
||||
abbr=f"anli-{_split}",
|
||||
path=f"data/anli/anli_v1.0/{_split}/dev.jsonl",
|
||||
reader_cfg=anli_reader_cfg,
|
||||
infer_cfg=anli_infer_cfg,
|
||||
eval_cfg=anli_eval_cfg,
|
||||
)
|
||||
)
|
4
configs/datasets/anli/anli_ppl.py
Normal file
4
configs/datasets/anli/anli_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .anli_ppl_1d290e import anli_datasets # noqa: F401, F403
|
50
configs/datasets/anli/anli_ppl_1d290e.py
Normal file
50
configs/datasets/anli/anli_ppl_1d290e.py
Normal file
@ -0,0 +1,50 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import AnliDataset
|
||||
|
||||
anli_datasets = []
|
||||
for _split in ['R1', 'R2', 'R3']:
|
||||
anli_reader_cfg = dict(
|
||||
input_columns=["context", "hypothesis"],
|
||||
output_column="label",
|
||||
)
|
||||
|
||||
anli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
"A":
|
||||
dict(round=[
|
||||
dict(role="HUMAN", prompt="{context}\n{hypothesis}\What is the relation between the two sentences?"),
|
||||
dict(role="BOT", prompt="Contradiction"),
|
||||
]),
|
||||
"B":
|
||||
dict(round=[
|
||||
dict(role="HUMAN", prompt="{context}\n{hypothesis}\What is the relation between the two sentences?"),
|
||||
dict(role="BOT", prompt="Entailment"),
|
||||
]),
|
||||
"C":
|
||||
dict(round=[
|
||||
dict(role="HUMAN", prompt="{context}\n{hypothesis}\What is the relation between the two sentences?"),
|
||||
dict(role="BOT", prompt="Neutral"),
|
||||
]),
|
||||
},
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
anli_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
||||
anli_datasets.append(
|
||||
dict(
|
||||
type=AnliDataset,
|
||||
abbr=f"anli-{_split}",
|
||||
path=f"data/anli/anli_v1.0/{_split}/dev.jsonl",
|
||||
reader_cfg=anli_reader_cfg,
|
||||
infer_cfg=anli_infer_cfg,
|
||||
eval_cfg=anli_eval_cfg,
|
||||
)
|
||||
)
|
4
configs/datasets/squad20/squad20_gen.py
Normal file
4
configs/datasets/squad20/squad20_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .squad20_gen_1710bc import squad20_datasets # noqa: F401, F403
|
32
configs/datasets/squad20/squad20_gen_1710bc.py
Normal file
32
configs/datasets/squad20/squad20_gen_1710bc.py
Normal file
@ -0,0 +1,32 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import SQuAD20Dataset, SQuAD20Evaluator
|
||||
|
||||
squad20_reader_cfg = dict(
|
||||
input_columns=['context', 'question'],
|
||||
output_column='answers')
|
||||
|
||||
squad20_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='{context}\nAccording to the above passage, answer the following question. If it is impossible to answer according to the passage, answer `impossible to answer`:\nQuestion: {question}'),
|
||||
dict(role='BOT', prompt='Answer:'),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=50))
|
||||
|
||||
squad20_eval_cfg = dict(
|
||||
evaluator=dict(type=SQuAD20Evaluator), pred_role='BOT')
|
||||
|
||||
squad20_datasets = [
|
||||
dict(
|
||||
type=SQuAD20Dataset,
|
||||
abbr='squad2.0',
|
||||
path='./data/SQuAD2.0/dev-v2.0.json',
|
||||
reader_cfg=squad20_reader_cfg,
|
||||
infer_cfg=squad20_infer_cfg,
|
||||
eval_cfg=squad20_eval_cfg)
|
||||
]
|
4
configs/datasets/xiezhi/xiezhi_gen.py
Normal file
4
configs/datasets/xiezhi/xiezhi_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .xiezhi_gen_b86cf5 import xiezhi_datasets # noqa: F401, F403
|
50
configs/datasets/xiezhi/xiezhi_gen_b86cf5.py
Normal file
50
configs/datasets/xiezhi/xiezhi_gen_b86cf5.py
Normal file
@ -0,0 +1,50 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import XiezhiDataset, XiezhiRetriever
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
xiezhi_datasets = []
|
||||
|
||||
for split in ["spec_eng", "spec_chn", "inter_eng", "inter_chn"]:
|
||||
if 'chn' in split:
|
||||
q_hint, a_hint = "题目", "答案"
|
||||
else:
|
||||
q_hint, a_hint = "Question", "Answer"
|
||||
|
||||
xiezhi_reader_cfg = dict(
|
||||
input_columns=["question", "A", "B", "C", "D", "labels"],
|
||||
output_column="answer",
|
||||
train_split="train",
|
||||
test_split='test',
|
||||
)
|
||||
xiezhi_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(role="HUMAN", prompt=f"{q_hint}: {{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\n{a_hint}: "),
|
||||
dict(role="BOT", prompt="{answer}"),
|
||||
]
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=XiezhiRetriever, ice_num=3),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
xiezhi_eval_cfg = dict(evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type=first_capital_postprocess))
|
||||
|
||||
xiezhi_datasets.append(
|
||||
dict(
|
||||
type=XiezhiDataset,
|
||||
abbr=f"xiezhi-{split}",
|
||||
path="./data/xiezhi/",
|
||||
name="xiezhi_" + split,
|
||||
reader_cfg=xiezhi_reader_cfg,
|
||||
infer_cfg=xiezhi_infer_cfg,
|
||||
eval_cfg=xiezhi_eval_cfg,
|
||||
))
|
4
configs/datasets/xiezhi/xiezhi_ppl.py
Normal file
4
configs/datasets/xiezhi/xiezhi_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .xiezhi_ppl_ea6bd7 import xiezhi_datasets # noqa: F401, F403
|
49
configs/datasets/xiezhi/xiezhi_ppl_ea6bd7.py
Normal file
49
configs/datasets/xiezhi/xiezhi_ppl_ea6bd7.py
Normal file
@ -0,0 +1,49 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import XiezhiDataset, XiezhiRetriever
|
||||
|
||||
xiezhi_datasets = []
|
||||
|
||||
for split in ["spec_eng", "spec_chn", "inter_eng", "inter_chn"]:
|
||||
if 'chn' in split:
|
||||
q_hint, a_hint = "题目", "答案"
|
||||
else:
|
||||
q_hint, a_hint = "Question", "Answer"
|
||||
|
||||
xiezhi_reader_cfg = dict(
|
||||
input_columns=["question", "A", "B", "C", "D", "labels"],
|
||||
output_column="answer",
|
||||
train_split="train",
|
||||
test_split='test',
|
||||
)
|
||||
xiezhi_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
answer: dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(role="HUMAN", prompt=f"{q_hint}: {{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}"),
|
||||
dict(role="BOT", prompt=f"{a_hint}: {answer}"),
|
||||
])
|
||||
for answer in ["A", "B", "C", "D"]
|
||||
},
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=XiezhiRetriever, ice_num=3),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
xiezhi_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
xiezhi_datasets.append(
|
||||
dict(
|
||||
type=XiezhiDataset,
|
||||
abbr=f"xiezhi-{split}",
|
||||
path="./data/xiezhi/",
|
||||
name="xiezhi_" + split,
|
||||
reader_cfg=xiezhi_reader_cfg,
|
||||
infer_cfg=xiezhi_infer_cfg,
|
||||
eval_cfg=xiezhi_eval_cfg,
|
||||
))
|
@ -1,5 +1,6 @@
|
||||
from .afqmcd import * # noqa: F401, F403
|
||||
from .agieval import * # noqa: F401, F403
|
||||
from .anli import AnliDataset # noqa: F401, F403
|
||||
from .arc import * # noqa: F401, F403
|
||||
from .ax import * # noqa: F401, F403
|
||||
from .bbh import * # noqa: F401, F403
|
||||
@ -48,6 +49,7 @@ from .realtoxicprompts import * # noqa: F401, F403
|
||||
from .record import * # noqa: F401, F403
|
||||
from .safety import * # noqa: F401, F403
|
||||
from .siqa import * # noqa: F401, F403
|
||||
from .squad20 import SQuAD20Dataset, SQuAD20Evaluator # noqa: F401, F403
|
||||
from .storycloze import * # noqa: F401, F403
|
||||
from .strategyqa import * # noqa: F401, F403
|
||||
from .summedits import * # noqa: F401, F403
|
||||
@ -63,5 +65,6 @@ from .winograd import * # noqa: F401, F403
|
||||
from .winogrande import * # noqa: F401, F403
|
||||
from .wsc import * # noqa: F401, F403
|
||||
from .xcopa import * # noqa: F401, F403
|
||||
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
|
||||
from .xlsum import * # noqa: F401, F403
|
||||
from .xsum import * # noqa: F401, F403
|
||||
|
18
opencompass/datasets/anli.py
Normal file
18
opencompass/datasets/anli.py
Normal file
@ -0,0 +1,18 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
class AnliDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
dataset = []
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
line['label'] = {'c': 'A', 'e': 'B', 'n': 'C'}[line['label']]
|
||||
dataset.append(line)
|
||||
return Dataset.from_list(dataset)
|
66
opencompass/datasets/squad20.py
Normal file
66
opencompass/datasets/squad20.py
Normal file
@ -0,0 +1,66 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.utils.text_postprocessors import general_postprocess
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
class SQuAD20Dataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
data = data['data']
|
||||
dataset = []
|
||||
for article in data:
|
||||
for paragraph in article['paragraphs']:
|
||||
for qa in paragraph['qas']:
|
||||
is_impossible = qa['is_impossible']
|
||||
if not is_impossible:
|
||||
answers = list(
|
||||
set([answer['text'] for answer in qa['answers']]))
|
||||
else:
|
||||
answers = list(
|
||||
set([
|
||||
answer['text']
|
||||
for answer in qa['plausible_answers']
|
||||
]))
|
||||
answers += ['impossible to answer']
|
||||
item = {
|
||||
'context': paragraph['context'],
|
||||
'question': qa['question'],
|
||||
'answers': answers,
|
||||
}
|
||||
dataset.append(item)
|
||||
dataset = Dataset.from_list(dataset)
|
||||
return dataset
|
||||
|
||||
|
||||
class SQuAD20Evaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error': 'predictions and references have different '
|
||||
'length'
|
||||
}
|
||||
processed_predictions = []
|
||||
for prediction in predictions:
|
||||
prediction = prediction.split('\n')[0].lower()
|
||||
if 'answer is' in prediction:
|
||||
prediction = prediction.split('answer is')[-1]
|
||||
prediction = general_postprocess(prediction)
|
||||
processed_predictions.append(prediction)
|
||||
processed_answers = [[general_postprocess(j).lower() for j in i]
|
||||
for i in references]
|
||||
|
||||
cnt = 0
|
||||
for pred, cand_ans in zip(processed_predictions, processed_answers):
|
||||
cnt += int(any([cand == pred for cand in cand_ans]))
|
||||
score = cnt / len(predictions) * 100
|
||||
|
||||
return {'score': score}
|
88
opencompass/datasets/xiezhi.py
Normal file
88
opencompass/datasets/xiezhi.py
Normal file
@ -0,0 +1,88 @@
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Optional
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from tqdm import trange
|
||||
|
||||
from opencompass.openicl.icl_retriever import BaseRetriever
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
class XiezhiDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str):
|
||||
dataset = DatasetDict()
|
||||
filename = osp.join(path, name, 'xiezhi.v1.json')
|
||||
if 'chn' in name:
|
||||
train_filename = osp.join(path, 'xiezhi_train_chn',
|
||||
'xiezhi.v1.json')
|
||||
else:
|
||||
train_filename = osp.join(path, 'xiezhi_train_eng',
|
||||
'xiezhi.v1.json')
|
||||
for split, filename in [['train', train_filename], ['test', filename]]:
|
||||
raw_data = []
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
if data['options'].endswith("\"\n"):
|
||||
data['options'] = data['options'][:-2]
|
||||
options = data['options'].split('\n')
|
||||
if len(options) != 4:
|
||||
continue
|
||||
answer = 'ABCD'[options.index(data['answer'])]
|
||||
# The longer the label, the more fine-grained the concept
|
||||
labels = sorted(
|
||||
data['labels' if 'chn' in name else 'label'],
|
||||
key=lambda x: len(x),
|
||||
reverse=True)
|
||||
raw_data.append({
|
||||
'question': data['question'],
|
||||
'A': options[0],
|
||||
'B': options[1],
|
||||
'C': options[2],
|
||||
'D': options[3],
|
||||
'labels': labels,
|
||||
'answer': answer,
|
||||
})
|
||||
dataset[split] = Dataset.from_list(raw_data)
|
||||
return dataset
|
||||
|
||||
|
||||
class XiezhiRetriever(BaseRetriever):
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
ice_separator: Optional[str] = '\n',
|
||||
ice_eos_token: Optional[str] = '\n',
|
||||
ice_num: Optional[int] = 1) -> None:
|
||||
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
|
||||
|
||||
def retrieve(self):
|
||||
"""Retrieve in-context examples for each test case.
|
||||
|
||||
For each one of the in-context example, there is a list of label,
|
||||
indicating the categories to which the example is related. For each one
|
||||
of the test case, there is also a list of label, indicating the
|
||||
categories. This retriever will retrieve the in-context examples that
|
||||
share at least one label with the test case.
|
||||
"""
|
||||
label2indice = {}
|
||||
for index, item in enumerate(self.index_ds):
|
||||
for label in item['labels']:
|
||||
if label not in label2indice:
|
||||
label2indice[label] = []
|
||||
label2indice[label].append(index)
|
||||
rtr_idx_list = []
|
||||
for index in trange(len(self.test_ds),
|
||||
disable=not self.is_main_process):
|
||||
id_list = []
|
||||
for label in self.test_ds[index]['labels']:
|
||||
if len(id_list) < self.ice_num:
|
||||
id_list += label2indice[label]
|
||||
else:
|
||||
break
|
||||
rtr_idx_list.append(id_list[:self.ice_num])
|
||||
return rtr_idx_list
|
Loading…
Reference in New Issue
Block a user