OpenCompass/opencompass/datasets/xiezhi.py

91 lines
3.5 KiB
Python
Raw Normal View History

import json
import os.path as osp
from typing import Optional
from datasets import Dataset, DatasetDict
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
[Feature] Support ModelScope datasets (#1289) * add ceval, gsm8k modelscope surpport * update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest * update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets * format file * format file * update dataset format * support ms_dataset * udpate dataset for modelscope support * merge myl_dev and update test_ms_dataset * udpate dataset for modelscope support * update readme * update eval_api_zhipu_v2 * remove unused code * add get_data_path function * update readme * remove tydiqa japanese subset * add ceval, gsm8k modelscope surpport * update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest * update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets * format file * format file * update dataset format * support ms_dataset * udpate dataset for modelscope support * merge myl_dev and update test_ms_dataset * update readme * udpate dataset for modelscope support * update eval_api_zhipu_v2 * remove unused code * add get_data_path function * remove tydiqa japanese subset * update util * remove .DS_Store * fix md format * move util into package * update docs/get_started.md * restore eval_api_zhipu_v2.py, add environment setting * Update dataset * Update * Update * Update * Update --------- Co-authored-by: Yun lin <yunlin@U-Q9X2K4QV-1904.local> Co-authored-by: Yunnglin <mao.looper@qq.com> Co-authored-by: Yun lin <yunlin@laptop.local> Co-authored-by: Yunnglin <maoyl@smail.nju.edu.cn> Co-authored-by: zhangsongyang <zhangsongyang@pjlab.org.cn>
2024-07-29 13:48:32 +08:00
from opencompass.utils import get_data_path
from .base import BaseDataset
class XiezhiDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
[Feature] Support ModelScope datasets (#1289) * add ceval, gsm8k modelscope surpport * update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest * update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets * format file * format file * update dataset format * support ms_dataset * udpate dataset for modelscope support * merge myl_dev and update test_ms_dataset * udpate dataset for modelscope support * update readme * update eval_api_zhipu_v2 * remove unused code * add get_data_path function * update readme * remove tydiqa japanese subset * add ceval, gsm8k modelscope surpport * update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest * update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets * format file * format file * update dataset format * support ms_dataset * udpate dataset for modelscope support * merge myl_dev and update test_ms_dataset * update readme * udpate dataset for modelscope support * update eval_api_zhipu_v2 * remove unused code * add get_data_path function * remove tydiqa japanese subset * update util * remove .DS_Store * fix md format * move util into package * update docs/get_started.md * restore eval_api_zhipu_v2.py, add environment setting * Update dataset * Update * Update * Update * Update --------- Co-authored-by: Yun lin <yunlin@U-Q9X2K4QV-1904.local> Co-authored-by: Yunnglin <mao.looper@qq.com> Co-authored-by: Yun lin <yunlin@laptop.local> Co-authored-by: Yunnglin <maoyl@smail.nju.edu.cn> Co-authored-by: zhangsongyang <zhangsongyang@pjlab.org.cn>
2024-07-29 13:48:32 +08:00
path = get_data_path(path, local_mode=True)
dataset = DatasetDict()
filename = osp.join(path, name, 'xiezhi.v1.json')
if 'chn' in name:
train_filename = osp.join(path, 'xiezhi_train_chn',
'xiezhi.v1.json')
else:
train_filename = osp.join(path, 'xiezhi_train_eng',
'xiezhi.v1.json')
for split, filename in [['train', train_filename], ['test', filename]]:
raw_data = []
with open(filename, encoding='utf-8') as f:
for line in f:
data = json.loads(line)
if data['options'].endswith("\"\n"):
data['options'] = data['options'][:-2]
options = data['options'].split('\n')
if len(options) != 4:
continue
answer = 'ABCD'[options.index(data['answer'])]
# The longer the label, the more fine-grained the concept
labels = sorted(
data['labels' if 'chn' in name else 'label'],
key=lambda x: len(x),
reverse=True)
raw_data.append({
'question': data['question'],
'A': options[0],
'B': options[1],
'C': options[2],
'D': options[3],
'labels': labels,
'answer': answer,
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
class XiezhiRetriever(BaseRetriever):
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
def retrieve(self):
"""Retrieve in-context examples for each test case.
For each one of the in-context example, there is a list of label,
indicating the categories to which the example is related. For each one
of the test case, there is also a list of label, indicating the
categories. This retriever will retrieve the in-context examples that
share at least one label with the test case.
"""
label2indice = {}
for index, item in enumerate(self.index_ds):
for label in item['labels']:
if label not in label2indice:
label2indice[label] = []
label2indice[label].append(index)
rtr_idx_list = []
for index in trange(len(self.test_ds),
disable=not self.is_main_process):
id_list = []
for label in self.test_ds[index]['labels']:
if len(id_list) < self.ice_num:
id_list += label2indice[label]
else:
break
rtr_idx_list.append(id_list[:self.ice_num])
return rtr_idx_list