OpenCompass/opencompass/datasets/base.py
Songyang Zhang aa2b89b6f8
[Update] Add CascadeEvaluator with Data Replica (#2022)
* Update CascadeEvaluator

* Update CascadeEvaluator

* Update CascadeEvaluator

* Update Config

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
2025-05-20 16:46:55 +08:00

61 lines
2.0 KiB
Python

from typing import Dict, List, Optional, Union
from datasets import Dataset, DatasetDict, concatenate_datasets
from opencompass.openicl import DatasetReader
from opencompass.utils import get_logger
logger = get_logger()
class BaseDataset:
def __init__(self,
reader_cfg: Optional[Dict] = {},
k: Union[int, List[int]] = 1,
n: int = 1,
**kwargs):
abbr = kwargs.pop('abbr', 'dataset')
dataset = self.load(**kwargs)
# maybe duplicate
assert (max(k) if isinstance(k, List) else
k) <= n, 'Maximum value of `k` must less than or equal to `n`'
if isinstance(dataset, Dataset):
dataset = dataset.map(lambda x, idx: {
'subdivision': abbr,
'idx': idx
},
with_indices=True,
writer_batch_size=16,
load_from_cache_file=False)
dataset = concatenate_datasets([dataset] * n)
self.dataset = dataset
else:
self.dataset = DatasetDict()
for key in dataset:
dataset[key] = dataset[key].map(lambda x, idx: {
'subdivision': f'{abbr}_{key}',
'idx': idx
},
with_indices=True,
writer_batch_size=16,
load_from_cache_file=False)
dataset[key] = concatenate_datasets([dataset[key]] * n)
self.dataset[key] = dataset[key]
self._init_reader(**reader_cfg)
def _init_reader(self, **kwargs):
self.reader = DatasetReader(self.dataset, **kwargs)
@property
def train(self):
return self.reader.dataset['train']
@property
def test(self):
return self.reader.dataset['test']
@staticmethod
def load(**kwargs) -> Union[Dataset, DatasetDict]:
pass