mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* add ceval, gsm8k modelscope surpport * update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest * update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets * format file * format file * update dataset format * support ms_dataset * udpate dataset for modelscope support * merge myl_dev and update test_ms_dataset * udpate dataset for modelscope support * update readme * update eval_api_zhipu_v2 * remove unused code * add get_data_path function * update readme * remove tydiqa japanese subset * add ceval, gsm8k modelscope surpport * update race, mmlu, arc, cmmlu, commonsenseqa, humaneval and unittest * update bbh, flores, obqa, siqa, storycloze, summedits, winogrande, xsum datasets * format file * format file * update dataset format * support ms_dataset * udpate dataset for modelscope support * merge myl_dev and update test_ms_dataset * update readme * udpate dataset for modelscope support * update eval_api_zhipu_v2 * remove unused code * add get_data_path function * remove tydiqa japanese subset * update util * remove .DS_Store * fix md format * move util into package * update docs/get_started.md * restore eval_api_zhipu_v2.py, add environment setting * Update dataset * Update * Update * Update * Update --------- Co-authored-by: Yun lin <yunlin@U-Q9X2K4QV-1904.local> Co-authored-by: Yunnglin <mao.looper@qq.com> Co-authored-by: Yun lin <yunlin@laptop.local> Co-authored-by: Yunnglin <maoyl@smail.nju.edu.cn> Co-authored-by: zhangsongyang <zhangsongyang@pjlab.org.cn>
91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
import json
|
|
import os.path as osp
|
|
from typing import Optional
|
|
|
|
from datasets import Dataset, DatasetDict
|
|
from tqdm import trange
|
|
|
|
from opencompass.openicl.icl_retriever import BaseRetriever
|
|
from opencompass.utils import get_data_path
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
class XiezhiDataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str, name: str):
|
|
path = get_data_path(path, local_mode=True)
|
|
dataset = DatasetDict()
|
|
filename = osp.join(path, name, 'xiezhi.v1.json')
|
|
if 'chn' in name:
|
|
train_filename = osp.join(path, 'xiezhi_train_chn',
|
|
'xiezhi.v1.json')
|
|
else:
|
|
train_filename = osp.join(path, 'xiezhi_train_eng',
|
|
'xiezhi.v1.json')
|
|
for split, filename in [['train', train_filename], ['test', filename]]:
|
|
raw_data = []
|
|
with open(filename, encoding='utf-8') as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
if data['options'].endswith("\"\n"):
|
|
data['options'] = data['options'][:-2]
|
|
options = data['options'].split('\n')
|
|
if len(options) != 4:
|
|
continue
|
|
answer = 'ABCD'[options.index(data['answer'])]
|
|
# The longer the label, the more fine-grained the concept
|
|
labels = sorted(
|
|
data['labels' if 'chn' in name else 'label'],
|
|
key=lambda x: len(x),
|
|
reverse=True)
|
|
raw_data.append({
|
|
'question': data['question'],
|
|
'A': options[0],
|
|
'B': options[1],
|
|
'C': options[2],
|
|
'D': options[3],
|
|
'labels': labels,
|
|
'answer': answer,
|
|
})
|
|
dataset[split] = Dataset.from_list(raw_data)
|
|
return dataset
|
|
|
|
|
|
class XiezhiRetriever(BaseRetriever):
|
|
|
|
def __init__(self,
|
|
dataset,
|
|
ice_separator: Optional[str] = '\n',
|
|
ice_eos_token: Optional[str] = '\n',
|
|
ice_num: Optional[int] = 1) -> None:
|
|
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
|
|
|
|
def retrieve(self):
|
|
"""Retrieve in-context examples for each test case.
|
|
|
|
For each one of the in-context example, there is a list of label,
|
|
indicating the categories to which the example is related. For each one
|
|
of the test case, there is also a list of label, indicating the
|
|
categories. This retriever will retrieve the in-context examples that
|
|
share at least one label with the test case.
|
|
"""
|
|
label2indice = {}
|
|
for index, item in enumerate(self.index_ds):
|
|
for label in item['labels']:
|
|
if label not in label2indice:
|
|
label2indice[label] = []
|
|
label2indice[label].append(index)
|
|
rtr_idx_list = []
|
|
for index in trange(len(self.test_ds),
|
|
disable=not self.is_main_process):
|
|
id_list = []
|
|
for label in self.test_ds[index]['labels']:
|
|
if len(id_list) < self.ice_num:
|
|
id_list += label2indice[label]
|
|
else:
|
|
break
|
|
rtr_idx_list.append(id_list[:self.ice_num])
|
|
return rtr_idx_list
|