OpenCompass/opencompass/datasets/inference_ppl.py
Que Haoran a244453d9e
[Feature] Support inference ppl datasets (#1315)
* commit inference ppl datasets

* revised format

* revise

* revise

* revise

* revise

* revise

* revise
2024-07-22 17:59:30 +08:00

38 lines
1021 B
Python

import os.path as osp
from typing import List
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class InferencePPLDataset(BaseDataset):
@staticmethod
def load(path: str, name: List[str] = None, samples: int = None):
# Check if file exists in the given path
supported_extensions = ['jsonl']
for ext in supported_extensions:
filename = osp.join(
path, f'{name}.{ext}') # name refers to data subset name
if osp.exists(filename):
break
else:
raise FileNotFoundError(f'{filename} not found.')
samples = 'test' if samples is None else f'test[:{samples}]'
data_files = {'test': filename}
dataset = load_dataset('json', data_files=data_files, split=samples)
# Filter out empty samples
dataset = dataset.filter(lambda example: len(example['text']) > 0)
return dataset