diff --git a/opencompass/datasets/hle.py b/opencompass/datasets/hle.py index 802c667d..b14507cc 100644 --- a/opencompass/datasets/hle.py +++ b/opencompass/datasets/hle.py @@ -9,7 +9,7 @@ from .base import BaseDataset class HLEDataset(BaseDataset): @staticmethod - def load(path: str, category: str): + def load(path: str, category: str | None = None): dataset = load_dataset(path) ds = dataset['test'].filter(lambda x: x['image'] == '') if category: