mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* Update CascadeEvaluator * Update CascadeEvaluator * Update CascadeEvaluator * Update Config * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
from typing import Dict, List, Optional, Union
|
|
|
|
from datasets import Dataset, DatasetDict, concatenate_datasets
|
|
|
|
from opencompass.openicl import DatasetReader
|
|
from opencompass.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class BaseDataset:
|
|
|
|
def __init__(self,
|
|
reader_cfg: Optional[Dict] = {},
|
|
k: Union[int, List[int]] = 1,
|
|
n: int = 1,
|
|
**kwargs):
|
|
abbr = kwargs.pop('abbr', 'dataset')
|
|
dataset = self.load(**kwargs)
|
|
# maybe duplicate
|
|
assert (max(k) if isinstance(k, List) else
|
|
k) <= n, 'Maximum value of `k` must less than or equal to `n`'
|
|
if isinstance(dataset, Dataset):
|
|
dataset = dataset.map(lambda x, idx: {
|
|
'subdivision': abbr,
|
|
'idx': idx
|
|
},
|
|
with_indices=True,
|
|
writer_batch_size=16,
|
|
load_from_cache_file=False)
|
|
dataset = concatenate_datasets([dataset] * n)
|
|
self.dataset = dataset
|
|
else:
|
|
self.dataset = DatasetDict()
|
|
for key in dataset:
|
|
dataset[key] = dataset[key].map(lambda x, idx: {
|
|
'subdivision': f'{abbr}_{key}',
|
|
'idx': idx
|
|
},
|
|
with_indices=True,
|
|
writer_batch_size=16,
|
|
load_from_cache_file=False)
|
|
dataset[key] = concatenate_datasets([dataset[key]] * n)
|
|
self.dataset[key] = dataset[key]
|
|
self._init_reader(**reader_cfg)
|
|
|
|
def _init_reader(self, **kwargs):
|
|
self.reader = DatasetReader(self.dataset, **kwargs)
|
|
|
|
@property
|
|
def train(self):
|
|
return self.reader.dataset['train']
|
|
|
|
@property
|
|
def test(self):
|
|
return self.reader.dataset['test']
|
|
|
|
@staticmethod
|
|
def load(**kwargs) -> Union[Dataset, DatasetDict]:
|
|
pass
|