mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix: fix load function for SeedBenchDataset
This commit is contained in:
parent
8000375069
commit
c9ea024c67
@ -2,4 +2,4 @@ from mmengine.config import read_base
|
|||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
# Default use LLM as a judge
|
# Default use LLM as a judge
|
||||||
from .seedbench_gen_44868b import seedbench_datasets # noqa: F401, F403
|
from .seedbench_gen_5d5ea1 import seedbench_datasets # noqa: F401, F403
|
||||||
|
@ -6,8 +6,6 @@ from opencompass.datasets.SeedBench import SeedBenchDataset, F1ScoreEvaluator, m
|
|||||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||||
|
|
||||||
|
|
||||||
agri_data_dir = './data/SeedBench'
|
|
||||||
|
|
||||||
agri_reader_cfg = dict(
|
agri_reader_cfg = dict(
|
||||||
input_columns=['instruction', 'question'],
|
input_columns=['instruction', 'question'],
|
||||||
output_column='answer'
|
output_column='answer'
|
||||||
@ -61,14 +59,14 @@ for stage in ['zero-shot','one-shot']:
|
|||||||
)
|
)
|
||||||
if 'pred_postprocessor' in config:
|
if 'pred_postprocessor' in config:
|
||||||
eval_cfg['pred_postprocessor'] = config['pred_postprocessor']
|
eval_cfg['pred_postprocessor'] = config['pred_postprocessor']
|
||||||
data_file = f"{agri_data_dir}/{stage}/{config['data_file']}"
|
data_file = f"{stage}/{config['data_file']}"
|
||||||
abbr_name = f"{config['abbr']}_{stage}"
|
abbr_name = f"{config['abbr']}_{stage}"
|
||||||
seedbench_datasets.append(
|
seedbench_datasets.append(
|
||||||
dict(
|
dict(
|
||||||
type=SeedBenchDataset,
|
type=SeedBenchDataset,
|
||||||
abbr=abbr_name,
|
abbr=abbr_name,
|
||||||
data_files=data_file,
|
data_files=data_file,
|
||||||
path='json',
|
path='opencompass/seedbench',
|
||||||
reader_cfg=agri_reader_cfg,
|
reader_cfg=agri_reader_cfg,
|
||||||
infer_cfg=agri_infer_cfg,
|
infer_cfg=agri_infer_cfg,
|
||||||
eval_cfg=eval_cfg
|
eval_cfg=eval_cfg
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
from os import environ
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@ -8,23 +9,37 @@ import numpy as np
|
|||||||
from rouge_chinese import Rouge
|
from rouge_chinese import Rouge
|
||||||
|
|
||||||
from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
|
from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator
|
||||||
from opencompass.registry import ICL_EVALUATORS, TEXT_POSTPROCESSORS
|
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||||
|
TEXT_POSTPROCESSORS)
|
||||||
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
class SeedBenchDataset(BaseDataset):
|
class SeedBenchDataset(BaseDataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(data_files: str,
|
def load(data_files: str,
|
||||||
path: str = 'json',
|
path: str,
|
||||||
split: str = None,
|
split: str = None,
|
||||||
**kwargs) -> datasets.Dataset:
|
**kwargs) -> datasets.Dataset:
|
||||||
dataset = datasets.load_dataset(path, data_files=data_files, **kwargs)
|
|
||||||
|
path = get_data_path(path)
|
||||||
|
if environ.get('DATASET_SOURCE', None) == 'ModelScope':
|
||||||
|
from modelscope import MsDataset
|
||||||
|
dataset = MsDataset.load(path,
|
||||||
|
subset_name='default',
|
||||||
|
split=split,
|
||||||
|
data_files=data_files,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
dataset = datasets.load_dataset(path,
|
||||||
|
data_files=data_files,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
if split is None:
|
if split is None:
|
||||||
split = list(dataset.keys())[0]
|
split = list(dataset.keys())[0]
|
||||||
print(f'my datasets split : {split}')
|
|
||||||
|
|
||||||
if split not in dataset:
|
if split not in dataset:
|
||||||
raise ValueError(f"Split '{split}' not found. \
|
raise ValueError(f"Split '{split}' not found. \
|
||||||
|
@ -231,7 +231,7 @@ DATASETS_MAPPING = {
|
|||||||
},
|
},
|
||||||
# SeedBench
|
# SeedBench
|
||||||
"opencompass/seedbench": {
|
"opencompass/seedbench": {
|
||||||
"ms_id": "",
|
"ms_id": "y12869741/SeedBench",
|
||||||
"hf_id": "y12869741/SeedBench",
|
"hf_id": "y12869741/SeedBench",
|
||||||
"local": "./data/SeedBench",
|
"local": "./data/SeedBench",
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user