diff --git a/opencompass/datasets/base.py b/opencompass/datasets/base.py index 7099b0c6..ac6c4570 100644 --- a/opencompass/datasets/base.py +++ b/opencompass/datasets/base.py @@ -1,7 +1,6 @@ -from copy import deepcopy from typing import Dict, List, Optional, Union -from datasets import Dataset, DatasetDict +from datasets import Dataset, DatasetDict, concatenate_datasets from opencompass.openicl import DatasetReader @@ -19,28 +18,25 @@ class BaseDataset: assert (max(k) if isinstance(k, List) else k) <= n, 'Maximum value of `k` must less than or equal to `n`' if isinstance(dataset, Dataset): - examples = [] - for idx, example in enumerate(dataset): - if 'subdivision' not in example: - example['subdivision'] = abbr - if 'idx' not in example: - example['idx'] = idx - examples.append(example) - examples = sum([deepcopy(examples) for _ in range(n)], []) - self.dataset = Dataset.from_list(examples) + dataset = dataset.map(lambda x, idx: { + 'subdivision': abbr, + 'idx': idx + }, + with_indices=True, + writer_batch_size=16) + dataset = concatenate_datasets([dataset] * n) + self.dataset = dataset else: self.dataset = DatasetDict() for key in dataset: - examples = [] - for idx, example in enumerate(dataset[key]): - if 'subdivision' not in example: - example['subdivision'] = f'{abbr}_{key}' - if 'idx' not in example: - example['idx'] = idx - examples.append(example) - print(abbr, key, len(examples)) - examples = sum([deepcopy(examples) for _ in range(n)], []) - self.dataset[key] = Dataset.from_list(examples) + dataset[key] = dataset[key].map(lambda x, idx: { + 'subdivision': f'{abbr}_{key}', + 'idx': idx + }, + with_indices=True, + writer_batch_size=16) + dataset[key] = concatenate_datasets([dataset[key]] * n) + self.dataset[key] = dataset[key] self._init_reader(**reader_cfg) def _init_reader(self, **kwargs):