This commit is contained in:
tcheng 2025-05-13 17:26:13 +08:00 committed by GitHub
commit 9bde347000
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 235 additions and 0 deletions

View File

@ -719,6 +719,12 @@
paper: https://arxiv.org/pdf/2009.03300
configpath: opencompass/configs/datasets/mmlu/mmlu_gen.py
configpath_llmjudge: opencompass/configs/datasets/mmlu/mmlu_llm_judge_gen.py
- PromptCBLUE:
name: PromptCBLUE
category: Understanding
paper: https://arxiv.org/pdf/2310.14151
configpath: opencompass/configs/datasets/PromptCBLUE/PromptCBLUE_gen.py
configpath_llmjudge: opencompass/configs/datasets/PromptCBLUE/PromptCBLUE_llmjudge_gen.py
- mmlu_cf:
name: MMLU-CF
category: Understanding

View File

@ -0,0 +1,64 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.utils.text_postprocessors import first_capital_postprocess
from opencompass.datasets import PromptCBLUEDataset
# 1. 子数据集名称
PromptCBLUE_lifescience_sets = [
'CHIP-CDN', 'CHIP-CTC', 'KUAKE-QIC', 'IMCS-V2-DAC',
'CHIP-STS', 'KUAKE-QQR', 'KUAKE-IR', 'KUAKE-QTR'
]
# 2. Reader 配置
reader_cfg = dict(
input_columns=['input', 'answer_choices', 'options_str'],
output_column='target',
train_split='validation',
)
# 3. Prompt 模板:末行固定 ANSWER: $LETTER
_HINT = 'Given the ICD-10 candidate terms below, choose the normalized term(s) matching the original diagnosis.'
query_template = f"""{_HINT}
Original diagnosis: {{input}}
Options:
{{options_str}}
The last line of your response must be exactly:
ANSWER: $LETTER
""".strip()
infer_cfg_common = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[dict(role='HUMAN', prompt=query_template)]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
# 4. 评估配置
eval_cfg_common = dict(
evaluator=dict(type=AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess),
)
# 5. 组装数据集配置
promptcblue_datasets = []
for ds_name in PromptCBLUE_lifescience_sets:
promptcblue_datasets.append(dict(
abbr=f'promptcblue_{ds_name.lower().replace("-", "_")}_norm',
type=PromptCBLUEDataset,
path='tchenglv/PromptCBLUE',
name=ds_name,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg_common,
eval_cfg=eval_cfg_common,
))
# ★ OpenCompass 识别的出口变量
datasets = promptcblue_datasets

View File

@ -0,0 +1,102 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.evaluator import GenericLLMEvaluator
from opencompass.datasets import generic_llmjudge_postprocess
from opencompass.datasets import PromptCBLUEDataset
PromptCBLUE_lifescience_sets = [
'CHIP-CDN', 'CHIP-CTC', 'KUAKE-QIC', 'IMCS-V2-DAC',
'CHIP-STS', 'KUAKE-QQR', 'KUAKE-IR', 'KUAKE-QTR'
]
# Query template (keep original)
QUERY_TEMPLATE = """
Given a medical diagnosis description and labeled ICD-10 candidate terms below, select the matching normalized term(s).
Original diagnosis: {input}
Options:
{options_str}
The last line of your response must be exactly in the format:
ANSWER: <LETTER(S)>
""".strip()
# Grader template (keep original)
GRADER_TEMPLATE = """
As an expert evaluator, judge whether the candidate's answer matches the gold standard below.
Return 'A' for CORRECT or 'B' for INCORRECT, with no additional text.
Original diagnosis: {input}
Options:
{options_str}
Gold answer: {target}
Candidate answer: {prediction}
""".strip()
# Common reader config
reader_cfg = dict(
input_columns=['input', 'answer_choices', 'options_str'],
output_column='target',
train_split='validation'
)
# Assemble LLM evaluation datasets
promptcblue_llm_datasets = []
for name in PromptCBLUE_lifescience_sets:
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt=QUERY_TEMPLATE),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
eval_cfg = dict(
evaluator=dict(
type=GenericLLMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt='You are an expert judge for medical term normalization tasks.',
)
],
round=[
dict(role='HUMAN', prompt=GRADER_TEMPLATE),
],
)
),
dataset_cfg=dict(
type=PromptCBLUEDataset,
path='tchenglv/PromptCBLUE',
name=name,
reader_cfg=reader_cfg,
),
judge_cfg=dict(),
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
),
pred_role='BOT',
)
promptcblue_llm_datasets.append(
dict(
abbr=f"promptcblue_{name.lower().replace('-', '_')}_norm_llm",
type=PromptCBLUEDataset,
path='tchenglv/PromptCBLUE',
name=name,
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
mode='singlescore',
)
)

View File

@ -0,0 +1,62 @@
from datasets import Dataset, DatasetDict, load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset # 保持与 MMLUDataset 同级的导包风格
@LOAD_DATASET.register_module()
class PromptCBLUEDataset(BaseDataset):
"""Loader for PromptCBLUE life-science tasks (CHIP-CDN, CHIP-CTC …).
- 只读 validation split
- 保留指定 `task_dataset` 的所有任务类型
- `target` 不在 `answer_choices`自动追加并生成 `options_str`
- 返回 `DatasetDict`包含 `validation` `test`以满足评估流程
"""
@staticmethod
def load(path: str, name: str, **kwargs):
# 1) 从 HuggingFace 读取 validation split
hf_ds = load_dataset(path, split='validation', **kwargs)
# 2) 过滤子数据集并构造记录
records = []
for rec in hf_ds:
if rec.get('task_dataset') != name:
continue
choices = rec.get('answer_choices', []).copy()
target = rec.get('target')
if target not in choices:
choices.append(target)
options_str = '\n'.join(f'{chr(65 + i)}. {opt}'
for i, opt in enumerate(choices))
records.append({
'input': rec['input'],
'answer_choices': choices,
'options_str': options_str,
'target': target,
})
# 3) 构造 Dataset
if records:
validation_ds = Dataset.from_list(records)
else:
validation_ds = Dataset.from_dict({
k: []
for k in [
'input',
'answer_choices',
'options_str',
'target',
]
})
# 4) 返回时包含 validation 和 test
return DatasetDict(
validation=validation_ds,
test=validation_ds,
)

View File

@ -125,6 +125,7 @@ from .OlympiadBench import * # noqa: F401, F403
from .OpenFinData import * # noqa: F401, F403
from .physics import * # noqa: F401, F403
from .piqa import * # noqa: F401, F403
from .PromptCBLUE import * # noqa: F401, F403
from .ProteinLMBench import * # noqa: F401, F403
from .py150 import * # noqa: F401, F403
from .qasper import * # noqa: F401, F403