OpenCompass/opencompass/datasets/PromptCBLUE.py

63 lines
2.0 KiB
Python
Raw Normal View History

2025-05-09 14:47:26 +08:00
from datasets import Dataset, DatasetDict, load_dataset
2025-05-04 20:00:36 +08:00
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset # 保持与 MMLUDataset 同级的导包风格
@LOAD_DATASET.register_module()
class PromptCBLUEDataset(BaseDataset):
"""Loader for PromptCBLUE life-science tasks (CHIP-CDN, CHIP-CTC …).
2025-05-09 14:47:26 +08:00
- 只读 validation split
- 保留指定 `task_dataset` 的所有任务类型
- `target` 不在 `answer_choices`自动追加并生成 `options_str`
- 返回 `DatasetDict`包含 `validation` `test`以满足评估流程
2025-05-04 20:00:36 +08:00
"""
@staticmethod
def load(path: str, name: str, **kwargs):
2025-05-09 14:47:26 +08:00
# 1) 从 HuggingFace 读取 validation split
hf_ds = load_dataset(path, split='validation', **kwargs)
2025-05-04 20:00:36 +08:00
2025-05-09 14:47:26 +08:00
# 2) 过滤子数据集并构造记录
2025-05-04 20:00:36 +08:00
records = []
2025-05-09 14:47:26 +08:00
for rec in hf_ds:
if rec.get('task_dataset') != name:
continue
choices = rec.get('answer_choices', []).copy()
target = rec.get('target')
if target not in choices:
choices.append(target)
options_str = '\n'.join(f'{chr(65 + i)}. {opt}'
for i, opt in enumerate(choices))
records.append({
'input': rec['input'],
'answer_choices': choices,
'options_str': options_str,
'target': target,
})
# 3) 构造 Dataset
2025-05-04 20:00:36 +08:00
if records:
2025-05-09 14:47:26 +08:00
validation_ds = Dataset.from_list(records)
2025-05-04 20:00:36 +08:00
else:
2025-05-09 14:47:26 +08:00
validation_ds = Dataset.from_dict({
2025-05-04 20:00:36 +08:00
k: []
2025-05-09 14:47:26 +08:00
for k in [
'input',
'answer_choices',
'options_str',
'target',
]
2025-05-04 20:00:36 +08:00
})
2025-05-09 14:47:26 +08:00
# 4) 返回时包含 validation 和 test
return DatasetDict(
validation=validation_ds,
test=validation_ds,
)