mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Update] Update dataset repeat concatenation (#2039)
This commit is contained in:
parent
dcbf899369
commit
97010dc4ce
@ -1,7 +1,6 @@
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets
|
||||
|
||||
from opencompass.openicl import DatasetReader
|
||||
|
||||
@ -19,28 +18,25 @@ class BaseDataset:
|
||||
assert (max(k) if isinstance(k, List) else
|
||||
k) <= n, 'Maximum value of `k` must less than or equal to `n`'
|
||||
if isinstance(dataset, Dataset):
|
||||
examples = []
|
||||
for idx, example in enumerate(dataset):
|
||||
if 'subdivision' not in example:
|
||||
example['subdivision'] = abbr
|
||||
if 'idx' not in example:
|
||||
example['idx'] = idx
|
||||
examples.append(example)
|
||||
examples = sum([deepcopy(examples) for _ in range(n)], [])
|
||||
self.dataset = Dataset.from_list(examples)
|
||||
dataset = dataset.map(lambda x, idx: {
|
||||
'subdivision': abbr,
|
||||
'idx': idx
|
||||
},
|
||||
with_indices=True,
|
||||
writer_batch_size=16)
|
||||
dataset = concatenate_datasets([dataset] * n)
|
||||
self.dataset = dataset
|
||||
else:
|
||||
self.dataset = DatasetDict()
|
||||
for key in dataset:
|
||||
examples = []
|
||||
for idx, example in enumerate(dataset[key]):
|
||||
if 'subdivision' not in example:
|
||||
example['subdivision'] = f'{abbr}_{key}'
|
||||
if 'idx' not in example:
|
||||
example['idx'] = idx
|
||||
examples.append(example)
|
||||
print(abbr, key, len(examples))
|
||||
examples = sum([deepcopy(examples) for _ in range(n)], [])
|
||||
self.dataset[key] = Dataset.from_list(examples)
|
||||
dataset[key] = dataset[key].map(lambda x, idx: {
|
||||
'subdivision': f'{abbr}_{key}',
|
||||
'idx': idx
|
||||
},
|
||||
with_indices=True,
|
||||
writer_batch_size=16)
|
||||
dataset[key] = concatenate_datasets([dataset[key]] * n)
|
||||
self.dataset[key] = dataset[key]
|
||||
self._init_reader(**reader_cfg)
|
||||
|
||||
def _init_reader(self, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user