OpenCompass/opencompass/datasets/xiezhi.py

89 lines
3.4 KiB
Python
Raw Normal View History

import json
import os.path as osp
from typing import Optional
from datasets import Dataset, DatasetDict
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from .base import BaseDataset
class XiezhiDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
filename = osp.join(path, name, 'xiezhi.v1.json')
if 'chn' in name:
train_filename = osp.join(path, 'xiezhi_train_chn',
'xiezhi.v1.json')
else:
train_filename = osp.join(path, 'xiezhi_train_eng',
'xiezhi.v1.json')
for split, filename in [['train', train_filename], ['test', filename]]:
raw_data = []
with open(filename, encoding='utf-8') as f:
for line in f:
data = json.loads(line)
if data['options'].endswith("\"\n"):
data['options'] = data['options'][:-2]
options = data['options'].split('\n')
if len(options) != 4:
continue
answer = 'ABCD'[options.index(data['answer'])]
# The longer the label, the more fine-grained the concept
labels = sorted(
data['labels' if 'chn' in name else 'label'],
key=lambda x: len(x),
reverse=True)
raw_data.append({
'question': data['question'],
'A': options[0],
'B': options[1],
'C': options[2],
'D': options[3],
'labels': labels,
'answer': answer,
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
class XiezhiRetriever(BaseRetriever):
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
def retrieve(self):
"""Retrieve in-context examples for each test case.
For each one of the in-context example, there is a list of label,
indicating the categories to which the example is related. For each one
of the test case, there is also a list of label, indicating the
categories. This retriever will retrieve the in-context examples that
share at least one label with the test case.
"""
label2indice = {}
for index, item in enumerate(self.index_ds):
for label in item['labels']:
if label not in label2indice:
label2indice[label] = []
label2indice[label].append(index)
rtr_idx_list = []
for index in trange(len(self.test_ds),
disable=not self.is_main_process):
id_list = []
for label in self.test_ds[index]['labels']:
if len(id_list) < self.ice_num:
id_list += label2indice[label]
else:
break
rtr_idx_list.append(id_list[:self.ice_num])
return rtr_idx_list