[Update] Update dataset repeat concatenation (#2039)

This commit is contained in:
Junnan Liu 2025-04-23 16:16:28 +08:00 committed by GitHub
parent dcbf899369
commit 97010dc4ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,6 @@
from copy import deepcopy
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict, concatenate_datasets
from opencompass.openicl import DatasetReader from opencompass.openicl import DatasetReader
@ -19,28 +18,25 @@ class BaseDataset:
assert (max(k) if isinstance(k, List) else assert (max(k) if isinstance(k, List) else
k) <= n, 'Maximum value of `k` must less than or equal to `n`' k) <= n, 'Maximum value of `k` must less than or equal to `n`'
if isinstance(dataset, Dataset): if isinstance(dataset, Dataset):
examples = [] dataset = dataset.map(lambda x, idx: {
for idx, example in enumerate(dataset): 'subdivision': abbr,
if 'subdivision' not in example: 'idx': idx
example['subdivision'] = abbr },
if 'idx' not in example: with_indices=True,
example['idx'] = idx writer_batch_size=16)
examples.append(example) dataset = concatenate_datasets([dataset] * n)
examples = sum([deepcopy(examples) for _ in range(n)], []) self.dataset = dataset
self.dataset = Dataset.from_list(examples)
else: else:
self.dataset = DatasetDict() self.dataset = DatasetDict()
for key in dataset: for key in dataset:
examples = [] dataset[key] = dataset[key].map(lambda x, idx: {
for idx, example in enumerate(dataset[key]): 'subdivision': f'{abbr}_{key}',
if 'subdivision' not in example: 'idx': idx
example['subdivision'] = f'{abbr}_{key}' },
if 'idx' not in example: with_indices=True,
example['idx'] = idx writer_batch_size=16)
examples.append(example) dataset[key] = concatenate_datasets([dataset[key]] * n)
print(abbr, key, len(examples)) self.dataset[key] = dataset[key]
examples = sum([deepcopy(examples) for _ in range(n)], [])
self.dataset[key] = Dataset.from_list(examples)
self._init_reader(**reader_cfg) self._init_reader(**reader_cfg)
def _init_reader(self, **kwargs): def _init_reader(self, **kwargs):