mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
![]() |
from datasets import DatasetDict, load_dataset
|
||
|
|
||
|
from opencompass.registry import LOAD_DATASET
|
||
|
|
||
|
from .base import BaseDataset
|
||
|
|
||
|
|
||
|
@LOAD_DATASET.register_module()
|
||
|
class CivilCommentsDataset(BaseDataset):
|
||
|
|
||
|
@staticmethod
|
||
|
def load(**kwargs):
|
||
|
train_dataset = load_dataset(**kwargs, split='train')
|
||
|
test_dataset = load_dataset(**kwargs, split='test')
|
||
|
|
||
|
def pre_process(example):
|
||
|
example['label'] = int(example['toxicity'] >= 0.5)
|
||
|
example['choices'] = ['no', 'yes']
|
||
|
return example
|
||
|
|
||
|
def remove_columns(dataset):
|
||
|
return dataset.remove_columns([
|
||
|
'severe_toxicity', 'obscene', 'threat', 'insult',
|
||
|
'identity_attack', 'sexual_explicit'
|
||
|
])
|
||
|
|
||
|
train_dataset = remove_columns(train_dataset)
|
||
|
test_dataset = remove_columns(test_dataset)
|
||
|
test_dataset = test_dataset.shuffle(seed=42)
|
||
|
test_dataset = test_dataset.select(list(range(10000)))
|
||
|
test_dataset = test_dataset.map(pre_process)
|
||
|
|
||
|
return DatasetDict({
|
||
|
'train': train_dataset,
|
||
|
'test': test_dataset,
|
||
|
})
|