mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add CMMLU dataset (#91)
* add CMMLU * debug cmmlu * add slurm args `qos` * fix format: space before comment * remove unused variable * change the location of `answer is` --------- Co-authored-by: 李浩楠 <lihaonan@lihaonandeMacBook-Air.local> Co-authored-by: 李浩楠 <haonan.li> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
6e885d668b
commit
e9cdb24ddd
4
configs/datasets/cmmlu/cmmlu_gen.py
Normal file
4
configs/datasets/cmmlu/cmmlu_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .cmmlu_gen_ffe7c0 import cmmlu_datasets # noqa: F401, F403
|
122
configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py
Normal file
122
configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py
Normal file
@ -0,0 +1,122 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import CMMLUDataset
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
cmmlu_subject_mapping = {
|
||||
'agronomy': '农学',
|
||||
'anatomy': '解剖学',
|
||||
'ancient_chinese': '古汉语',
|
||||
'arts': '艺术学',
|
||||
'astronomy': '天文学',
|
||||
'business_ethics': '商业伦理',
|
||||
'chinese_civil_service_exam': '中国公务员考试',
|
||||
'chinese_driving_rule': '中国驾驶规则',
|
||||
'chinese_food_culture': '中国饮食文化',
|
||||
'chinese_foreign_policy': '中国外交政策',
|
||||
'chinese_history': '中国历史',
|
||||
'chinese_literature': '中国文学',
|
||||
'chinese_teacher_qualification': '中国教师资格',
|
||||
'clinical_knowledge': '临床知识',
|
||||
'college_actuarial_science': '大学精算学',
|
||||
'college_education': '大学教育学',
|
||||
'college_engineering_hydrology': '大学工程水文学',
|
||||
'college_law': '大学法律',
|
||||
'college_mathematics': '大学数学',
|
||||
'college_medical_statistics': '大学医学统计',
|
||||
'college_medicine': '大学医学',
|
||||
'computer_science': '计算机科学',
|
||||
'computer_security': '计算机安全',
|
||||
'conceptual_physics': '概念物理学',
|
||||
'construction_project_management': '建设工程管理',
|
||||
'economics': '经济学',
|
||||
'education': '教育学',
|
||||
'electrical_engineering': '电气工程',
|
||||
'elementary_chinese': '小学语文',
|
||||
'elementary_commonsense': '小学常识',
|
||||
'elementary_information_and_technology': '小学信息技术',
|
||||
'elementary_mathematics': '初等数学',
|
||||
'ethnology': '民族学',
|
||||
'food_science': '食品科学',
|
||||
'genetics': '遗传学',
|
||||
'global_facts': '全球事实',
|
||||
'high_school_biology': '高中生物',
|
||||
'high_school_chemistry': '高中化学',
|
||||
'high_school_geography': '高中地理',
|
||||
'high_school_mathematics': '高中数学',
|
||||
'high_school_physics': '高中物理学',
|
||||
'high_school_politics': '高中政治',
|
||||
'human_sexuality': '人类性行为',
|
||||
'international_law': '国际法学',
|
||||
'journalism': '新闻学',
|
||||
'jurisprudence': '法理学',
|
||||
'legal_and_moral_basis': '法律与道德基础',
|
||||
'logical': '逻辑学',
|
||||
'machine_learning': '机器学习',
|
||||
'management': '管理学',
|
||||
'marketing': '市场营销',
|
||||
'marxist_theory': '马克思主义理论',
|
||||
'modern_chinese': '现代汉语',
|
||||
'nutrition': '营养学',
|
||||
'philosophy': '哲学',
|
||||
'professional_accounting': '专业会计',
|
||||
'professional_law': '专业法学',
|
||||
'professional_medicine': '专业医学',
|
||||
'professional_psychology': '专业心理学',
|
||||
'public_relations': '公共关系',
|
||||
'security_study': '安全研究',
|
||||
'sociology': '社会学',
|
||||
'sports_science': '体育学',
|
||||
'traditional_chinese_medicine': '中医中药',
|
||||
'virology': '病毒学',
|
||||
'world_history': '世界历史',
|
||||
'world_religions': '世界宗教'
|
||||
}
|
||||
|
||||
|
||||
cmmlu_all_sets = list(cmmlu_subject_mapping.keys())
|
||||
|
||||
cmmlu_datasets = []
|
||||
for _name in cmmlu_all_sets:
|
||||
_ch_name = cmmlu_subject_mapping[_name]
|
||||
cmmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n题目:{{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}"
|
||||
),
|
||||
dict(role="BOT", prompt='答案是: {answer}'),
|
||||
]),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
)
|
||||
|
||||
cmmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_postprocessor=dict(type=first_capital_postprocess))
|
||||
|
||||
cmmlu_datasets.append(
|
||||
dict(
|
||||
type=CMMLUDataset,
|
||||
path="./data/cmmlu/",
|
||||
name=_name,
|
||||
abbr=f"cmmlu-{_name}",
|
||||
reader_cfg=dict(
|
||||
input_columns=["question", "A", "B", "C", "D"],
|
||||
output_column="answer",
|
||||
train_split="dev",
|
||||
test_split='test'),
|
||||
infer_cfg=cmmlu_infer_cfg,
|
||||
eval_cfg=cmmlu_eval_cfg,
|
||||
))
|
||||
|
||||
del _name, _ch_name
|
4
configs/datasets/cmmlu/cmmlu_ppl.py
Normal file
4
configs/datasets/cmmlu/cmmlu_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .cmmlu_ppl_fd1f2f import cmmlu_datasets # noqa: F401, F403
|
122
configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py
Normal file
122
configs/datasets/cmmlu/cmmlu_ppl_fd1f2f.py
Normal file
@ -0,0 +1,122 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import CMMLUDataset
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
cmmlu_subject_mapping = {
|
||||
'agronomy': '农学',
|
||||
'anatomy': '解剖学',
|
||||
'ancient_chinese': '古汉语',
|
||||
'arts': '艺术学',
|
||||
'astronomy': '天文学',
|
||||
'business_ethics': '商业伦理',
|
||||
'chinese_civil_service_exam': '中国公务员考试',
|
||||
'chinese_driving_rule': '中国驾驶规则',
|
||||
'chinese_food_culture': '中国饮食文化',
|
||||
'chinese_foreign_policy': '中国外交政策',
|
||||
'chinese_history': '中国历史',
|
||||
'chinese_literature': '中国文学',
|
||||
'chinese_teacher_qualification': '中国教师资格',
|
||||
'clinical_knowledge': '临床知识',
|
||||
'college_actuarial_science': '大学精算学',
|
||||
'college_education': '大学教育学',
|
||||
'college_engineering_hydrology': '大学工程水文学',
|
||||
'college_law': '大学法律',
|
||||
'college_mathematics': '大学数学',
|
||||
'college_medical_statistics': '大学医学统计',
|
||||
'college_medicine': '大学医学',
|
||||
'computer_science': '计算机科学',
|
||||
'computer_security': '计算机安全',
|
||||
'conceptual_physics': '概念物理学',
|
||||
'construction_project_management': '建设工程管理',
|
||||
'economics': '经济学',
|
||||
'education': '教育学',
|
||||
'electrical_engineering': '电气工程',
|
||||
'elementary_chinese': '小学语文',
|
||||
'elementary_commonsense': '小学常识',
|
||||
'elementary_information_and_technology': '小学信息技术',
|
||||
'elementary_mathematics': '初等数学',
|
||||
'ethnology': '民族学',
|
||||
'food_science': '食品科学',
|
||||
'genetics': '遗传学',
|
||||
'global_facts': '全球事实',
|
||||
'high_school_biology': '高中生物',
|
||||
'high_school_chemistry': '高中化学',
|
||||
'high_school_geography': '高中地理',
|
||||
'high_school_mathematics': '高中数学',
|
||||
'high_school_physics': '高中物理学',
|
||||
'high_school_politics': '高中政治',
|
||||
'human_sexuality': '人类性行为',
|
||||
'international_law': '国际法学',
|
||||
'journalism': '新闻学',
|
||||
'jurisprudence': '法理学',
|
||||
'legal_and_moral_basis': '法律与道德基础',
|
||||
'logical': '逻辑学',
|
||||
'machine_learning': '机器学习',
|
||||
'management': '管理学',
|
||||
'marketing': '市场营销',
|
||||
'marxist_theory': '马克思主义理论',
|
||||
'modern_chinese': '现代汉语',
|
||||
'nutrition': '营养学',
|
||||
'philosophy': '哲学',
|
||||
'professional_accounting': '专业会计',
|
||||
'professional_law': '专业法学',
|
||||
'professional_medicine': '专业医学',
|
||||
'professional_psychology': '专业心理学',
|
||||
'public_relations': '公共关系',
|
||||
'security_study': '安全研究',
|
||||
'sociology': '社会学',
|
||||
'sports_science': '体育学',
|
||||
'traditional_chinese_medicine': '中医中药',
|
||||
'virology': '病毒学',
|
||||
'world_history': '世界历史',
|
||||
'world_religions': '世界宗教'
|
||||
}
|
||||
|
||||
|
||||
cmmlu_all_sets = list(cmmlu_subject_mapping.keys())
|
||||
|
||||
cmmlu_datasets = []
|
||||
for _name in cmmlu_all_sets:
|
||||
_ch_name = cmmlu_subject_mapping[_name]
|
||||
cmmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
answer: dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n题目:{{question}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}"
|
||||
),
|
||||
dict(role="BOT", prompt=f'答案是: {answer}'),
|
||||
])
|
||||
for answer in ["A", "B", "C", "D"]
|
||||
},
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
)
|
||||
|
||||
cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
cmmlu_datasets.append(
|
||||
dict(
|
||||
type=CMMLUDataset,
|
||||
path="./data/cmmlu/",
|
||||
name=_name,
|
||||
abbr=f"cmmlu-{_name}",
|
||||
reader_cfg=dict(
|
||||
input_columns=["question", "A", "B", "C", "D"],
|
||||
output_column="answer",
|
||||
train_split="dev",
|
||||
test_split='test'),
|
||||
infer_cfg=cmmlu_infer_cfg,
|
||||
eval_cfg=cmmlu_eval_cfg,
|
||||
))
|
||||
|
||||
del _name, _ch_name
|
@ -11,6 +11,7 @@ from .ceval import * # noqa: F401, F403
|
||||
from .chid import * # noqa: F401, F403
|
||||
from .civilcomments import * # noqa: F401, F403
|
||||
from .cluewsc import * # noqa: F401, F403
|
||||
from .cmmlu import * # noqa: F401, F403
|
||||
from .cmnli import * # noqa: F401, F403
|
||||
from .cmrc import * # noqa: F401, F403
|
||||
from .commonsenseqa import * # noqa: F401, F403
|
||||
|
34
opencompass/datasets/cmmlu.py
Normal file
34
opencompass/datasets/cmmlu.py
Normal file
@ -0,0 +1,34 @@
|
||||
import csv
|
||||
import os.path as osp
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class CMMLUDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str):
|
||||
dataset = DatasetDict()
|
||||
for split in ['dev', 'test']:
|
||||
raw_data = []
|
||||
filename = osp.join(path, split, f'{name}.csv')
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
_ = next(reader) # skip the header
|
||||
for row in reader:
|
||||
assert len(row) == 7
|
||||
raw_data.append({
|
||||
'question': row[1],
|
||||
'A': row[2],
|
||||
'B': row[3],
|
||||
'C': row[4],
|
||||
'D': row[5],
|
||||
'answer': row[6],
|
||||
})
|
||||
dataset[split] = Dataset.from_list(raw_data)
|
||||
return dataset
|
@ -28,6 +28,7 @@ class SlurmRunner(BaseRunner):
|
||||
retry (int): Number of retries if the job failed. Defaults to 2.
|
||||
partition (str): Slurm partition name. Defaults to None.
|
||||
quotatype (str): Slurm quota type. Defaults to None.
|
||||
qos (str): Slurm quality of service. Defaults to None.
|
||||
debug (bool): Whether to run in debug mode. Defaults to False.
|
||||
lark_bot_url (str): Lark bot url. Defaults to None.
|
||||
"""
|
||||
@ -38,6 +39,7 @@ class SlurmRunner(BaseRunner):
|
||||
retry: int = 2,
|
||||
partition: str = None,
|
||||
quotatype: str = None,
|
||||
qos: str = None,
|
||||
debug: bool = False,
|
||||
lark_bot_url: str = None):
|
||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||
@ -45,6 +47,7 @@ class SlurmRunner(BaseRunner):
|
||||
self.retry = retry
|
||||
self.partition = partition
|
||||
self.quotatype = quotatype
|
||||
self.qos = qos
|
||||
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
"""Launch multiple tasks.
|
||||
@ -97,6 +100,8 @@ class SlurmRunner(BaseRunner):
|
||||
tmpl += f' -p {self.partition}'
|
||||
if self.quotatype:
|
||||
tmpl += f' --quotatype={self.quotatype}'
|
||||
if self.qos:
|
||||
tmpl += f' --qos={self.qos}'
|
||||
if num_gpus > 0:
|
||||
tmpl += f' --gres=gpu:{num_gpus}'
|
||||
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
||||
|
6
run.py
6
run.py
@ -129,6 +129,10 @@ def parse_slurm_args(slurm_parser):
|
||||
help='Slurm quota type',
|
||||
default=None,
|
||||
type=str)
|
||||
slurm_parser.add_argument('--qos',
|
||||
help='Slurm quality of service',
|
||||
default=None,
|
||||
type=str)
|
||||
|
||||
|
||||
def parse_dlc_args(dlc_parser):
|
||||
@ -286,6 +290,7 @@ def exec_infer_runner(tasks, args, cfg):
|
||||
max_num_workers=args.max_num_workers,
|
||||
partition=args.partition,
|
||||
quotatype=args.quotatype,
|
||||
qos=args.qos,
|
||||
retry=args.retry,
|
||||
debug=args.debug,
|
||||
lark_bot_url=cfg['lark_bot_url'])
|
||||
@ -311,6 +316,7 @@ def exec_eval_runner(tasks, args, cfg):
|
||||
max_num_workers=args.max_num_workers,
|
||||
partition=args.partition,
|
||||
quotatype=args.quotatype,
|
||||
qos=args.qos,
|
||||
retry=args.retry,
|
||||
debug=args.debug,
|
||||
lark_bot_url=cfg['lark_bot_url'])
|
||||
|
Loading…
Reference in New Issue
Block a user