2023-11-13 13:00:37 +08:00
|
|
|
import json
|
|
|
|
|
|
|
|
from datasets import Dataset, DatasetDict
|
2023-07-05 09:01:25 +08:00
|
|
|
|
|
|
|
from opencompass.registry import LOAD_DATASET
|
|
|
|
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class dropDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
2023-11-13 13:00:37 +08:00
|
|
|
def get_answers(validated_answers):
|
|
|
|
answers = []
|
|
|
|
for answer_item in validated_answers:
|
|
|
|
if answer_item['number']:
|
|
|
|
answers.append(answer_item['number'])
|
|
|
|
elif any(answer_item['date'][i] for i in ['day', 'month', 'year']):
|
|
|
|
d = [answer_item['date'][i] for i in ['day', 'month', 'year']]
|
|
|
|
answers.append(' '.join(d).strip())
|
|
|
|
else:
|
|
|
|
for span in answer_item['spans']:
|
|
|
|
answers.append(span)
|
|
|
|
answers = list(set(answers))
|
|
|
|
return answers
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(path, only_number=True):
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
lines = json.load(f)
|
|
|
|
dataset_list = []
|
|
|
|
for line in lines.values():
|
|
|
|
for qa_pair in line['qa_pairs']:
|
|
|
|
validated_answers = qa_pair['validated_answers']
|
|
|
|
if only_number and not any(i['number']
|
|
|
|
for i in validated_answers):
|
|
|
|
continue
|
|
|
|
item = {
|
|
|
|
'prompt': line['passage'],
|
|
|
|
'question': qa_pair['question'],
|
|
|
|
'answers': dropDataset.get_answers(validated_answers),
|
|
|
|
}
|
|
|
|
dataset_list.append(item)
|
|
|
|
|
|
|
|
dataset_list = Dataset.from_list(dataset_list)
|
|
|
|
return DatasetDict({'validation': dataset_list})
|