mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix: fix load function for SeedBenchDataset
This commit is contained in:
parent
8000375069
commit
c9ea024c67
@ -2,4 +2,4 @@ from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
# Default use LLM as a judge
|
||||
from .seedbench_gen_44868b import seedbench_datasets # noqa: F401, F403
|
||||
from .seedbench_gen_5d5ea1 import seedbench_datasets # noqa: F401, F403
|
||||
|
@ -6,8 +6,6 @@ from opencompass.datasets.SeedBench import SeedBenchDataset, F1ScoreEvaluator, m
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
|
||||
|
||||
agri_data_dir = './data/SeedBench'
|
||||
|
||||
agri_reader_cfg = dict(
|
||||
input_columns=['instruction', 'question'],
|
||||
output_column='answer'
|
||||
@ -61,14 +59,14 @@ for stage in ['zero-shot','one-shot']:
|
||||
)
|
||||
if 'pred_postprocessor' in config:
|
||||
eval_cfg['pred_postprocessor'] = config['pred_postprocessor']
|
||||
data_file = f"{agri_data_dir}/{stage}/{config['data_file']}"
|
||||
data_file = f"{stage}/{config['data_file']}"
|
||||
abbr_name = f"{config['abbr']}_{stage}"
|
||||
seedbench_datasets.append(
|
||||
dict(
|
||||
type=SeedBenchDataset,
|
||||
abbr=abbr_name,
|
||||
data_files=data_file,
|
||||
path='json',
|
||||
path='opencompass/seedbench',
|
||||
reader_cfg=agri_reader_cfg,
|
||||
infer_cfg=agri_infer_cfg,
|
||||
eval_cfg=eval_cfg
|
||||
|
@ -1,5 +1,6 @@
|
||||
import random
|
||||
import re
|
||||
from os import environ
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
@ -8,23 +9,37 @@ import numpy as np
|
||||
from rouge_chinese import Rouge
|
||||
|
||||
from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, TEXT_POSTPROCESSORS
|
||||
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||
TEXT_POSTPROCESSORS)
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class SeedBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(data_files: str,
|
||||
path: str = 'json',
|
||||
path: str,
|
||||
split: str = None,
|
||||
**kwargs) -> datasets.Dataset:
|
||||
dataset = datasets.load_dataset(path, data_files=data_files, **kwargs)
|
||||
|
||||
path = get_data_path(path)
|
||||
if environ.get('DATASET_SOURCE', None) == 'ModelScope':
|
||||
from modelscope import MsDataset
|
||||
dataset = MsDataset.load(path,
|
||||
subset_name='default',
|
||||
split=split,
|
||||
data_files=data_files,
|
||||
**kwargs)
|
||||
else:
|
||||
dataset = datasets.load_dataset(path,
|
||||
data_files=data_files,
|
||||
**kwargs)
|
||||
|
||||
if split is None:
|
||||
split = list(dataset.keys())[0]
|
||||
print(f'my datasets split : {split}')
|
||||
|
||||
if split not in dataset:
|
||||
raise ValueError(f"Split '{split}' not found. \
|
||||
|
@ -231,7 +231,7 @@ DATASETS_MAPPING = {
|
||||
},
|
||||
# SeedBench
|
||||
"opencompass/seedbench": {
|
||||
"ms_id": "",
|
||||
"ms_id": "y12869741/SeedBench",
|
||||
"hf_id": "y12869741/SeedBench",
|
||||
"local": "./data/SeedBench",
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user