OpenCompass/opencompass/datasets/xiezhi.py
Xingjun.Wang edab1c07ba
[Feature] Support ModelScope datasets (#1289)
* add ceval, gsm8k modelscope surpport

* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest

* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets

* format file

* format file

* update dataset format

* support ms_dataset

* udpate dataset for modelscope support

* merge myl_dev and update test_ms_dataset

* udpate dataset for modelscope support

* update readme

* update eval_api_zhipu_v2

* remove unused code

* add get_data_path function

* update readme

* remove tydiqa japanese subset

* add ceval, gsm8k modelscope surpport

* update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest

* update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets

* format file

* format file

* update dataset format

* support ms_dataset

* udpate dataset for modelscope support

* merge myl_dev and update test_ms_dataset

* update readme

* udpate dataset for modelscope support

* update eval_api_zhipu_v2

* remove unused code

* add get_data_path function

* remove tydiqa japanese subset

* update util

* remove .DS_Store

* fix md format

* move util into package

* update docs/get_started.md

* restore eval_api_zhipu_v2.py, add environment setting

* Update dataset

* Update

* Update

* Update

* Update

---------

Co-authored-by: Yun lin <yunlin@U-Q9X2K4QV-1904.local>
Co-authored-by: Yunnglin <mao.looper@qq.com>
Co-authored-by: Yun lin <yunlin@laptop.local>
Co-authored-by: Yunnglin <maoyl@smail.nju.edu.cn>
Co-authored-by: zhangsongyang <zhangsongyang@pjlab.org.cn>
2024-07-29 13:48:32 +08:00

91 lines
3.5 KiB
Python

import json
import os.path as osp
from typing import Optional
from datasets import Dataset, DatasetDict
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.utils import get_data_path
from .base import BaseDataset
class XiezhiDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
path = get_data_path(path, local_mode=True)
dataset = DatasetDict()
filename = osp.join(path, name, 'xiezhi.v1.json')
if 'chn' in name:
train_filename = osp.join(path, 'xiezhi_train_chn',
'xiezhi.v1.json')
else:
train_filename = osp.join(path, 'xiezhi_train_eng',
'xiezhi.v1.json')
for split, filename in [['train', train_filename], ['test', filename]]:
raw_data = []
with open(filename, encoding='utf-8') as f:
for line in f:
data = json.loads(line)
if data['options'].endswith("\"\n"):
data['options'] = data['options'][:-2]
options = data['options'].split('\n')
if len(options) != 4:
continue
answer = 'ABCD'[options.index(data['answer'])]
# The longer the label, the more fine-grained the concept
labels = sorted(
data['labels' if 'chn' in name else 'label'],
key=lambda x: len(x),
reverse=True)
raw_data.append({
'question': data['question'],
'A': options[0],
'B': options[1],
'C': options[2],
'D': options[3],
'labels': labels,
'answer': answer,
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
class XiezhiRetriever(BaseRetriever):
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
def retrieve(self):
"""Retrieve in-context examples for each test case.
For each one of the in-context example, there is a list of label,
indicating the categories to which the example is related. For each one
of the test case, there is also a list of label, indicating the
categories. This retriever will retrieve the in-context examples that
share at least one label with the test case.
"""
label2indice = {}
for index, item in enumerate(self.index_ds):
for label in item['labels']:
if label not in label2indice:
label2indice[label] = []
label2indice[label].append(index)
rtr_idx_list = []
for index in trange(len(self.test_ds),
disable=not self.is_main_process):
id_list = []
for label in self.test_ds[index]['labels']:
if len(id_list) < self.ice_num:
id_list += label2indice[label]
else:
break
rtr_idx_list.append(id_list[:self.ice_num])
return rtr_idx_list