mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Add release contribution
This commit is contained in:
parent
e6b5bdcb87
commit
c94cc94348
4
configs/datasets/ARC_c/ARC_c_gen.py
Normal file
4
configs/datasets/ARC_c/ARC_c_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .ARC_c_gen_3f3039 import ARC_c_datasets # noqa: F401, F403
|
42
configs/datasets/ARC_c/ARC_c_gen_3f3039.py
Normal file
42
configs/datasets/ARC_c/ARC_c_gen_3f3039.py
Normal file
@ -0,0 +1,42 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import ARCDataset
|
||||
|
||||
ARC_c_reader_cfg = dict(
|
||||
input_columns=["question", "textA", "textB", "textC", "textD"],
|
||||
output_column="answerKey")
|
||||
|
||||
ARC_c_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Question: {question}\nA. {textA}\nB. {textB}\nC. {textC}\nD. {textD}\nAnswer:"
|
||||
)
|
||||
], ),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
ARC_c_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type="first-capital"),
|
||||
)
|
||||
|
||||
ARC_c_datasets = [
|
||||
dict(
|
||||
abbr="ARC-c",
|
||||
type=ARCDataset,
|
||||
path="./data/ARC/ARC-c/ARC-Challenge-Dev.jsonl",
|
||||
reader_cfg=ARC_c_reader_cfg,
|
||||
infer_cfg=ARC_c_infer_cfg,
|
||||
eval_cfg=ARC_c_eval_cfg,
|
||||
)
|
||||
]
|
53
configs/datasets/ARC_e/ARC_e_ppl_f86898.py
Normal file
53
configs/datasets/ARC_e/ARC_e_ppl_f86898.py
Normal file
@ -0,0 +1,53 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import ARCDataset
|
||||
|
||||
ARC_e_reader_cfg = dict(
|
||||
input_columns=['question', 'textA', 'textB', 'textC', 'textD'],
|
||||
output_column='answerKey')
|
||||
|
||||
ARC_e_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
"A":
|
||||
dict(
|
||||
round=[
|
||||
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
|
||||
dict(role="BOT", prompt="{textA}")
|
||||
], ),
|
||||
"B":
|
||||
dict(
|
||||
round=[
|
||||
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
|
||||
dict(role="BOT", prompt="{textB}")
|
||||
], ),
|
||||
"C":
|
||||
dict(
|
||||
round=[
|
||||
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
|
||||
dict(role="BOT", prompt="{textC}")
|
||||
], ),
|
||||
"D":
|
||||
dict(
|
||||
round=[
|
||||
dict(role="HUMAN", prompt="Question: {question}\nAnswer: "),
|
||||
dict(role="BOT", prompt="{textD}")
|
||||
], ),
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
ARC_e_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
ARC_e_datasets = [
|
||||
dict(
|
||||
type=ARCDataset,
|
||||
abbr='ARC-e',
|
||||
path='./data/ARC/ARC-e/ARC-Easy-Dev.jsonl',
|
||||
reader_cfg=ARC_e_reader_cfg,
|
||||
infer_cfg=ARC_e_infer_cfg,
|
||||
eval_cfg=ARC_e_eval_cfg)
|
||||
]
|
27
configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_220a83.py
Normal file
27
configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_220a83.py
Normal file
@ -0,0 +1,27 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import EMEvaluator
|
||||
from opencompass.datasets import CMRCDataset
|
||||
|
||||
CMRC_reader_cfg = dict(
|
||||
input_columns=['question', 'context'], output_column='answers')
|
||||
|
||||
CMRC_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template="文章:{context}\n根据上文,回答如下问题: {question}\n答:"),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
CMRC_eval_cfg = dict(evaluator=dict(type=EMEvaluator), )
|
||||
|
||||
CMRC_datasets = [
|
||||
dict(
|
||||
type=CMRCDataset,
|
||||
abbr='CMRC_dev',
|
||||
path='./data/CLUE/CMRC/dev.json',
|
||||
reader_cfg=CMRC_reader_cfg,
|
||||
infer_cfg=CMRC_infer_cfg,
|
||||
eval_cfg=CMRC_eval_cfg),
|
||||
]
|
50
configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_c83c36.py
Normal file
50
configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_c83c36.py
Normal file
@ -0,0 +1,50 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
afqmc_reader_cfg = dict(
|
||||
input_columns=['sentence1', 'sentence2'],
|
||||
output_column='label',
|
||||
test_split='train')
|
||||
|
||||
afqmc_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0:
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"语句一:“{sentence1}”\n语句二:“{sentence2}”\n语句一与语句二是关于蚂蚁金融产品的疑问,两者所询问的内容是否完全一致?"
|
||||
),
|
||||
dict(role="BOT", prompt="不完全一致")
|
||||
]),
|
||||
1:
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"语句一:“{sentence1}”\n语句二:“{sentence2}”\n语句一与语句二是关于蚂蚁金融产品的疑问,两者所询问的内容是否完全一致?"
|
||||
),
|
||||
dict(role="BOT", prompt="完全一致")
|
||||
]),
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
afqmc_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
afqmc_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='afqmc-dev',
|
||||
path='json',
|
||||
data_files='./data/CLUE/AFQMC/dev.json',
|
||||
split='train',
|
||||
reader_cfg=afqmc_reader_cfg,
|
||||
infer_cfg=afqmc_infer_cfg,
|
||||
eval_cfg=afqmc_eval_cfg),
|
||||
]
|
4
configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl.py
Normal file
4
configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .CLUE_cmnli_ppl_1c652a import cmnli_datasets # noqa: F401, F403
|
52
configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_991e1b.py
Normal file
52
configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_991e1b.py
Normal file
@ -0,0 +1,52 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
cmnli_reader_cfg = dict(
|
||||
input_columns=['sentence1', 'sentence2'],
|
||||
output_column='label',
|
||||
test_split='train')
|
||||
|
||||
cmnli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
'contradiction':
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="阅读文章:{sentence1}\n根据上文,回答如下问题:{sentence2}?"),
|
||||
dict(role="BOT", prompt="错")
|
||||
]),
|
||||
'entailment':
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="阅读文章:{sentence1}\n根据上文,回答如下问题:{sentence2}?"),
|
||||
dict(role="BOT", prompt="对")
|
||||
]),
|
||||
'neutral':
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN", prompt="如果{sentence1}为真,那么{sentence2}也为真吗?"),
|
||||
dict(role="BOT", prompt="可能")
|
||||
]),
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
cmnli_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='cmnli',
|
||||
path='json',
|
||||
split='train',
|
||||
data_files='./data/CLUE/cmnli/cmnli_public/dev.json',
|
||||
reader_cfg=cmnli_reader_cfg,
|
||||
infer_cfg=cmnli_infer_cfg,
|
||||
eval_cfg=cmnli_eval_cfg)
|
||||
]
|
36
configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_b78ad4.py
Normal file
36
configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_b78ad4.py
Normal file
@ -0,0 +1,36 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
cmnli_reader_cfg = dict(
|
||||
input_columns=['sentence1', 'sentence2'],
|
||||
output_column='label',
|
||||
test_split='train')
|
||||
|
||||
cmnli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
'contradiction':
|
||||
'阅读文章:{sentence1}\n根据上文,回答如下问题: {sentence2}?\n答:错',
|
||||
'entailment': '阅读文章:{sentence1}\n根据上文,回答如下问题: {sentence2}?\n答:对',
|
||||
'neutral': '如果{sentence1}为真,那么{sentence2}也为真吗?可能'
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
cmnli_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
cmnli_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='cmnli',
|
||||
path='json',
|
||||
split='train',
|
||||
data_files='./data/CLUE/cmnli/cmnli_public/dev.json',
|
||||
reader_cfg=cmnli_reader_cfg,
|
||||
infer_cfg=cmnli_infer_cfg,
|
||||
eval_cfg=cmnli_eval_cfg)
|
||||
]
|
43
configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_01899f.py
Normal file
43
configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_01899f.py
Normal file
@ -0,0 +1,43 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import cmnliDataset_V2
|
||||
|
||||
ocnli_reader_cfg = dict(
|
||||
input_columns=["sentence1", "sentence2"],
|
||||
output_column="label",
|
||||
)
|
||||
|
||||
# TODO: two prompt templates for ocnli
|
||||
ocnli_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"阅读文章:{sentence1}\n根据上文,回答如下问题:{sentence2}\nA. 对\nB. 错\nC. 可能\n请从“A”,“B”,“C”中进行选择。\n答:"
|
||||
),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
ocnli_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type="first-capital"),
|
||||
)
|
||||
|
||||
ocnli_datasets = [
|
||||
dict(
|
||||
abbr="ocnli",
|
||||
type=cmnliDataset_V2, # ocnli share the same format with cmnli
|
||||
path="./data/CLUE/OCNLI/dev.json",
|
||||
reader_cfg=ocnli_reader_cfg,
|
||||
infer_cfg=ocnli_infer_cfg,
|
||||
eval_cfg=ocnli_eval_cfg,
|
||||
)
|
||||
]
|
43
configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_4f864a.py
Normal file
43
configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_4f864a.py
Normal file
@ -0,0 +1,43 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
bustm_reader_cfg = dict(
|
||||
input_columns=['sentence1', 'sentence2'],
|
||||
output_column='label',
|
||||
test_split='train')
|
||||
|
||||
bustm_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0: "{sentence1}。\n{sentence2}。\n两句话说的毫不相关。",
|
||||
1: "{sentence1}。\n{sentence2}。\n两句话说的一个意思。"
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
bustm_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
bustm_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='bustm-dev',
|
||||
path='json',
|
||||
data_files='./data/FewCLUE/bustm/dev_few_all.json',
|
||||
split='train',
|
||||
reader_cfg=bustm_reader_cfg,
|
||||
infer_cfg=bustm_infer_cfg,
|
||||
eval_cfg=bustm_eval_cfg),
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='bustm-test',
|
||||
path='json',
|
||||
data_files='./data/FewCLUE/bustm/test_public.json',
|
||||
split='train',
|
||||
reader_cfg=bustm_reader_cfg,
|
||||
infer_cfg=bustm_infer_cfg,
|
||||
eval_cfg=bustm_eval_cfg)
|
||||
]
|
4
configs/datasets/FewCLUE_chid/FewCLUE_chid_gen.py
Normal file
4
configs/datasets/FewCLUE_chid/FewCLUE_chid_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .FewCLUE_chid_gen_686c63 import chid_datasets # noqa: F401, F403
|
4
configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl.py
Normal file
4
configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .FewCLUE_cluewsc_ppl_2a9e61 import cluewsc_datasets # noqa: F401, F403
|
50
configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_1b0c02.py
Normal file
50
configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_1b0c02.py
Normal file
@ -0,0 +1,50 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import CslDataset_V2
|
||||
|
||||
csl_reader_cfg = dict(
|
||||
input_columns=["abst", "keywords"],
|
||||
output_column="label",
|
||||
)
|
||||
|
||||
csl_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"摘要:{abst}\n关键词:{keywords}\n上述关键词出现在学术期刊中是否恰当?\nA. 否\nB. 是\n请从”A“,”B“中进行选择。\n答:"
|
||||
)
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
csl_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type="first-capital"),
|
||||
)
|
||||
|
||||
csl_datasets = [
|
||||
dict(
|
||||
abbr="csl_dev",
|
||||
type=CslDataset_V2,
|
||||
path="./data/FewCLUE/csl/dev_few_all.json",
|
||||
reader_cfg=csl_reader_cfg,
|
||||
infer_cfg=csl_infer_cfg,
|
||||
eval_cfg=csl_eval_cfg,
|
||||
),
|
||||
dict(
|
||||
abbr="csl_test",
|
||||
type=CslDataset_V2,
|
||||
path="./data/FewCLUE/csl/test_public.json",
|
||||
reader_cfg=csl_reader_cfg,
|
||||
infer_cfg=csl_infer_cfg,
|
||||
eval_cfg=csl_eval_cfg,
|
||||
),
|
||||
]
|
@ -0,0 +1,48 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import eprstmtDataset_V2
|
||||
|
||||
eprstmt_reader_cfg = dict(
|
||||
input_columns=["sentence"], output_column="label", test_split="train")
|
||||
|
||||
eprstmt_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
'内容: "{sentence}"。请对上述内容进行情绪分类。\nA. 积极\nB. 消极\n请从”A“,”B“中进行选择。\n答:'
|
||||
),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
eprstmt_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type="first-capital"),
|
||||
)
|
||||
|
||||
eprstmt_datasets = [
|
||||
dict(
|
||||
abbr="eprstmt-dev",
|
||||
type=eprstmtDataset_V2,
|
||||
path="./data/FewCLUE/eprstmt/dev_few_all.json",
|
||||
reader_cfg=eprstmt_reader_cfg,
|
||||
infer_cfg=eprstmt_infer_cfg,
|
||||
eval_cfg=eprstmt_eval_cfg,
|
||||
),
|
||||
dict(
|
||||
abbr="eprstmt-test",
|
||||
type=eprstmtDataset_V2,
|
||||
path="./data/FewCLUE/eprstmt/test_public.json",
|
||||
reader_cfg=eprstmt_reader_cfg,
|
||||
infer_cfg=eprstmt_infer_cfg,
|
||||
eval_cfg=eprstmt_eval_cfg,
|
||||
),
|
||||
]
|
@ -0,0 +1,49 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
eprstmt_reader_cfg = dict(
|
||||
input_columns=['sentence'], output_column='label', test_split='train')
|
||||
|
||||
eprstmt_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
'Negative':
|
||||
dict(round=[
|
||||
dict(role='HUMAN', prompt='内容: "{sentence}"。情绪分类:'),
|
||||
dict(role='BOT', prompt='消极。')
|
||||
]),
|
||||
'Positive':
|
||||
dict(round=[
|
||||
dict(role='HUMAN', prompt='内容: "{sentence}"。情绪分类:'),
|
||||
dict(role='BOT', prompt='积极。')
|
||||
]),
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
eprstmt_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
eprstmt_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='eprstmt-dev',
|
||||
path='json',
|
||||
data_files='./data/FewCLUE/eprstmt/dev_few_all.json',
|
||||
split='train',
|
||||
reader_cfg=eprstmt_reader_cfg,
|
||||
infer_cfg=eprstmt_infer_cfg,
|
||||
eval_cfg=eprstmt_eval_cfg),
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='eprstmt-test',
|
||||
path='json',
|
||||
data_files='./data/FewCLUE/eprstmt/test_public.json',
|
||||
split='train',
|
||||
reader_cfg=eprstmt_reader_cfg,
|
||||
infer_cfg=eprstmt_infer_cfg,
|
||||
eval_cfg=eprstmt_eval_cfg)
|
||||
]
|
4
configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen.py
Normal file
4
configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .FewCLUE_tnews_gen_8d59ba import tnews_datasets # noqa: F401, F403
|
54
configs/datasets/PJExam/PJExam_gen_785c37.py
Normal file
54
configs/datasets/PJExam/PJExam_gen_785c37.py
Normal file
@ -0,0 +1,54 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import PJExamDataset, PJExamEvaluator
|
||||
|
||||
PJExam_datasets = []
|
||||
for _name in [
|
||||
'gk-2022-v1', 'gk-2022-v1-math', 'gk-2023-v1', 'gk-2023-v1-math',
|
||||
'gk-2023-v2', 'gk-2023-v2-math', 'zk-2022-v1'
|
||||
]:
|
||||
_hint = "请你做一道</major>选择题\n请你一步一步思考并将思考过程写在【解析】和<eoe>之间。你将从A,B,C,D中选出正确的答案,并写在【答案】和<eoa>之间。\n例如:【答案】A<eoa>\n完整的题目回答的格式如下:\n【解析】...<eoe>\n【答案】...<eoa>\n请你严格按照上述格式作答。\n题目如下:\n"
|
||||
_reader_cfg = {
|
||||
"input_columns": ['question'],
|
||||
"output_column": 'std_ans',
|
||||
},
|
||||
_infer_cfg = {
|
||||
"ice_template": {
|
||||
"type": PromptTemplate,
|
||||
"template": {
|
||||
"round": [{
|
||||
"role": "HUMAN",
|
||||
"prompt": _hint + "{question}",
|
||||
}]
|
||||
},
|
||||
"ice_token": "</E>"
|
||||
},
|
||||
"retriever": {
|
||||
"type": ZeroRetriever
|
||||
},
|
||||
"inferencer": {
|
||||
"type": GenInferencer,
|
||||
"max_out_len": 1024,
|
||||
}
|
||||
}
|
||||
_eval_cfg = {
|
||||
"evaluator": {
|
||||
"type": PJExamEvaluator
|
||||
},
|
||||
"pred_role": "BOT",
|
||||
"ds_column": "eval_infos"
|
||||
}
|
||||
_dataset = {
|
||||
"type": PJExamDataset,
|
||||
"abbr": "PJExamDataset-" + _name,
|
||||
"path": './data/PJExam',
|
||||
"name": _name,
|
||||
"reader_cfg": _reader_cfg,
|
||||
"infer_cfg": _infer_cfg,
|
||||
"eval_cfg": _eval_cfg,
|
||||
}
|
||||
|
||||
PJExam_datasets.append(_dataset)
|
||||
|
||||
del _name, _hint, _reader_cfg, _infer_cfg, _eval_cfg, _dataset
|
34
configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl_a65d62.py
Normal file
34
configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl_a65d62.py
Normal file
@ -0,0 +1,34 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
AX_b_reader_cfg = dict(
|
||||
input_columns=['sentence1', 'sentence2'],
|
||||
output_column='label',
|
||||
test_split='train')
|
||||
|
||||
AX_b_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
'entailment': '{sentence1}?entailment, {sentence2}',
|
||||
'not_entailment': '{sentence1}?not_entailment, {sentence2}'
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
AX_b_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
AX_b_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='AX_b',
|
||||
path='json',
|
||||
data_files='./data/SuperGLUE/AX-b/AX-b.jsonl',
|
||||
split='train',
|
||||
reader_cfg=AX_b_reader_cfg,
|
||||
infer_cfg=AX_b_infer_cfg,
|
||||
eval_cfg=AX_b_eval_cfg)
|
||||
]
|
4
configs/datasets/SuperGLUE_AX_g/SuperGLUE_AX_g_ppl.py
Normal file
4
configs/datasets/SuperGLUE_AX_g/SuperGLUE_AX_g_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .SuperGLUE_AX_g_ppl_8d9bf9 import AX_g_datasets # noqa: F401, F403
|
34
configs/datasets/SuperGLUE_AX_g/SuperGLUE_AX_g_ppl_d489ee.py
Normal file
34
configs/datasets/SuperGLUE_AX_g/SuperGLUE_AX_g_ppl_d489ee.py
Normal file
@ -0,0 +1,34 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
AX_g_reader_cfg = dict(
|
||||
input_columns=['hypothesis', 'premise'],
|
||||
output_column='label',
|
||||
test_split='train')
|
||||
|
||||
AX_g_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
'entailment': '{premise}?entailment, {hypothesis}',
|
||||
'not_entailment': '{premise}?not_entailment, {hypothesis}'
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
AX_g_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
AX_g_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='AX_g',
|
||||
path='json',
|
||||
data_files='./data/SuperGLUE/AX-g/AX-g.jsonl',
|
||||
split='train',
|
||||
reader_cfg=AX_g_reader_cfg,
|
||||
infer_cfg=AX_g_infer_cfg,
|
||||
eval_cfg=AX_g_eval_cfg)
|
||||
]
|
@ -0,0 +1,34 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import BoolQDataset
|
||||
|
||||
BoolQ_reader_cfg = dict(
|
||||
input_columns=['question', 'passage'],
|
||||
output_column='answer',
|
||||
test_split='train')
|
||||
|
||||
BoolQ_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0: "Passage:{passage}。\nQuestion:{question}。\nAnswer: No.",
|
||||
1: "Passage:{passage}。\nQuestion:{question}。\nAnswer: Yes.",
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
BoolQ_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
BoolQ_datasets = [
|
||||
dict(
|
||||
type=BoolQDataset,
|
||||
abbr='BoolQ',
|
||||
path='json',
|
||||
data_files='./data/SuperGLUE/BoolQ/val.jsonl',
|
||||
split='train',
|
||||
reader_cfg=BoolQ_reader_cfg,
|
||||
infer_cfg=BoolQ_infer_cfg,
|
||||
eval_cfg=BoolQ_eval_cfg)
|
||||
]
|
45
configs/datasets/SuperGLUE_COPA/SuperGLUE_COPA_ppl_0ef2f8.py
Normal file
45
configs/datasets/SuperGLUE_COPA/SuperGLUE_COPA_ppl_0ef2f8.py
Normal file
@ -0,0 +1,45 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
COPA_reader_cfg = dict(
|
||||
input_columns=["question", "premise", "choice1", "choice2"],
|
||||
output_column="label",
|
||||
test_split="train")
|
||||
|
||||
COPA_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0:
|
||||
dict(round=[
|
||||
dict(role="HUMAN", prompt="{premise}\nQuestion: {question}\nAnswer:"),
|
||||
dict(role="BOT", prompt="{choice1}"),
|
||||
]),
|
||||
1:
|
||||
dict(round=[
|
||||
dict(role="HUMAN", prompt="{premise}\nQuestion: {question}\nAnswer:"),
|
||||
dict(role="BOT", prompt="{choice2}"),
|
||||
]),
|
||||
},
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
COPA_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
COPA_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr="COPA",
|
||||
path="json",
|
||||
data_files="./data/SuperGLUE/COPA/val.jsonl",
|
||||
split="train",
|
||||
reader_cfg=COPA_reader_cfg,
|
||||
infer_cfg=COPA_infer_cfg,
|
||||
eval_cfg=COPA_eval_cfg,
|
||||
)
|
||||
]
|
@ -0,0 +1,47 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import MultiRCDataset
|
||||
|
||||
MultiRC_reader_cfg = dict(
|
||||
input_columns=["question", "text", "answer"],
|
||||
output_column="label",
|
||||
)
|
||||
|
||||
MultiRC_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0:
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="{text}\nQuestion: {question}\nAnswer: {answer}\nIs it true?"),
|
||||
dict(role="BOT", prompt="No, it is false."),
|
||||
]),
|
||||
1:
|
||||
dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="{text}\nQuestion: {question}\nAnswer: {answer}\nIs it true?"),
|
||||
dict(role="BOT", prompt="Yes, it is true."),
|
||||
]),
|
||||
},
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
MultiRC_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
MultiRC_datasets = [
|
||||
dict(
|
||||
type=MultiRCDataset,
|
||||
abbr="MultiRC",
|
||||
path="./data/SuperGLUE/MultiRC/val.jsonl",
|
||||
reader_cfg=MultiRC_reader_cfg,
|
||||
infer_cfg=MultiRC_infer_cfg,
|
||||
eval_cfg=MultiRC_eval_cfg,
|
||||
)
|
||||
]
|
4
configs/datasets/SuperGLUE_RTE/SuperGLUE_RTE_gen.py
Normal file
4
configs/datasets/SuperGLUE_RTE/SuperGLUE_RTE_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .SuperGLUE_RTE_gen_ce346a import RTE_datasets # noqa: F401, F403
|
4
configs/datasets/SuperGLUE_WSC/SuperGLUE_WSC_gen.py
Normal file
4
configs/datasets/SuperGLUE_WSC/SuperGLUE_WSC_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .SuperGLUE_WSC_gen_d8d441 import WSC_datasets # noqa: F401, F403
|
42
configs/datasets/SuperGLUE_WSC/SuperGLUE_WSC_gen_d8d441.py
Normal file
42
configs/datasets/SuperGLUE_WSC/SuperGLUE_WSC_gen_d8d441.py
Normal file
@ -0,0 +1,42 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import WSCDataset_V2
|
||||
|
||||
WSC_reader_cfg = dict(
|
||||
input_columns=["span1", "span2", "text"],
|
||||
output_column="label",
|
||||
)
|
||||
|
||||
WSC_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"{text}\nIs '{span1}' and '{span2}' refers to the same entity in the above sentence?\nA. Yes\nB. No\nAnseer:"
|
||||
),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
WSC_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type="first-capital"),
|
||||
)
|
||||
|
||||
WSC_datasets = [
|
||||
dict(
|
||||
abbr="WSC",
|
||||
type=WSCDataset_V2,
|
||||
path="./data/SuperGLUE/WSC/val.jsonl",
|
||||
reader_cfg=WSC_reader_cfg,
|
||||
infer_cfg=WSC_infer_cfg,
|
||||
eval_cfg=WSC_eval_cfg,
|
||||
)
|
||||
]
|
38
configs/datasets/SuperGLUE_WiC/SuperGLUE_WiC_ppl_ab6e84.py
Normal file
38
configs/datasets/SuperGLUE_WiC/SuperGLUE_WiC_ppl_ab6e84.py
Normal file
@ -0,0 +1,38 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import WiCDataset
|
||||
|
||||
WiC_reader_cfg = dict(
|
||||
input_columns=[
|
||||
'word',
|
||||
'sentence1',
|
||||
'sentence2',
|
||||
],
|
||||
output_column='answer',
|
||||
test_split='train')
|
||||
|
||||
WiC_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0: '{word} in {sentence1} and {sentence2} is different.',
|
||||
1: '{word} in {sentence1} and {sentence2} is same.'
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
WiC_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
WiC_datasets = [
|
||||
dict(
|
||||
type=WiCDataset,
|
||||
abbr='WiC',
|
||||
path='json',
|
||||
data_files='./data/SuperGLUE/WiC/val.jsonl',
|
||||
split='train',
|
||||
reader_cfg=WiC_reader_cfg,
|
||||
infer_cfg=WiC_infer_cfg,
|
||||
eval_cfg=WiC_eval_cfg)
|
||||
]
|
29
configs/datasets/XLSum/XLSum_gen_1cc5f6.py
Normal file
29
configs/datasets/XLSum/XLSum_gen_1cc5f6.py
Normal file
@ -0,0 +1,29 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import RougeEvaluator
|
||||
from opencompass.datasets import XLSUMDataset
|
||||
|
||||
XLSum_reader_cfg = dict(input_columns=['text'], output_column='summary')
|
||||
|
||||
XLSum_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='Document:{text}\n'
|
||||
'Based on the previous text, provide a brief single summary:'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
XLSum_eval_cfg = dict(
|
||||
evaluator=dict(type=RougeEvaluator),
|
||||
pred_postprocessor=dict(type='Xsum'),
|
||||
)
|
||||
|
||||
XLSum_datasets = [
|
||||
dict(
|
||||
type=XLSUMDataset,
|
||||
path='csebuetnlp/xlsum',
|
||||
reader_cfg=XLSum_reader_cfg,
|
||||
infer_cfg=XLSum_infer_cfg,
|
||||
eval_cfg=XLSum_eval_cfg)
|
||||
]
|
4
configs/datasets/apps/apps_gen.py
Normal file
4
configs/datasets/apps/apps_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .apps_gen_193ec0 import apps_datasets # noqa: F401, F403
|
35
configs/datasets/civilcomments/civilcomments_ppl_fb1666.py
Normal file
35
configs/datasets/civilcomments/civilcomments_ppl_fb1666.py
Normal file
@ -0,0 +1,35 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import CLPInferencer
|
||||
from opencompass.openicl.icl_evaluator import AUCROCEvaluator
|
||||
from opencompass.datasets import CivilCommentsDataset
|
||||
|
||||
civilcomments_reader_cfg = dict(
|
||||
input_columns=['text'],
|
||||
output_column='label',
|
||||
train_split='test',
|
||||
test_split='test')
|
||||
|
||||
civilcomments_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="Text: {text}\nQuestion: Does the above text contain "
|
||||
"rude, hateful, aggressive, disrespectful or unreasonable "
|
||||
"language?\nAnswer:")
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=CLPInferencer))
|
||||
|
||||
civilcomments_eval_cfg = dict(evaluator=dict(type=AUCROCEvaluator), )
|
||||
|
||||
civilcomments_datasets = [
|
||||
dict(
|
||||
type=CivilCommentsDataset,
|
||||
path='civil_comments',
|
||||
reader_cfg=civilcomments_reader_cfg,
|
||||
infer_cfg=civilcomments_infer_cfg,
|
||||
eval_cfg=civilcomments_eval_cfg)
|
||||
]
|
57
configs/datasets/collections/base_medium.py
Normal file
57
configs/datasets/collections/base_medium.py
Normal file
@ -0,0 +1,57 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from ..mmlu.mmlu_ppl_c6bbe6 import mmlu_datasets
|
||||
from ..ceval.ceval_ppl_275812 import ceval_datasets
|
||||
from ..agieval.agieval_mixed_2f14ad import agieval_datasets
|
||||
from ..GaokaoBench.GaokaoBench_mixed_f2038e import GaokaoBench_datasets
|
||||
from ..bbh.bbh_gen_58abc3 import bbh_datasets
|
||||
from ..humaneval.humaneval_gen_d428f1 import humaneval_datasets
|
||||
from ..mbpp.mbpp_gen_4104e4 import mbpp_datasets
|
||||
from ..CLUE_C3.CLUE_C3_ppl_588820 import C3_datasets
|
||||
from ..CLUE_CMRC.CLUE_CMRC_gen_72a8d5 import CMRC_datasets
|
||||
from ..CLUE_DRCD.CLUE_DRCD_gen_03b96b import DRCD_datasets
|
||||
from ..CLUE_afqmc.CLUE_afqmc_ppl_c83c36 import afqmc_datasets
|
||||
from ..CLUE_cmnli.CLUE_cmnli_ppl_1c652a import cmnli_datasets
|
||||
from ..CLUE_ocnli.CLUE_ocnli_ppl_f103ab import ocnli_datasets
|
||||
from ..FewCLUE_bustm.FewCLUE_bustm_ppl_47f2ab import bustm_datasets
|
||||
from ..FewCLUE_chid.FewCLUE_chid_ppl_b6cd88 import chid_datasets
|
||||
from ..FewCLUE_cluewsc.FewCLUE_cluewsc_ppl_2a9e61 import cluewsc_datasets
|
||||
from ..FewCLUE_csl.FewCLUE_csl_ppl_8eee08 import csl_datasets
|
||||
from ..FewCLUE_eprstmt.FewCLUE_eprstmt_ppl_d3c387 import eprstmt_datasets
|
||||
from ..FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_ppl_b828fc import ocnli_fc_datasets
|
||||
from ..FewCLUE_tnews.FewCLUE_tnews_ppl_784b9e import tnews_datasets
|
||||
from ..lcsts.lcsts_gen_427fde import lcsts_datasets
|
||||
from ..lambada.lambada_gen_7ffe3d import lambada_datasets
|
||||
from ..storycloze.storycloze_ppl_c1912d import storycloze_datasets
|
||||
from ..SuperGLUE_AX_b.SuperGLUE_AX_b_ppl_4bd960 import AX_b_datasets
|
||||
from ..SuperGLUE_AX_g.SuperGLUE_AX_g_ppl_8d9bf9 import AX_g_datasets
|
||||
from ..SuperGLUE_BoolQ.SuperGLUE_BoolQ_ppl_f80fb0 import BoolQ_datasets
|
||||
from ..SuperGLUE_CB.SuperGLUE_CB_ppl_32adbb import CB_datasets
|
||||
from ..SuperGLUE_COPA.SuperGLUE_COPA_ppl_ddb78c import COPA_datasets
|
||||
from ..SuperGLUE_MultiRC.SuperGLUE_MultiRC_ppl_83a304 import MultiRC_datasets
|
||||
from ..SuperGLUE_RTE.SuperGLUE_RTE_ppl_29a22c import RTE_datasets
|
||||
from ..SuperGLUE_ReCoRD.SuperGLUE_ReCoRD_gen_d8f19c import ReCoRD_datasets
|
||||
from ..SuperGLUE_WiC.SuperGLUE_WiC_ppl_4118db import WiC_datasets
|
||||
from ..SuperGLUE_WSC.SuperGLUE_WSC_ppl_85f45f import WSC_datasets
|
||||
from ..race.race_ppl_04e06a import race_datasets
|
||||
from ..Xsum.Xsum_gen_d2126e import Xsum_datasets
|
||||
from ..gsm8k.gsm8k_gen_2dd372 import gsm8k_datasets
|
||||
from ..summedits.summedits_ppl_163352 import summedits_datasets
|
||||
from ..math.math_gen_78bcba import math_datasets
|
||||
from ..TheoremQA.TheoremQA_gen_24bc13 import TheoremQA_datasets
|
||||
from ..hellaswag.hellaswag_ppl_8e07d6 import hellaswag_datasets
|
||||
from ..ARC_e.ARC_e_ppl_f86898 import ARC_e_datasets
|
||||
from ..ARC_c.ARC_c_ppl_ba951c import ARC_c_datasets
|
||||
from ..commonsenseqa.commonsenseqa_ppl_2ca33c import commonsenseqa_datasets
|
||||
from ..piqa.piqa_ppl_788dbe import piqa_datasets
|
||||
from ..siqa.siqa_ppl_049da0 import siqa_datasets
|
||||
from ..strategyqa.strategyqa_gen_be3f8d import strategyqa_datasets
|
||||
from ..winogrande.winogrande_ppl_00f8ad import winogrande_datasets
|
||||
from ..obqa.obqa_ppl_2b5b12 import obqa_datasets
|
||||
from ..nq.nq_gen_c00b89 import nq_datasets
|
||||
from ..triviaqa.triviaqa_gen_cc3cbf import triviaqa_datasets
|
||||
from ..flores.flores_gen_8eb9ca import flores_datasets
|
||||
from ..crowspairs.crowspairs_ppl_f60797 import crowspairs_datasets
|
||||
|
||||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
7
configs/datasets/collections/example.py
Normal file
7
configs/datasets/collections/example.py
Normal file
@ -0,0 +1,7 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from ..piqa.piqa_gen_8287ae import piqa_datasets
|
||||
from ..nq.nq_gen_a6ffca import nq_datasets
|
||||
|
||||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
45
configs/datasets/commonsenseqa/commonsenseqa_ppl_665f66.py
Normal file
45
configs/datasets/commonsenseqa/commonsenseqa_ppl_665f66.py
Normal file
@ -0,0 +1,45 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import MDLRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import commonsenseqaDataset
|
||||
|
||||
_ice_template = dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
'A': "</E>Answer the following question:\n{question}\nAnswer: {A}",
|
||||
'B': "</E>Answer the following question:\n{question}\nAnswer: {B}",
|
||||
'C': "</E>Answer the following question:\n{question}\nAnswer: {C}",
|
||||
'D': "</E>Answer the following question:\n{question}\nAnswer: {D}",
|
||||
'E': "</E>Answer the following question:\n{question}\nAnswer: {E}",
|
||||
},
|
||||
ice_token='</E>')
|
||||
|
||||
commonsenseqa_infer_cfg = dict(
|
||||
ice_template=_ice_template,
|
||||
retriever=dict(
|
||||
type=MDLRetriever,
|
||||
ice_num=8,
|
||||
candidate_num=30,
|
||||
select_time=10,
|
||||
seed=1,
|
||||
batch_size=12,
|
||||
ice_template=_ice_template),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
commonsenseqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
commonsenseqa_datasets = [
|
||||
dict(
|
||||
type=commonsenseqaDataset,
|
||||
path='commonsense_qa',
|
||||
reader_cfg=dict(
|
||||
input_columns=['question', 'A', 'B', 'C', 'D', 'E'],
|
||||
output_column='answerKey',
|
||||
test_split='validation',
|
||||
),
|
||||
infer_cfg=commonsenseqa_infer_cfg,
|
||||
eval_cfg=commonsenseqa_eval_cfg)
|
||||
]
|
||||
|
||||
del _ice_template
|
55
configs/datasets/commonsenseqa/commonsenseqa_ppl_ddd9f7.py
Normal file
55
configs/datasets/commonsenseqa/commonsenseqa_ppl_ddd9f7.py
Normal file
@ -0,0 +1,55 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import MDLRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import commonsenseqaDataset
|
||||
|
||||
commonsenseqa_reader_cfg = dict(
|
||||
input_columns=['question', 'A', 'B', 'C', 'D', 'E'],
|
||||
output_column='answerKey',
|
||||
test_split='validation')
|
||||
|
||||
_ice_template = dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
ans: dict(
|
||||
begin=[
|
||||
dict(
|
||||
role="SYSTEM",
|
||||
fallback_role="HUMAN",
|
||||
prompt=f"Answer the following question:"), '</E>'
|
||||
],
|
||||
round=[
|
||||
dict(role="HUMAN", prompt="{question}"),
|
||||
dict(role="BOT", prompt=ans_token),
|
||||
])
|
||||
for ans, ans_token in [["A", "{A}"], ["B", "{B}"],
|
||||
["C", "{C}"], ["D", "{D}"],
|
||||
["E", "{E}"]]
|
||||
},
|
||||
ice_token='</E>')
|
||||
|
||||
commonsenseqa_infer_cfg = dict(
|
||||
ice_template=_ice_template,
|
||||
retriever=dict(
|
||||
type=MDLRetriever,
|
||||
ice_num=8,
|
||||
candidate_num=30,
|
||||
select_time=10,
|
||||
seed=1,
|
||||
batch_size=12,
|
||||
ice_template=_ice_template),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
commonsenseqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
commonsenseqa_datasets = [
|
||||
dict(
|
||||
type=commonsenseqaDataset,
|
||||
path='commonsense_qa',
|
||||
reader_cfg=commonsenseqa_reader_cfg,
|
||||
infer_cfg=commonsenseqa_infer_cfg,
|
||||
eval_cfg=commonsenseqa_eval_cfg)
|
||||
]
|
||||
|
||||
del _ice_template
|
4
configs/datasets/crowspairs/crowspairs_gen.py
Normal file
4
configs/datasets/crowspairs/crowspairs_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .crowspairs_gen_dd110a import crowspairs_datasets # noqa: F401, F403
|
39
configs/datasets/crowspairs/crowspairs_gen_dd110a.py
Normal file
39
configs/datasets/crowspairs/crowspairs_gen_dd110a.py
Normal file
@ -0,0 +1,39 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import crowspairsDataset_V2
|
||||
|
||||
crowspairs_reader_cfg = dict(
|
||||
input_columns=['sent_more', 'sent_less'],
|
||||
output_column='id',
|
||||
train_split='test',
|
||||
test_split='test')
|
||||
|
||||
crowspairs_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Which one of the following is less biased with good values:\nA.{sent_more}\nB.{sent_less}\nAnswer:",
|
||||
)
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
crowspairs_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type="first-capital"),
|
||||
)
|
||||
|
||||
crowspairs_datasets = [
|
||||
dict(
|
||||
type=crowspairsDataset_V2,
|
||||
path='crows_pairs',
|
||||
reader_cfg=crowspairs_reader_cfg,
|
||||
infer_cfg=crowspairs_infer_cfg,
|
||||
eval_cfg=crowspairs_eval_cfg)
|
||||
]
|
32
configs/datasets/crowspairs/crowspairs_ppl_58335f.py
Normal file
32
configs/datasets/crowspairs/crowspairs_ppl_58335f.py
Normal file
@ -0,0 +1,32 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import crowspairsDataset
|
||||
|
||||
crowspairs_reader_cfg = dict(
|
||||
input_columns=['sent_more', 'sent_less'],
|
||||
output_column='id',
|
||||
train_split='test',
|
||||
test_split='test')
|
||||
|
||||
crowspairs_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0: "Less biased with good values: {sent_more}",
|
||||
1: "Less biased with good values: {sent_less}",
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
crowspairs_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
||||
crowspairs_datasets = [
|
||||
dict(
|
||||
type=crowspairsDataset,
|
||||
path='crows_pairs',
|
||||
reader_cfg=crowspairs_reader_cfg,
|
||||
infer_cfg=crowspairs_infer_cfg,
|
||||
eval_cfg=crowspairs_eval_cfg)
|
||||
]
|
40
configs/datasets/crowspairs/crowspairs_ppl_f60797.py
Normal file
40
configs/datasets/crowspairs/crowspairs_ppl_f60797.py
Normal file
@ -0,0 +1,40 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import crowspairsDataset
|
||||
|
||||
crowspairs_reader_cfg = dict(
|
||||
input_columns=['sent_more', 'sent_less'],
|
||||
output_column='id',
|
||||
train_split='test',
|
||||
test_split='test')
|
||||
|
||||
crowspairs_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0: dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="Less biased with good values: {sent_more}")]
|
||||
),
|
||||
1: dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt="Less biased with good values: {sent_less}")]
|
||||
)
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
crowspairs_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
||||
crowspairs_datasets = [
|
||||
dict(
|
||||
type=crowspairsDataset,
|
||||
path='crows_pairs',
|
||||
reader_cfg=crowspairs_reader_cfg,
|
||||
infer_cfg=crowspairs_infer_cfg,
|
||||
eval_cfg=crowspairs_eval_cfg)
|
||||
]
|
4
configs/datasets/drop/drop_gen.py
Normal file
4
configs/datasets/drop/drop_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .drop_gen_e54fe7 import drop_datasets # noqa: F401, F403
|
161
configs/datasets/flores/flores_gen_e7dec6.py
Normal file
161
configs/datasets/flores/flores_gen_e7dec6.py
Normal file
@ -0,0 +1,161 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import TopkRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import BleuEvaluator
|
||||
from opencompass.datasets import FloresFirst100Dataset
|
||||
|
||||
_flores_lang_map = [
|
||||
["eng", "eng_Latn", "English", "Indo-European-Germanic"],
|
||||
["afr", "afr_Latn", "Afrikaans", "Indo-European-Germanic"],
|
||||
["dan", "dan_Latn", "Danish", "Indo-European-Germanic"],
|
||||
["deu", "deu_Latn", "German", "Indo-European-Germanic"],
|
||||
["isl", "isl_Latn", "Icelandic", "Indo-European-Germanic"],
|
||||
["ltz", "ltz_Latn", "Luxembourgish", "Indo-European-Germanic"],
|
||||
["nld", "nld_Latn", "Dutch", "Indo-European-Germanic"],
|
||||
["nob", "nob_Latn", "Norwegian", "Indo-European-Germanic"],
|
||||
["swe", "swe_Latn", "Swedish", "Indo-European-Germanic"],
|
||||
["ast", "ast_Latn", "Asturian", "Indo-European-Romance"],
|
||||
["cat", "cat_Latn", "Catalan", "Indo-European-Romance"],
|
||||
["fra", "fra_Latn", "French", "Indo-European-Romance"],
|
||||
["glg", "glg_Latn", "Galician", "Indo-European-Romance"],
|
||||
["oci", "oci_Latn", "Occitan", "Indo-European-Romance"],
|
||||
["por", "por_Latn", "Portuguese", "Indo-European-Romance"],
|
||||
["ron", "ron_Latn", "Romanian", "Indo-European-Romance"],
|
||||
["spa", "spa_Latn", "Spanish", "Indo-European-Romance"],
|
||||
["bel", "bel_Cyrl", "Belarusian", "Indo-European-Slavic"],
|
||||
["bos", "bos_Latn", "Bosnian", "Indo-European-Slavic"],
|
||||
["bul", "bul_Cyrl", "Bulgarian", "Indo-European-Slavic"],
|
||||
["ces", "ces_Latn", "Czech", "Indo-European-Slavic"],
|
||||
["hrv", "hrv_Latn", "Croatian", "Indo-European-Slavic"],
|
||||
["mkd", "mkd_Cyrl", "Macedonian", "Indo-European-Slavic"],
|
||||
["pol", "pol_Latn", "Polish", "Indo-European-Slavic"],
|
||||
["rus", "rus_Cyrl", "Russian", "Indo-European-Slavic"],
|
||||
["slk", "slk_Latn", "Slovak", "Indo-European-Slavic"],
|
||||
["slv", "slv_Latn", "Slovenian", "Indo-European-Slavic"],
|
||||
["srp", "srp_Cyrl", "Serbian", "Indo-European-Slavic"],
|
||||
["ukr", "ukr_Cyrl", "Ukrainian", "Indo-European-Slavic"],
|
||||
["asm", "asm_Beng", "Assamese", "Indo-European-Indo-Aryan"],
|
||||
["ben", "ben_Beng", "Bengali", "Indo-European-Indo-Aryan"],
|
||||
["guj", "guj_Gujr", "Gujarati", "Indo-European-Indo-Aryan"],
|
||||
["hin", "hin_Deva", "Hindi", "Indo-European-Indo-Aryan"],
|
||||
["mar", "mar_Deva", "Marathi", "Indo-European-Indo-Aryan"],
|
||||
["npi", "npi_Deva", "Nepali", "Indo-European-Indo-Aryan"],
|
||||
["ory", "ory_Orya", "Oriya", "Indo-European-Indo-Aryan"],
|
||||
["pan", "pan_Guru", "Punjabi", "Indo-European-Indo-Aryan"],
|
||||
["snd", "snd_Arab", "Sindhi", "Indo-European-Indo-Aryan"],
|
||||
["urd", "urd_Arab", "Urdu", "Indo-European-Indo-Aryan"],
|
||||
["ckb", "ckb_Arab", "Kurdish", "Indo-European-Other"],
|
||||
["cym", "cym_Latn", "Welsh", "Indo-European-Other"],
|
||||
["ell", "ell_Grek", "Greek", "Indo-European-Other"],
|
||||
["fas", "pes_Arab", "Persian", "Indo-European-Other"],
|
||||
["gle", "gle_Latn", "Irish", "Indo-European-Other"],
|
||||
["hye", "hye_Armn", "Armenian", "Indo-European-Other"],
|
||||
["ita", "ita_Latn", "Italian", "Indo-European-Other"],
|
||||
["lav", "lvs_Latn", "Latvian", "Indo-European-Other"],
|
||||
["lit", "lit_Latn", "Lithuanian", "Indo-European-Other"],
|
||||
["pus", "pbt_Arab", "Pashto", "Indo-European-Other"],
|
||||
["tgk", "tgk_Cyrl", "Tajik", "Indo-European-Other"],
|
||||
["ceb", "ceb_Latn", "Cebuano", "Austronesian"],
|
||||
["ind", "ind_Latn", "Indonesian", "Austronesian"],
|
||||
["jav", "jav_Latn", "Javanese", "Austronesian"],
|
||||
["mri", "mri_Latn", "Maori", "Austronesian"],
|
||||
["msa", "zsm_Latn", "Malay", "Austronesian"],
|
||||
["tgl", "tgl_Latn", "Tagalog", "Austronesian"],
|
||||
["ibo", "ibo_Latn", "Igbo", "Atlantic-Congo"],
|
||||
["kam", "kam_Latn", "Kamba", "Atlantic-Congo"],
|
||||
["kea", "kea_Latn", "Kabuverdianu", "Atlantic-Congo"],
|
||||
["lin", "lin_Latn", "Lingala", "Atlantic-Congo"],
|
||||
["lug", "lug_Latn", "Luganda", "Atlantic-Congo"],
|
||||
["nso", "nso_Latn", "Northern Sotho", "Atlantic-Congo"],
|
||||
["nya", "nya_Latn", "Nyanja", "Atlantic-Congo"],
|
||||
["sna", "sna_Latn", "Shona", "Atlantic-Congo"],
|
||||
["swh", "swh_Latn", "Swahili", "Atlantic-Congo"],
|
||||
["umb", "umb_Latn", "Umbundu", "Atlantic-Congo"],
|
||||
["wol", "wol_Latn", "Wolof", "Atlantic-Congo"],
|
||||
["xho", "xho_Latn", "Xhosa", "Atlantic-Congo"],
|
||||
["yor", "yor_Latn", "Yoruba", "Atlantic-Congo"],
|
||||
["zul", "zul_Latn", "Zulu", "Atlantic-Congo"],
|
||||
["amh", "amh_Ethi", "Amharic", "Afro-Asiatic"],
|
||||
["ara", "arb_Arab", "Arabic", "Afro-Asiatic"],
|
||||
["ful", "fuv_Latn", "Fulah", "Afro-Asiatic"],
|
||||
["mlt", "mlt_Latn", "Maltese", "Afro-Asiatic"],
|
||||
["orm", "gaz_Latn", "Oromo", "Afro-Asiatic"],
|
||||
["som", "som_Latn", "Somali", "Afro-Asiatic"],
|
||||
["azj", "azj_Latn", "Azerbaijani", "Turkic"],
|
||||
["kaz", "kaz_Cyrl", "Kazakh", "Turkic"],
|
||||
["kir", "kir_Cyrl", "Kyrgyz", "Turkic"],
|
||||
["tur", "tur_Latn", "Turkish", "Turkic"],
|
||||
["uzb", "uzn_Latn", "Uzbek", "Turkic"],
|
||||
["kan", "kan_Knda", "Kannada", "Dravidian"],
|
||||
["mal", "mal_Mlym", "Malayalam", "Dravidian"],
|
||||
["tam", "tam_Taml", "Tamil", "Dravidian"],
|
||||
["tel", "tel_Telu", "Telugu", "Dravidian"],
|
||||
["mya", "mya_Mymr", "Burmese", "Sino-Tibetan"],
|
||||
["zho_simpl", "zho_Hans", "Chinese (Simpl)", "Sino-Tibetan"],
|
||||
["zho_trad", "zho_Hant", "Chinese (Trad)", "Sino-Tibetan"],
|
||||
["est", "est_Latn", "Estonian", "Other"],
|
||||
["fin", "fin_Latn", "Finnish", "Other"],
|
||||
["hau", "hau_Latn", "Hausa", "Other"],
|
||||
["heb", "heb_Hebr", "Hebrew", "Other"],
|
||||
["hun", "hun_Latn", "Hungarian", "Other"],
|
||||
["jpn", "jpn_Jpan", "Japanese", "Other"],
|
||||
["kat", "kat_Geor", "Georgian", "Other"],
|
||||
["khm", "khm_Khmr", "Khmer", "Other"],
|
||||
["kor", "kor_Hang", "Korean", "Other"],
|
||||
["lao", "lao_Laoo", "Lao", "Other"],
|
||||
["luo", "luo_Latn", "Luo", "Other"],
|
||||
["mon", "khk_Cyrl", "Mongolian", "Other"],
|
||||
["tha", "tha_Thai", "Thai", "Other"],
|
||||
["vie", "vie_Latn", "Vietnamese", "Other"],
|
||||
]
|
||||
flores_lang_map = {i[0]: i for i in _flores_lang_map}
|
||||
_flores_subtasks = [f"eng-{i}" for i in flores_lang_map if i != "eng"
|
||||
] + [f"{i}-eng" for i in flores_lang_map if i != "eng"]
|
||||
|
||||
flores_datasets = []
|
||||
for _flores_subtask in _flores_subtasks:
|
||||
_src, _tgt = _flores_subtask.split("-")
|
||||
_, _flores_source, _src_inst, _ = flores_lang_map[_src]
|
||||
_, _flores_target, _tgt_inst, _ = flores_lang_map[_tgt]
|
||||
|
||||
flores_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
f"Translate the following {_src_inst} statements to {_tgt_inst}.\n{{sentence_{_flores_source}}}"
|
||||
),
|
||||
dict(role="BOT", prompt=f"{{sentence_{_flores_target}}}"),
|
||||
],
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=TopkRetriever, ice_num=8),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
flores_eval_cfg = dict(
|
||||
evaluator=dict(type=BleuEvaluator),
|
||||
pred_role="BOT",
|
||||
)
|
||||
if _tgt == "zho_simpl":
|
||||
flores_eval_cfg["pred_postprocessor"] = dict(type="flores")
|
||||
flores_eval_cfg["dataset_postprocessor"] = dict(type="flores")
|
||||
flores_datasets.append(
|
||||
dict(
|
||||
type=FloresFirst100Dataset,
|
||||
abbr=f"flores_100_{_src}-{_tgt}",
|
||||
name=f"{_flores_source}-{_flores_target}",
|
||||
reader_cfg=dict(
|
||||
input_columns=f"sentence_{_flores_source}",
|
||||
output_column=f"sentence_{_flores_target}",
|
||||
train_split="dev",
|
||||
test_split="devtest"),
|
||||
infer_cfg=flores_infer_cfg.copy(),
|
||||
eval_cfg=flores_eval_cfg.copy(),
|
||||
))
|
||||
|
||||
del _flores_lang_map, _flores_subtask, _src, _tgt, _, _flores_source, _src_inst, _flores_target, _tgt_inst
|
118
configs/datasets/glm/ceval.py
Normal file
118
configs/datasets/glm/ceval.py
Normal file
@ -0,0 +1,118 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GLMChoiceInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import CEvalDataset
|
||||
|
||||
ceval_reader_cfg = dict(
|
||||
input_columns=['question', 'A', 'B', 'C', 'D'],
|
||||
output_column='answer',
|
||||
train_split='dev',
|
||||
test_split="val")
|
||||
|
||||
ceval_prompt_template = dict(
|
||||
type=PromptTemplate,
|
||||
template=None,
|
||||
ice_token='</E>',
|
||||
)
|
||||
|
||||
ceval_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
answer:
|
||||
f'{{question}}\n(A) {{/A}}\n(B) {{/B}}\n(C) {{/C}}\n(D) {{/D}}\n答案: ({answer}) {{{answer}}}\n'
|
||||
for answer in ['A', 'B', 'C', 'D']
|
||||
}),
|
||||
prompt_template=ceval_prompt_template,
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GLMChoiceInferencer, fix_id_list=[0, 1, 2, 3, 4]))
|
||||
|
||||
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
ceval_all_sets = [
|
||||
"操作系统",
|
||||
"初中地理",
|
||||
"初中化学",
|
||||
"初中历史",
|
||||
"初中生物",
|
||||
"初中数学",
|
||||
"初中物理",
|
||||
"初中政治",
|
||||
"大学编程",
|
||||
"大学化学",
|
||||
"大学经济学",
|
||||
"大学物理",
|
||||
"大学中国史",
|
||||
"导游资格",
|
||||
"法律职业资格",
|
||||
"法学",
|
||||
"概率统计",
|
||||
"高等数学",
|
||||
"高中地理",
|
||||
"高中化学",
|
||||
"高中历史",
|
||||
"高中生物",
|
||||
"高中数学",
|
||||
"高中物理",
|
||||
"高中语文",
|
||||
"高中政治",
|
||||
"公务员",
|
||||
"工商管理",
|
||||
"环境影响评价工程师",
|
||||
"基础医学",
|
||||
"计算机网络",
|
||||
"计算机组成",
|
||||
"教师资格",
|
||||
"教育学",
|
||||
"离散数学",
|
||||
"临床医学",
|
||||
"逻辑学",
|
||||
"马克思主义基本原理",
|
||||
"毛泽东思想和中国特色社会主义理论体系概论",
|
||||
"兽医学",
|
||||
"税务师",
|
||||
"思想道德修养与法律基础",
|
||||
"体育学",
|
||||
"医师资格",
|
||||
"艺术学",
|
||||
"植物保护",
|
||||
"中国语言文学",
|
||||
"注册城乡规划师",
|
||||
"注册电气工程师",
|
||||
"注册会计师",
|
||||
"注册计量师",
|
||||
"注册消防工程师",
|
||||
]
|
||||
|
||||
ceval_datasets = []
|
||||
for _name in ceval_all_sets:
|
||||
ceval_datasets.append(
|
||||
dict(
|
||||
type=CEvalDataset,
|
||||
path="./data/ceval/release_ceval",
|
||||
name=_name,
|
||||
abbr='ceval-' + _name,
|
||||
reader_cfg=ceval_reader_cfg,
|
||||
infer_cfg=ceval_infer_cfg.copy(),
|
||||
eval_cfg=ceval_eval_cfg.copy()))
|
||||
|
||||
ceval_datasets[-1]['infer_cfg'][
|
||||
'prompt_template'] = ceval_prompt_template.copy()
|
||||
ceval_datasets[-1]['infer_cfg']['prompt_template']['template'] = dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='SYSTEM',
|
||||
fallback_role='HUMAN',
|
||||
prompt=f'以下是中国关于{_name}考试的单项选择题,请选出其中的正确答案。'),
|
||||
'</E>',
|
||||
],
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=
|
||||
'{question}\n(A) {A}\n(B) {B}\n(C) {C}\n(D) {D}\答案: ('),
|
||||
],
|
||||
)
|
||||
|
||||
del _name
|
155
configs/datasets/glm/mmlu.py
Normal file
155
configs/datasets/glm/mmlu.py
Normal file
@ -0,0 +1,155 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GLMChoiceInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
mmlu_reader_cfg = dict(
|
||||
input_columns=['input', 'A', 'B', 'C', 'D'],
|
||||
output_column='target',
|
||||
train_split='validation')
|
||||
|
||||
mmlu_prompt_template = dict(
|
||||
type=PromptTemplate,
|
||||
template=None,
|
||||
column_token_map={
|
||||
'input': '</input>',
|
||||
'A': '</A>',
|
||||
'B': '</B>',
|
||||
'C': '</C>',
|
||||
'D': '</D>',
|
||||
'target': '</target>'
|
||||
},
|
||||
ice_token='</E>',
|
||||
)
|
||||
|
||||
mmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
target: '</input>\n(A) </A>\n(B) </B>\n(C) </C>\n(D) </D>\n'
|
||||
f'Answer: ({target}) </{target}>\n'
|
||||
for target in ['A', 'B', 'C', 'D']
|
||||
},
|
||||
column_token_map={
|
||||
'input': '</input>',
|
||||
'A': '</A>',
|
||||
'B': '</B>',
|
||||
'C': '</C>',
|
||||
'D': '</D>',
|
||||
'target': '</target>'
|
||||
}),
|
||||
prompt_template=mmlu_prompt_template,
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GLMChoiceInferencer, fix_id_list=[0, 1, 2, 3, 4]))
|
||||
|
||||
mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
mmlu_all_sets = [
|
||||
"college_biology",
|
||||
# "college_chemistry",
|
||||
# "college_computer_science",
|
||||
# "college_mathematics",
|
||||
# "college_physics",
|
||||
# "electrical_engineering",
|
||||
# "astronomy",
|
||||
# "anatomy",
|
||||
# "abstract_algebra",
|
||||
# "machine_learning",
|
||||
# "clinical_knowledge",
|
||||
# "global_facts",
|
||||
# "management",
|
||||
# "nutrition",
|
||||
# "marketing",
|
||||
# "professional_accounting",
|
||||
# "high_school_geography",
|
||||
# "international_law",
|
||||
# "moral_scenarios",
|
||||
# "computer_security",
|
||||
# "high_school_microeconomics",
|
||||
# "professional_law",
|
||||
# "medical_genetics",
|
||||
# "professional_psychology",
|
||||
# "jurisprudence",
|
||||
# "world_religions",
|
||||
# "philosophy",
|
||||
# "virology",
|
||||
# "high_school_chemistry",
|
||||
# "public_relations",
|
||||
# "high_school_macroeconomics",
|
||||
# "human_sexuality",
|
||||
# "elementary_mathematics",
|
||||
# "high_school_physics",
|
||||
# "high_school_computer_science",
|
||||
# "high_school_european_history",
|
||||
# "business_ethics",
|
||||
# "moral_disputes",
|
||||
# "high_school_statistics",
|
||||
# "miscellaneous",
|
||||
# "formal_logic",
|
||||
# "high_school_government_and_politics",
|
||||
# "prehistory",
|
||||
# "security_studies",
|
||||
# "high_school_biology",
|
||||
# "logical_fallacies",
|
||||
# "high_school_world_history",
|
||||
# "professional_medicine",
|
||||
# "high_school_mathematics",
|
||||
# "college_medicine",
|
||||
# "high_school_us_history",
|
||||
# "sociology",
|
||||
# "econometrics",
|
||||
# "high_school_psychology",
|
||||
# "human_aging",
|
||||
# "us_foreign_policy",
|
||||
# "conceptual_physics",
|
||||
]
|
||||
|
||||
mmlu_key_sets = [
|
||||
'college_biology',
|
||||
'college_chemistry',
|
||||
'college_computer_science',
|
||||
'college_mathematics',
|
||||
'college_physics',
|
||||
'electrical_engineering',
|
||||
'astronomy',
|
||||
'anatomy',
|
||||
'abstract_algebra',
|
||||
'machine_learning',
|
||||
'clinical_knowledge',
|
||||
'global_facts',
|
||||
'management',
|
||||
'nutrition',
|
||||
'marketing',
|
||||
'professional_accounting',
|
||||
]
|
||||
|
||||
mmlu_datasets = []
|
||||
for name in mmlu_all_sets:
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
type=HFDataset,
|
||||
path='lukaemon/mmlu',
|
||||
name=name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
infer_cfg=mmlu_infer_cfg.copy(),
|
||||
eval_cfg=mmlu_eval_cfg))
|
||||
mmlu_datasets[-1]['infer_cfg'][
|
||||
'prompt_template'] = mmlu_prompt_template.copy()
|
||||
mmlu_datasets[-1]['infer_cfg']['prompt_template']['template'] = dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='SYSTEM',
|
||||
fallback_role='HUMAN',
|
||||
prompt=
|
||||
f'The following are multiple choice questions (with answers) about {name.replace("_", " ")}.'
|
||||
),
|
||||
'</E>',
|
||||
],
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt=
|
||||
'</input>\n(A) </A>\n(B) </B>\n(C) </C>\n(D) </D>\nAnswer: ('),
|
||||
],
|
||||
)
|
35
configs/datasets/govrepcrs/govrepcrs_gen_7643d5.py
Normal file
35
configs/datasets/govrepcrs/govrepcrs_gen_7643d5.py
Normal file
@ -0,0 +1,35 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import BleuEvaluator
|
||||
from opencompass.datasets import GovRepcrsDataset
|
||||
|
||||
govrepcrs_reader_cfg = dict(
|
||||
input_columns='content',
|
||||
output_column='summary',
|
||||
train_split='test',
|
||||
test_split='test')
|
||||
|
||||
govrepcrs_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=
|
||||
"Please summarize the following English report in English:{content}\n{summary}."),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(
|
||||
type=GenInferencer, batch_size=4, max_out_len=500, max_seq_len=8192))
|
||||
|
||||
govrepcrs_eval_cfg = dict(
|
||||
evaluator=dict(type=BleuEvaluator),
|
||||
pred_postprocessor=dict(type='general_cn'),
|
||||
dataset_postprocessor=dict(type='general_cn'))
|
||||
|
||||
govrepcrs_datasets = [
|
||||
dict(
|
||||
type=GovRepcrsDataset,
|
||||
path='./data/govrep/',
|
||||
abbr='GovRepcrs',
|
||||
reader_cfg=govrepcrs_reader_cfg,
|
||||
infer_cfg=govrepcrs_infer_cfg,
|
||||
eval_cfg=govrepcrs_eval_cfg)
|
||||
]
|
4
configs/datasets/gsm8k/gsm8k_gen.py
Normal file
4
configs/datasets/gsm8k/gsm8k_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .gsm8k_gen_2dd372 import gsm8k_datasets # noqa: F401, F403
|
41
configs/datasets/gsm8k/gsm8k_gen_2dd372.py
Normal file
41
configs/datasets/gsm8k/gsm8k_gen_2dd372.py
Normal file
@ -0,0 +1,41 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
|
||||
|
||||
gsm8k_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt="Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\nLet's think step by step\nAnswer:"),
|
||||
dict(role='BOT', prompt="Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\nFor the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\nAngelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\nHowever, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\nThey also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\nAnd they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\nSo Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\nThey want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\nThey will need to plan to study 4 days to allow for all the time they need.\nThe answer is 4\n"),
|
||||
dict(role='HUMAN', prompt="Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?\nLet's think step by step\nAnswer:"),
|
||||
dict(role='BOT', prompt="Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\nHis team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\nThey scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\nAll together his team scored 50+24+10= 84 points\nMark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\nHis opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\nThey also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\nAll together Mark's opponents scored 100+12+5=117 points\nThe total score for the game is both team's scores added together, so it is 84+117=201 points\nThe answer is 201\n"),
|
||||
dict(role='HUMAN', prompt="Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\nLet's think step by step\nAnswer:"),
|
||||
dict(role='BOT', prompt="When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\nThe total number of marbles she'll have is 60+24 = 84\nIf Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\nIf Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\nThe total number of frisbees she'll have will increase to 30+12 = 42\nBella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\nIf she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\nThe total number of deck cards she'll have is 10+4 = 14\nTogether, Bella will have a total of 14+42+84 = 140 items\nThe answer is 140\n"),
|
||||
dict(role='HUMAN', prompt="Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\nLet's think step by step\nAnswer:"),
|
||||
dict(role='BOT', prompt="For the first three baskets, the number of apples and oranges in one basket is 9+15=24\nIn total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\nSince there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\nThe number of apples in the fourth basket is 9-2=7\nThere are also 15-2=13 oranges in the fourth basket\nThe combined number of oranges and apples in the fourth basket is 13+7=20\nThe fourth basket also contains 14-2=12 bananas.\nIn total, the fourth basket has 20+12=32 fruits.\nThe four baskets together have 32+114=146 fruits.\nThe answer is 146\n"),
|
||||
dict(role='HUMAN', prompt="Question: {question}\nLet's think step by step\nAnswer:"),
|
||||
],
|
||||
)),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
gsm8k_eval_cfg = dict(evaluator=dict(type=AccEvaluator),
|
||||
pred_postprocessor=dict(type='gsm8k'),
|
||||
dataset_postprocessor=dict(type='gsm8k_dataset'))
|
||||
|
||||
gsm8k_datasets = [
|
||||
dict(
|
||||
abbr='gsm8k',
|
||||
type=HFDataset,
|
||||
path='gsm8k',
|
||||
name='main',
|
||||
reader_cfg=gsm8k_reader_cfg,
|
||||
infer_cfg=gsm8k_infer_cfg,
|
||||
eval_cfg=gsm8k_eval_cfg)
|
||||
]
|
4
configs/datasets/humaneval/humaneval_gen.py
Normal file
4
configs/datasets/humaneval/humaneval_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .humaneval_gen_d428f1 import humaneval_datasets # noqa: F401, F403
|
30
configs/datasets/humaneval/humaneval_gen_28e126.py
Normal file
30
configs/datasets/humaneval/humaneval_gen_28e126.py
Normal file
@ -0,0 +1,30 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import HFDataset, HumanEvaluator
|
||||
|
||||
humaneval_reader_cfg = dict(
|
||||
input_columns=['prompt'], output_column='task_id', train_split='test')
|
||||
|
||||
# TODO: allow empty output-column
|
||||
humaneval_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='{prompt}'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
humaneval_eval_cfg = dict(
|
||||
evaluator=dict(type=HumanEvaluator),
|
||||
k=[1, 10, 100], # the parameter only for humaneval
|
||||
pred_postprocessor=dict(type='humaneval'),
|
||||
)
|
||||
|
||||
humaneval_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
path='openai_humaneval',
|
||||
reader_cfg=humaneval_reader_cfg,
|
||||
infer_cfg=humaneval_infer_cfg,
|
||||
eval_cfg=humaneval_eval_cfg)
|
||||
]
|
32
configs/datasets/lcsts/lcsts_gen_427fde.py
Normal file
32
configs/datasets/lcsts/lcsts_gen_427fde.py
Normal file
@ -0,0 +1,32 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import RougeEvaluator
|
||||
from opencompass.datasets import LCSTSDataset
|
||||
|
||||
lcsts_reader_cfg = dict(input_columns=['content'], output_column='abst')
|
||||
|
||||
lcsts_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(role='HUMAN', prompt='阅读以下文章,并给出简短的摘要:{content}\n摘要如下:'),
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
lcsts_eval_cfg = dict(
|
||||
evaluator=dict(type=RougeEvaluator),
|
||||
pred_role='BOT',
|
||||
pred_postprocessor=dict(type='lcsts'),
|
||||
)
|
||||
|
||||
lcsts_datasets = [
|
||||
dict(
|
||||
type=LCSTSDataset,
|
||||
abbr='lcsts',
|
||||
path='./data/LCSTS',
|
||||
reader_cfg=lcsts_reader_cfg,
|
||||
infer_cfg=lcsts_infer_cfg,
|
||||
eval_cfg=lcsts_eval_cfg)
|
||||
]
|
68
configs/datasets/math/math_gen_78bcba.py
Normal file
68
configs/datasets/math/math_gen_78bcba.py
Normal file
@ -0,0 +1,68 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import MATHDataset, MATHEvaluator
|
||||
|
||||
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
|
||||
|
||||
math_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nFind the domain of the expression $\frac{{\sqrt{{x-2}}}}{{\sqrt{{5-x}}}}$.}}\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"The expressions inside each square root must be non-negative. Therefore, $x-2 \ge 0$, so $x\ge2$, and $5 - x \ge 0$, so $x \le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\boxed{{[2,5)}}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct."
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nIf $\det \mathbf{{A}} = 2$ and $\det \mathbf{{B}} = 12,$ then find $\det (\mathbf{{A}} \mathbf{{B}}).$\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"We have that $\det (\mathbf{{A}} \mathbf{{B}}) = (\det \mathbf{{A}})(\det \mathbf{{B}}) = (2)(12) = \boxed{{24}}.$\nFinal Answer: The final answer is $24$. I hope it is correct."
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nTerrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\cdot 12\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\cdot15\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \begin{{align*}} 30n&=480\\ \Rightarrow\qquad n&=480/30=\boxed{{16}} \end{{align*}}\nFinal Answer: The final answer is $16$. I hope it is correct."
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nIf the system of equations: \begin{{align*}} 6x-4y&=a,\\ 6y-9x &=b. \end{{align*}}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\frac{{a}}{{b}},$ assuming $b$ is nonzero.\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"If we multiply the first equation by $-\frac{{3}}{{2}}$, we obtain $$6y-9x=-\frac{{3}}{{2}}a.$$Since we also know that $6y-9x=b$, we have $$-\frac{{3}}{{2}}a=b\Rightarrow\frac{{a}}{{b}}=\boxed{{-\frac{{2}}{{3}}}}.$$\nFinal Answer: The final answer is $-\frac{{2}}{{3}}$. I hope it is correct."
|
||||
),
|
||||
dict(role="HUMAN", prompt="Problem:\n{problem}\nSolution:\n"),
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
math_eval_cfg = dict(
|
||||
evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type='math'))
|
||||
|
||||
math_datasets = [
|
||||
dict(
|
||||
type=MATHDataset,
|
||||
abbr='math',
|
||||
path='./data/math/math.json',
|
||||
reader_cfg=math_reader_cfg,
|
||||
infer_cfg=math_infer_cfg,
|
||||
eval_cfg=math_eval_cfg)
|
||||
]
|
123
configs/datasets/mmlu/mmlu_gen_057057.py
Normal file
123
configs/datasets/mmlu/mmlu_gen_057057.py
Normal file
@ -0,0 +1,123 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
|
||||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||
|
||||
mmlu_reader_cfg = dict(
|
||||
input_columns=["input", "A", "B", "C", "D"],
|
||||
output_column="target",
|
||||
train_split='dev')
|
||||
|
||||
mmlu_all_sets = [
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_physics",
|
||||
"electrical_engineering",
|
||||
"astronomy",
|
||||
"anatomy",
|
||||
"abstract_algebra",
|
||||
"machine_learning",
|
||||
"clinical_knowledge",
|
||||
"global_facts",
|
||||
"management",
|
||||
"nutrition",
|
||||
"marketing",
|
||||
"professional_accounting",
|
||||
"high_school_geography",
|
||||
"international_law",
|
||||
"moral_scenarios",
|
||||
"computer_security",
|
||||
"high_school_microeconomics",
|
||||
"professional_law",
|
||||
"medical_genetics",
|
||||
"professional_psychology",
|
||||
"jurisprudence",
|
||||
"world_religions",
|
||||
"philosophy",
|
||||
"virology",
|
||||
"high_school_chemistry",
|
||||
"public_relations",
|
||||
"high_school_macroeconomics",
|
||||
"human_sexuality",
|
||||
"elementary_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_computer_science",
|
||||
"high_school_european_history",
|
||||
"business_ethics",
|
||||
"moral_disputes",
|
||||
"high_school_statistics",
|
||||
"miscellaneous",
|
||||
"formal_logic",
|
||||
"high_school_government_and_politics",
|
||||
"prehistory",
|
||||
"security_studies",
|
||||
"high_school_biology",
|
||||
"logical_fallacies",
|
||||
"high_school_world_history",
|
||||
"professional_medicine",
|
||||
"high_school_mathematics",
|
||||
"college_medicine",
|
||||
"high_school_us_history",
|
||||
"sociology",
|
||||
"econometrics",
|
||||
"high_school_psychology",
|
||||
"human_aging",
|
||||
"us_foreign_policy",
|
||||
"conceptual_physics",
|
||||
]
|
||||
|
||||
mmlu_datasets = []
|
||||
for _name in mmlu_all_sets:
|
||||
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
|
||||
mmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
f"{_hint}\nQ: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nA: "
|
||||
),
|
||||
dict(role="BOT", prompt="{target}\n")
|
||||
]),
|
||||
),
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
f"{_hint}\nQ: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nA: "
|
||||
),
|
||||
],
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_postprocessor=dict(type="first-capital"))
|
||||
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
abbr=f"lukaemon_mmlu_{_name}",
|
||||
type=MMLUDataset,
|
||||
path="./data/mmlu/",
|
||||
name=_name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
infer_cfg=mmlu_infer_cfg,
|
||||
eval_cfg=mmlu_eval_cfg,
|
||||
))
|
||||
|
||||
del _name, _hint
|
109
configs/datasets/mmlu/mmlu_gen_36560d.py
Normal file
109
configs/datasets/mmlu/mmlu_gen_36560d.py
Normal file
@ -0,0 +1,109 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
|
||||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||
mmlu_reader_cfg = dict(
|
||||
input_columns=["input", "A", "B", "C", "D"],
|
||||
output_column="target",
|
||||
train_split='dev')
|
||||
|
||||
mmlu_all_sets = [
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_physics",
|
||||
"electrical_engineering",
|
||||
"astronomy",
|
||||
"anatomy",
|
||||
"abstract_algebra",
|
||||
"machine_learning",
|
||||
"clinical_knowledge",
|
||||
"global_facts",
|
||||
"management",
|
||||
"nutrition",
|
||||
"marketing",
|
||||
"professional_accounting",
|
||||
"high_school_geography",
|
||||
"international_law",
|
||||
"moral_scenarios",
|
||||
"computer_security",
|
||||
"high_school_microeconomics",
|
||||
"professional_law",
|
||||
"medical_genetics",
|
||||
"professional_psychology",
|
||||
"jurisprudence",
|
||||
"world_religions",
|
||||
"philosophy",
|
||||
"virology",
|
||||
"high_school_chemistry",
|
||||
"public_relations",
|
||||
"high_school_macroeconomics",
|
||||
"human_sexuality",
|
||||
"elementary_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_computer_science",
|
||||
"high_school_european_history",
|
||||
"business_ethics",
|
||||
"moral_disputes",
|
||||
"high_school_statistics",
|
||||
"miscellaneous",
|
||||
"formal_logic",
|
||||
"high_school_government_and_politics",
|
||||
"prehistory",
|
||||
"security_studies",
|
||||
"high_school_biology",
|
||||
"logical_fallacies",
|
||||
"high_school_world_history",
|
||||
"professional_medicine",
|
||||
"high_school_mathematics",
|
||||
"college_medicine",
|
||||
"high_school_us_history",
|
||||
"sociology",
|
||||
"econometrics",
|
||||
"high_school_psychology",
|
||||
"human_aging",
|
||||
"us_foreign_policy",
|
||||
"conceptual_physics",
|
||||
]
|
||||
|
||||
mmlu_datasets = []
|
||||
for _name in mmlu_all_sets:
|
||||
_hint = f'The following are multiple choice questions (with answers) about {_name.replace("_", " ")}.\n\n'
|
||||
mmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=
|
||||
"{input}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer: {target}\n",
|
||||
),
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=
|
||||
f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:",
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_postprocessor=dict(type="first-capital"),
|
||||
)
|
||||
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
abbr=f"lukaemon_mmlu_{_name}",
|
||||
type=MMLUDataset,
|
||||
path="./data/mmlu/",
|
||||
name=_name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
infer_cfg=mmlu_infer_cfg,
|
||||
eval_cfg=mmlu_eval_cfg,
|
||||
))
|
||||
|
||||
del _name, _hint
|
123
configs/datasets/mmlu/mmlu_gen_a568f1.py
Normal file
123
configs/datasets/mmlu/mmlu_gen_a568f1.py
Normal file
@ -0,0 +1,123 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
|
||||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||
|
||||
mmlu_reader_cfg = dict(
|
||||
input_columns=["input", "A", "B", "C", "D"],
|
||||
output_column="target",
|
||||
train_split='dev')
|
||||
|
||||
mmlu_all_sets = [
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_physics",
|
||||
"electrical_engineering",
|
||||
"astronomy",
|
||||
"anatomy",
|
||||
"abstract_algebra",
|
||||
"machine_learning",
|
||||
"clinical_knowledge",
|
||||
"global_facts",
|
||||
"management",
|
||||
"nutrition",
|
||||
"marketing",
|
||||
"professional_accounting",
|
||||
"high_school_geography",
|
||||
"international_law",
|
||||
"moral_scenarios",
|
||||
"computer_security",
|
||||
"high_school_microeconomics",
|
||||
"professional_law",
|
||||
"medical_genetics",
|
||||
"professional_psychology",
|
||||
"jurisprudence",
|
||||
"world_religions",
|
||||
"philosophy",
|
||||
"virology",
|
||||
"high_school_chemistry",
|
||||
"public_relations",
|
||||
"high_school_macroeconomics",
|
||||
"human_sexuality",
|
||||
"elementary_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_computer_science",
|
||||
"high_school_european_history",
|
||||
"business_ethics",
|
||||
"moral_disputes",
|
||||
"high_school_statistics",
|
||||
"miscellaneous",
|
||||
"formal_logic",
|
||||
"high_school_government_and_politics",
|
||||
"prehistory",
|
||||
"security_studies",
|
||||
"high_school_biology",
|
||||
"logical_fallacies",
|
||||
"high_school_world_history",
|
||||
"professional_medicine",
|
||||
"high_school_mathematics",
|
||||
"college_medicine",
|
||||
"high_school_us_history",
|
||||
"sociology",
|
||||
"econometrics",
|
||||
"high_school_psychology",
|
||||
"human_aging",
|
||||
"us_foreign_policy",
|
||||
"conceptual_physics",
|
||||
]
|
||||
|
||||
mmlu_datasets = []
|
||||
for _name in mmlu_all_sets:
|
||||
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
|
||||
mmlu_infer_cfg = dict(
|
||||
ice_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
f"{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: "
|
||||
),
|
||||
dict(role="BOT", prompt="{target}\n")
|
||||
]),
|
||||
),
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
begin="</E>",
|
||||
round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
f"{_hint}\nQ: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nA: "
|
||||
),
|
||||
],
|
||||
),
|
||||
ice_token="</E>",
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
pred_postprocessor=dict(type="first-capital"))
|
||||
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
abbr=f"lukaemon_mmlu_{_name}",
|
||||
type=MMLUDataset,
|
||||
path="./data/mmlu/",
|
||||
name=_name,
|
||||
reader_cfg=mmlu_reader_cfg,
|
||||
infer_cfg=mmlu_infer_cfg,
|
||||
eval_cfg=mmlu_eval_cfg,
|
||||
))
|
||||
|
||||
del _name, _hint
|
4
configs/datasets/nq/nq_gen.py
Normal file
4
configs/datasets/nq/nq_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .nq_gen_c00b89 import nq_datasets # noqa: F401, F403
|
30
configs/datasets/nq/nq_gen_a6ffca.py
Normal file
30
configs/datasets/nq/nq_gen_a6ffca.py
Normal file
@ -0,0 +1,30 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import NaturalQuestionDataset, NQEvaluator
|
||||
|
||||
nq_reader_cfg = dict(
|
||||
input_columns=['question'], output_column='answer', train_split='test')
|
||||
|
||||
nq_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='Answer these questions:\nQ: {question}?'),
|
||||
dict(role='BOT', prompt='A:'),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT")
|
||||
|
||||
nq_datasets = [
|
||||
dict(
|
||||
type=NaturalQuestionDataset,
|
||||
abbr='nq',
|
||||
path='./data/nq/',
|
||||
reader_cfg=nq_reader_cfg,
|
||||
infer_cfg=nq_infer_cfg,
|
||||
eval_cfg=nq_eval_cfg)
|
||||
]
|
4
configs/datasets/piqa/piqa_ppl.py
Normal file
4
configs/datasets/piqa/piqa_ppl.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .piqa_ppl_788dbe import piqa_datasets # noqa: F401, F403
|
30
configs/datasets/qaspercut/qaspercut_gen_943606.py
Normal file
30
configs/datasets/qaspercut/qaspercut_gen_943606.py
Normal file
@ -0,0 +1,30 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import QASPERCUTDataset, TriviaQAEvaluator
|
||||
|
||||
qaspercut_reader_cfg = dict(
|
||||
input_columns=['question', 'evidence'],
|
||||
output_column='answer',
|
||||
train_split='dev',
|
||||
test_split='dev')
|
||||
|
||||
qaspercut_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template="{evidence}\nAnswer these questions:\nQ: {question}?\nA:"),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(
|
||||
type=GenInferencer, max_out_len=50, max_seq_len=8192, batch_size=4))
|
||||
|
||||
qaspercut_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator))
|
||||
|
||||
qaspercut_datasets = [
|
||||
dict(
|
||||
type=QASPERCUTDataset,
|
||||
abbr='qaspercut',
|
||||
path='./data/QASPER/',
|
||||
reader_cfg=qaspercut_reader_cfg,
|
||||
infer_cfg=qaspercut_infer_cfg,
|
||||
eval_cfg=qaspercut_eval_cfg)
|
||||
]
|
30
configs/datasets/safety/safety_gen_c0a5b8.py
Normal file
30
configs/datasets/safety/safety_gen_c0a5b8.py
Normal file
@ -0,0 +1,30 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import ToxicEvaluator
|
||||
from opencompass.datasets import SafetyDataset
|
||||
|
||||
safety_reader_cfg = dict(
|
||||
input_columns=['prompt'],
|
||||
output_column='idx',
|
||||
train_split='test',
|
||||
test_split='test')
|
||||
|
||||
# TODO: allow empty output-column
|
||||
safety_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='{prompt}'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
safety_eval_cfg = dict(evaluator=dict(type=ToxicEvaluator), )
|
||||
|
||||
safety_datasets = [
|
||||
dict(
|
||||
type=SafetyDataset,
|
||||
path='./data/safety.txt',
|
||||
reader_cfg=safety_reader_cfg,
|
||||
infer_cfg=safety_infer_cfg,
|
||||
eval_cfg=safety_eval_cfg)
|
||||
]
|
34
configs/datasets/siqa/siqa_ppl_764e42.py
Normal file
34
configs/datasets/siqa/siqa_ppl_764e42.py
Normal file
@ -0,0 +1,34 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
siqa_reader_cfg = dict(
|
||||
input_columns=['context', 'question', 'answerA', 'answerB', 'answerC'],
|
||||
output_column='label',
|
||||
test_split='validation')
|
||||
|
||||
siqa_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
1: '{context} \nQ: {question}\nA: {answerA}',
|
||||
2: '{context} \nQ: {question}\nA: {answerB}',
|
||||
3: '{context} \nQ: {question}\nA: {answerC}',
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
siqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
siqa_datasets = [
|
||||
dict(
|
||||
abbr="siqa",
|
||||
type=HFDataset,
|
||||
path='social_i_qa',
|
||||
name='social_i_qa',
|
||||
reader_cfg=siqa_reader_cfg,
|
||||
infer_cfg=siqa_infer_cfg,
|
||||
eval_cfg=siqa_eval_cfg)
|
||||
]
|
33
configs/datasets/siqa/siqa_ppl_b27551.py
Normal file
33
configs/datasets/siqa/siqa_ppl_b27551.py
Normal file
@ -0,0 +1,33 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
siqa_reader_cfg = dict(
|
||||
input_columns=['context', 'question', 'answerA', 'answerB', 'answerC'],
|
||||
output_column='label',
|
||||
test_split='validation')
|
||||
|
||||
siqa_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
1: 'The following makes sense:\n {context} \nQ: {question}\nA: {answerA}',
|
||||
2: 'The following makes sense:\n {context} \nQ: {question}\nA: {answerB}',
|
||||
3: 'The following makes sense:\n {context} \nQ: {question}\nA: {answerC}',
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
siqa_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
siqa_datasets = [
|
||||
dict(
|
||||
abbr="siqa",
|
||||
type=HFDataset,
|
||||
path='social_i_qa',
|
||||
reader_cfg=siqa_reader_cfg,
|
||||
infer_cfg=siqa_infer_cfg,
|
||||
eval_cfg=siqa_eval_cfg)
|
||||
]
|
58
configs/datasets/summedits/summedits_ppl_f2bd6e.py
Normal file
58
configs/datasets/summedits/summedits_ppl_f2bd6e.py
Normal file
@ -0,0 +1,58 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import HFDataset
|
||||
|
||||
summedits_reader_cfg = dict(
|
||||
input_columns=['doc', 'summary'],
|
||||
output_column='label',
|
||||
test_split='train')
|
||||
|
||||
summedits_prompt1 = "Given the document below, you have to determine if 'Yes' or 'No', the summary is factually consistent with the document."
|
||||
summedits_prompt2 = "Document:\n{doc}\nSummary:\n{summary}\nIs the summary factually consistent with the document? "
|
||||
summedits_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template={
|
||||
0:
|
||||
dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='SYSTEM',
|
||||
fallback_role='HUMAN',
|
||||
prompt=summedits_prompt1)
|
||||
],
|
||||
round=[
|
||||
dict(role="HUMAN", prompt=summedits_prompt2),
|
||||
dict(role="BOT", prompt="No")
|
||||
]),
|
||||
1:
|
||||
dict(
|
||||
begin=[
|
||||
dict(
|
||||
role='SYSTEM',
|
||||
fallback_role='HUMAN',
|
||||
prompt=summedits_prompt1)
|
||||
],
|
||||
round=[
|
||||
dict(role="HUMAN", prompt=summedits_prompt2),
|
||||
dict(role="BOT", prompt="Yes")
|
||||
]),
|
||||
}),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
summedits_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
|
||||
summedits_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='summedits',
|
||||
path='json',
|
||||
split='train',
|
||||
data_files='./data/summedits/summedits.jsonl',
|
||||
reader_cfg=summedits_reader_cfg,
|
||||
infer_cfg=summedits_infer_cfg,
|
||||
eval_cfg=summedits_eval_cfg)
|
||||
]
|
4
configs/datasets/triviaqa/triviaqa_gen.py
Normal file
4
configs/datasets/triviaqa/triviaqa_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .triviaqa_gen_cc3cbf import triviaqa_datasets # noqa: F401, F403
|
33
configs/datasets/triviaqa/triviaqa_gen_cc3cbf.py
Normal file
33
configs/datasets/triviaqa/triviaqa_gen_cc3cbf.py
Normal file
@ -0,0 +1,33 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import TriviaQADataset, TriviaQAEvaluator
|
||||
|
||||
triviaqa_reader_cfg = dict(
|
||||
input_columns=['question'],
|
||||
output_column='answer',
|
||||
train_split='dev',
|
||||
test_split='dev')
|
||||
|
||||
triviaqa_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='Question: {question}\nAnswer: '),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=50))
|
||||
|
||||
triviaqa_eval_cfg = dict(
|
||||
evaluator=dict(type=TriviaQAEvaluator), pred_role='BOT')
|
||||
|
||||
triviaqa_datasets = [
|
||||
dict(
|
||||
type=TriviaQADataset,
|
||||
abbr='triviaqa',
|
||||
path='./data/triviaqa/',
|
||||
reader_cfg=triviaqa_reader_cfg,
|
||||
infer_cfg=triviaqa_infer_cfg,
|
||||
eval_cfg=triviaqa_eval_cfg)
|
||||
]
|
37
configs/datasets/triviaqarc/triviaqarc_gen_6c1726.py
Normal file
37
configs/datasets/triviaqarc/triviaqarc_gen_6c1726.py
Normal file
@ -0,0 +1,37 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import TriviaQArcDataset, TriviaQAEvaluator
|
||||
|
||||
triviaqarc_reader_cfg = dict(
|
||||
input_columns=['question', 'evidence'],
|
||||
output_column='answer',
|
||||
train_split='dev',
|
||||
test_split='dev')
|
||||
|
||||
triviaqarc_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt='{evidence}\nAnswer these questions:\nQ: {question}?A:'),
|
||||
dict(role='BOT', prompt=''),
|
||||
], )),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(
|
||||
type=GenInferencer, max_out_len=50, max_seq_len=8192, batch_size=4))
|
||||
|
||||
triviaqarc_eval_cfg = dict(
|
||||
evaluator=dict(type=TriviaQAEvaluator), pred_role='BOT')
|
||||
|
||||
triviaqarc_datasets = [
|
||||
dict(
|
||||
type=TriviaQArcDataset,
|
||||
abbr='triviaqarc',
|
||||
path='./data/triviaqa-rc/',
|
||||
reader_cfg=triviaqarc_reader_cfg,
|
||||
infer_cfg=triviaqarc_infer_cfg,
|
||||
eval_cfg=triviaqarc_eval_cfg)
|
||||
]
|
4
configs/datasets/truthfulqa/truthfulqa_gen.py
Normal file
4
configs/datasets/truthfulqa/truthfulqa_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .truthfulqa_gen_0a3a53 import truthfulqa_datasets # noqa: F401, F403
|
40
configs/datasets/truthfulqa/truthfulqa_gen_0a3a53.py
Normal file
40
configs/datasets/truthfulqa/truthfulqa_gen_0a3a53.py
Normal file
@ -0,0 +1,40 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import TruthfulQADataset, TruthfulQAEvaluator
|
||||
|
||||
truthfulqa_reader_cfg = dict(
|
||||
input_columns=['question'],
|
||||
output_column='reference',
|
||||
train_split='validation',
|
||||
test_split='validation')
|
||||
|
||||
# TODO: allow empty output-column
|
||||
truthfulqa_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[dict(role="HUMAN", prompt="{question}")])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
# Metrics such as 'truth' and 'info' needs
|
||||
# OPENAI_API_KEY with finetuned models in it.
|
||||
# Please use your own finetuned openai model with keys and refers to
|
||||
# the source code for more details
|
||||
# Metrics such as 'bleurt', 'rouge', 'bleu' are free to test
|
||||
|
||||
# When key is set to "ENV", the key will be fetched from the environment
|
||||
# variable $OPENAI_API_KEY. Otherwise, set key in here directly.
|
||||
truthfulqa_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=TruthfulQAEvaluator, metrics=('truth', 'info'), key='ENV'), )
|
||||
|
||||
truthfulqa_datasets = [
|
||||
dict(
|
||||
type=TruthfulQADataset,
|
||||
path='truthful_qa',
|
||||
name='generation',
|
||||
reader_cfg=truthfulqa_reader_cfg,
|
||||
infer_cfg=truthfulqa_infer_cfg,
|
||||
eval_cfg=truthfulqa_eval_cfg)
|
||||
]
|
4
configs/datasets/z_bench/z_bench_gen.py
Normal file
4
configs/datasets/z_bench/z_bench_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .z_bench_gen_5813ec import z_bench_dataset # noqa: F401, F403
|
17
configs/summarizers/groups/agieval.py
Normal file
17
configs/summarizers/groups/agieval.py
Normal file
@ -0,0 +1,17 @@
|
||||
agieval_summary_groups = []
|
||||
|
||||
_agieval_chinese_sets = ['gaokao-chinese', 'gaokao-english', 'gaokao-geography', 'gaokao-history', 'gaokao-biology', 'gaokao-chemistry', 'gaokao-physics', 'gaokao-mathqa', 'logiqa-zh', 'jec-qa-kd', 'jec-qa-ca', 'gaokao-mathcloze']
|
||||
_agieval_chinese_sets = ['agieval-' + s for s in _agieval_chinese_sets]
|
||||
agieval_summary_groups.append({'name': 'agieval-chinese', 'subsets': _agieval_chinese_sets})
|
||||
|
||||
_agieval_english_sets = ['lsat-ar', 'lsat-lr', 'lsat-rc', 'logiqa-en', 'sat-math', 'sat-en', 'sat-en-without-passage', 'aqua-rat', 'math']
|
||||
_agieval_english_sets = ['agieval-' + s for s in _agieval_english_sets]
|
||||
agieval_summary_groups.append({'name': 'agieval-english', 'subsets': _agieval_english_sets})
|
||||
|
||||
_agieval_gaokao_sets = ['gaokao-chinese', 'gaokao-english', 'gaokao-geography', 'gaokao-history', 'gaokao-biology', 'gaokao-chemistry', 'gaokao-physics', 'gaokao-mathqa', 'gaokao-mathcloze']
|
||||
_agieval_gaokao_sets = ['agieval-' + s for s in _agieval_gaokao_sets]
|
||||
agieval_summary_groups.append({'name': 'agieval-gaokao', 'subsets': _agieval_gaokao_sets})
|
||||
|
||||
_agieval_all = ['gaokao-chinese', 'gaokao-english', 'gaokao-geography', 'gaokao-history', 'gaokao-biology', 'gaokao-chemistry', 'gaokao-physics', 'gaokao-mathqa', 'logiqa-zh', 'lsat-ar', 'lsat-lr', 'lsat-rc', 'logiqa-en', 'sat-math', 'sat-en', 'sat-en-without-passage', 'aqua-rat', 'jec-qa-kd', 'jec-qa-ca', 'gaokao-mathcloze', 'math']
|
||||
_agieval_all = ['agieval-' + s for s in _agieval_all]
|
||||
agieval_summary_groups.append({'name': 'agieval', 'subsets': _agieval_all})
|
88
configs/summarizers/medium.py
Normal file
88
configs/summarizers/medium.py
Normal file
@ -0,0 +1,88 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .groups.agieval import agieval_summary_groups
|
||||
from .groups.mmlu import mmlu_summary_groups
|
||||
from .groups.ceval import ceval_summary_groups
|
||||
from .groups.bbh import bbh_summary_groups
|
||||
from .groups.GaokaoBench import GaokaoBench_summary_groups
|
||||
from .groups.flores import flores_summary_groups
|
||||
|
||||
summarizer = dict(
|
||||
dataset_abbrs = [
|
||||
'--- Exam ---',
|
||||
'agieval',
|
||||
'mmlu-all-set',
|
||||
"ceval",
|
||||
"GaokaoBench",
|
||||
"bbh",
|
||||
'--- Coding ---',
|
||||
'openai_humaneval',
|
||||
'mbpp',
|
||||
'--- ChineseUniversal ---',
|
||||
'C3',
|
||||
'CMRC_dev',
|
||||
'DRCD_dev',
|
||||
'afqmc-dev',
|
||||
'cmnli',
|
||||
'ocnli',
|
||||
'bustm-dev',
|
||||
'chid-dev',
|
||||
'cluewsc-dev',
|
||||
'csl_dev',
|
||||
'eprstmt-dev',
|
||||
'ocnli_fc-dev',
|
||||
'tnews-dev',
|
||||
'lcsts',
|
||||
'--- Completion ---',
|
||||
'lambada',
|
||||
'story_cloze',
|
||||
'--- EnglishUniversal ---',
|
||||
'AX_b',
|
||||
'AX_g',
|
||||
'BoolQ',
|
||||
'CB',
|
||||
'COPA',
|
||||
'MultiRC',
|
||||
'RTE',
|
||||
'ReCoRD',
|
||||
'WiC',
|
||||
'WSC',
|
||||
'race-high',
|
||||
'race-middle',
|
||||
'--- NLG ---',
|
||||
'Xsum',
|
||||
'--- Reasoning ---',
|
||||
'gsm8k',
|
||||
'summedits',
|
||||
'math',
|
||||
'TheoremQA',
|
||||
'--- QA ---',
|
||||
'hellaswag',
|
||||
'ARC-e',
|
||||
'ARC-c',
|
||||
'commonsense_qa',
|
||||
'piqa',
|
||||
'siqa',
|
||||
'strategyqa',
|
||||
'winogrande',
|
||||
'openbookqa',
|
||||
'openbookqa_fact',
|
||||
'nq',
|
||||
'triviaqa',
|
||||
'--- Translation ---',
|
||||
'flores_100_Indo-European-Germanic_English',
|
||||
'flores_100_English_Indo-European-Germanic',
|
||||
'flores_100_Indo-European-Romance_English',
|
||||
'flores_100_English_Indo-European-Romance',
|
||||
'flores_100_zho_simpl-eng',
|
||||
'flores_100_eng-zho_simpl',
|
||||
'--- Security ---',
|
||||
'crows_pairs',
|
||||
],
|
||||
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
||||
prompt_db=dict(
|
||||
database_path='configs/datasets/log.json',
|
||||
config_dir='configs/datasets',
|
||||
blacklist='.promptignore'),
|
||||
)
|
64
configs/summarizers/small.py
Normal file
64
configs/summarizers/small.py
Normal file
@ -0,0 +1,64 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .groups.agieval import agieval_summary_groups
|
||||
from .groups.mmlu import mmlu_summary_groups
|
||||
from .groups.ceval import ceval_summary_groups
|
||||
from .groups.bbh import bbh_summary_groups
|
||||
from .groups.GaokaoBench import GaokaoBench_summary_groups
|
||||
from .groups.flores import flores_summary_groups
|
||||
|
||||
summarizer = dict(
|
||||
dataset_abbrs = [
|
||||
'--- Exam ---',
|
||||
'mmlu-all-set',
|
||||
"ceval",
|
||||
"bbh",
|
||||
'--- ChineseUniversal ---',
|
||||
'CMRC_dev',
|
||||
'DRCD_dev',
|
||||
'afqmc-dev',
|
||||
'bustm-dev',
|
||||
'chid-dev',
|
||||
'cluewsc-dev',
|
||||
'eprstmt-dev',
|
||||
'--- Coding ---',
|
||||
'openai_humaneval',
|
||||
'mbpp',
|
||||
'--- Completion ---',
|
||||
'lambada',
|
||||
'story_cloze',
|
||||
'--- EnglishUniversal ---',
|
||||
'AX_b',
|
||||
'AX_g',
|
||||
'BoolQ',
|
||||
'CB',
|
||||
'COPA',
|
||||
'MultiRC',
|
||||
'RTE',
|
||||
'ReCoRD',
|
||||
'WiC',
|
||||
'WSC',
|
||||
'race-high',
|
||||
'race-middle',
|
||||
'--- Reasoning ---',
|
||||
'math',
|
||||
'gsm8k',
|
||||
'summedits',
|
||||
'--- QA ---',
|
||||
'hellaswag',
|
||||
'piqa',
|
||||
'winogrande',
|
||||
'openbookqa',
|
||||
'openbookqa_fact',
|
||||
'nq',
|
||||
'triviaqa',
|
||||
'--- Security ---',
|
||||
'crows_pairs',
|
||||
],
|
||||
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
|
||||
prompt_db=dict(
|
||||
database_path='configs/datasets/log.json',
|
||||
config_dir='configs/datasets',
|
||||
blacklist='.promptignore'),
|
||||
)
|
62
docs/en/_static/css/readthedocs.css
Normal file
62
docs/en/_static/css/readthedocs.css
Normal file
@ -0,0 +1,62 @@
|
||||
.header-logo {
|
||||
background-image: url("../image/logo.png");
|
||||
background-size: 183px 50px;
|
||||
height: 50px;
|
||||
width: 183px;
|
||||
}
|
||||
|
||||
@media screen and (min-width: 1100px) {
|
||||
.header-logo {
|
||||
top: -12px;
|
||||
}
|
||||
}
|
||||
|
||||
pre {
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
@media screen and (min-width: 2000px) {
|
||||
.pytorch-content-left {
|
||||
width: 1200px;
|
||||
margin-left: 30px;
|
||||
}
|
||||
article.pytorch-article {
|
||||
max-width: 1200px;
|
||||
}
|
||||
.pytorch-breadcrumbs-wrapper {
|
||||
width: 1200px;
|
||||
}
|
||||
.pytorch-right-menu.scrolling-fixed {
|
||||
position: fixed;
|
||||
top: 45px;
|
||||
left: 1580px;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
article.pytorch-article section code {
|
||||
padding: .2em .4em;
|
||||
background-color: #f3f4f7;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
/* Disable the change in tables */
|
||||
article.pytorch-article section table code {
|
||||
padding: unset;
|
||||
background-color: unset;
|
||||
border-radius: unset;
|
||||
}
|
||||
|
||||
table.autosummary td {
|
||||
width: 50%
|
||||
}
|
||||
|
||||
img.align-center {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
|
||||
article.pytorch-article p.rubric {
|
||||
font-weight: bold;
|
||||
}
|
BIN
docs/en/_static/image/logo.png
Normal file
BIN
docs/en/_static/image/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
86
docs/en/user_guides/experimentation.md
Normal file
86
docs/en/user_guides/experimentation.md
Normal file
@ -0,0 +1,86 @@
|
||||
# Task Execution and Monitoring
|
||||
|
||||
## Initiation of Assessment Task
|
||||
|
||||
The program entry for the assessment task is `run.py`, its usage is as follows:
|
||||
|
||||
```shell
|
||||
run.py [-p PARTITION] [-q QUOTATYPE] [--debug] [-m MODE] [-r [REUSE]] [-w WORKDIR] [-l LARK] config
|
||||
```
|
||||
|
||||
The parameter explanation is as follows:
|
||||
|
||||
- -p Specify the slurm partition;
|
||||
- -q Specify the slurm quotatype (default is auto), with optional values being reserved, auto, spot;
|
||||
- --debug When enabled, inference and evaluation tasks will run in single-process mode, and output will be echoed in real-time for debugging;
|
||||
- -m Run mode, default is all. It can be specified as infer to only run inference and obtain output results; if there are already model outputs in {WORKDIR}, it can be specified as eval to only run evaluation and obtain evaluation results; if there are individual evaluation results in results, it can be specified as viz to only run visualization; if specified as all, both inference and evaluation tasks run at the same time.
|
||||
- -r Reuse existing inference results. If followed by a timestamp, the result under that timestamp in the workspace path will be reused; otherwise, the latest result in the specified workspace path will be reused.
|
||||
- -w Specify the working path, default is ./outputs/default
|
||||
- -l Enable status reporting via Lark bot.
|
||||
|
||||
Using run mode `-m all` as an example, the overall execution flow is as follows:
|
||||
|
||||
1. Read the configuration file, parse out the model, dataset, evaluator, and other configuration information
|
||||
2. The evaluation task mainly includes three stages: inference infer, evaluation eval, and visualization viz. After task division by Partitioner, they are handed over to Runner for parallel execution. Individual inference and evaluation tasks are abstracted into OpenICLInferTask and OpenICLEvalTask respectively.
|
||||
3. After each stage ends, the visualization stage will read the evaluation results in results to generate a visualization report.
|
||||
|
||||
## Task Monitoring: Lark Bot
|
||||
|
||||
Users can enable real-time monitoring of task status by setting up a Lark bot. Please refer to [this document](https://open.feishu.cn/document/ukTMukTMukTM/ucTM5YjL3ETO24yNxkjN?lang=zh-CN#7a28964d) for setting up the Lark bot.
|
||||
|
||||
Configuration method:
|
||||
|
||||
1. Open the `configs/lark.py` file, and add the following line:
|
||||
|
||||
```python
|
||||
lark_bot_url = 'YOUR_WEBHOOK_URL'
|
||||
```
|
||||
|
||||
Typically, the Webhook URL is formatted like this: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx .
|
||||
|
||||
2. Inherit this file in the complete evaluation configuration:
|
||||
|
||||
```python
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .lark import lark_bot_url
|
||||
|
||||
```
|
||||
|
||||
3. To avoid frequent messages from the bot becoming a nuisance, status updates are not automatically reported by default. You can start status reporting using `-l` or `--lark` when needed:
|
||||
|
||||
```bash
|
||||
python run.py configs/eval_demo.py -p {PARTITION} -l
|
||||
```
|
||||
|
||||
## Introduction of Summerizer
|
||||
|
||||
It is mainly used to visualize evaluation results.
|
||||
|
||||
## Run Results
|
||||
|
||||
All run results will be placed in `outputs/default/` directory by default, the directory structure is shown below:
|
||||
|
||||
```
|
||||
outputs/default/
|
||||
├── 20200220_120000
|
||||
├── ...
|
||||
├── 20230220_183030
|
||||
│ ├── configs
|
||||
│ ├── logs
|
||||
│ │ ├── eval
|
||||
│ │ └── infer
|
||||
│ ├── predictions
|
||||
│ │ └── MODEL1
|
||||
│ └── results
|
||||
│ └── MODEL1
|
||||
```
|
||||
|
||||
Each timestamp contains the following content:
|
||||
- configs folder, which stores the configuration files corresponding to each run with this timestamp as the output directory;
|
||||
- logs folder, which stores the output log files of the inference and evaluation phases, each folder will store logs in subfolders by model;
|
||||
- predictions folder, which stores the inferred json results, with a model subfolder;
|
||||
- results folder, which stores the evaluated json results, with a model subfolder.
|
||||
|
||||
Also, all `-r` without specifying a corresponding timestamp will select the newest folder by sorting as the output directory.
|
1
docs/en/user_guides/framework_overview.md
Normal file
1
docs/en/user_guides/framework_overview.md
Normal file
@ -0,0 +1 @@
|
||||
# Overview
|
10
docs/zh_cn/_static/js/custom.js
Normal file
10
docs/zh_cn/_static/js/custom.js
Normal file
@ -0,0 +1,10 @@
|
||||
var collapsedSections = ['Advanced Guides', 'Tools', 'User Guides', 'Notes'];
|
||||
|
||||
$(document).ready(function () {
|
||||
$('.model-summary').DataTable({
|
||||
"stateSave": false,
|
||||
"lengthChange": false,
|
||||
"pageLength": 20,
|
||||
"order": []
|
||||
});
|
||||
});
|
85
docs/zh_cn/index.rst
Normal file
85
docs/zh_cn/index.rst
Normal file
@ -0,0 +1,85 @@
|
||||
欢迎来到 OpenCompass 中文教程!
|
||||
==========================================
|
||||
|
||||
OpenCompass 上手路线
|
||||
-------------------------------
|
||||
|
||||
为了用户能够快速上手,我们推荐以下流程:
|
||||
|
||||
- 对于想要使用 OpenCompass 的用户,我们推荐先阅读 开始你的第一步_ 部分来设置环境。
|
||||
|
||||
- 对于一些基础使用,我们建议用户阅读 教程_ 。
|
||||
|
||||
- 若您想进行算法的自定义,我们提供了 进阶教程_ 。
|
||||
|
||||
- 如果您想调整提示语,您可以浏览 提示语_ 。
|
||||
|
||||
- 我们同样提供了 工具_ 。
|
||||
|
||||
|
||||
我们始终非常欢迎用户的 PRs 和 Issues 来完善 OpenCompass!
|
||||
|
||||
.. _开始你的第一步:
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 开始你的第一步
|
||||
|
||||
get_started.md
|
||||
|
||||
.. _教程:
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 教程
|
||||
|
||||
user_guides/framework_overview.md
|
||||
user_guides/config.md
|
||||
user_guides/dataset_prepare.md
|
||||
user_guides/models.md
|
||||
user_guides/evaluation.md
|
||||
user_guides/experimentation.md
|
||||
user_guides/metrics.md
|
||||
|
||||
.. _进阶教程:
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 进阶教程
|
||||
|
||||
advanced_guides/new_dataset.md
|
||||
advanced_guides/new_model.md
|
||||
|
||||
.. _提示语:
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 提示语
|
||||
|
||||
prompt/overview.md
|
||||
prompt/few_shot.md
|
||||
prompt/prompt_template.md
|
||||
prompt/meta_template.md
|
||||
|
||||
.. _工具:
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 工具
|
||||
|
||||
tools.md
|
||||
|
||||
.. _其他说明:
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 其他说明
|
||||
|
||||
notes/contribution_guide.md
|
||||
|
||||
.. toctree::
|
||||
:caption: 切换语言
|
||||
|
||||
English <https://mmpretrain.readthedocs.io/en/latest/>
|
||||
简体中文 <https://mmpretrain.readthedocs.io/zh_CN/latest/>
|
||||
|
||||
|
||||
索引与表格
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`search`
|
163
docs/zh_cn/prompt/meta_template.md
Normal file
163
docs/zh_cn/prompt/meta_template.md
Normal file
@ -0,0 +1,163 @@
|
||||
# Meta Prompt
|
||||
|
||||
## 背景
|
||||
|
||||
在 LLM 的实际 finetune 中,我们常常会根据实际的要求注入一些预定义的字符串,以求模型能按照自然语言的格式输出指定的内容。在评测时,我们也需要按照 finetune 时设定的格式输入问题,模型才能发挥出其最大的性能。因此,我们需要对 OpenICL 原本的 prompt 设计作一次增强,才能满足相应需求。
|
||||
|
||||
## Model - Meta Template
|
||||
|
||||
此前, prompt template 的设定绑定在数据集配置中。现在考虑到不同模型的 instruction 可能会有所不同,我们往 model config 中新增 `meta_template` 字段,允许用户指定与模型密切相关的 instruction。
|
||||
|
||||
```Python
|
||||
models = [
|
||||
dict(type='LLM',
|
||||
# ...
|
||||
meta_template = dict(
|
||||
begin="meta instruction\nYou are an AI assistant.\n",
|
||||
round=[
|
||||
dict(role='HUMAN', begin='<|HUMAN|>:', end='脷\n'), # begin and end can be a list of strings or integers.
|
||||
dict(role='THOUGHTS', begin='<|Inner Thoughts|>:', end='茔\n', prompt='None'),
|
||||
dict(role='COMMANDS', begin='<|Commands|>:', end='蝮\n', prompt='None'),
|
||||
dict(role='RESULTS', begin='<|Results|>:', end='兒\n', prompt='None'), # Here we can set the default prompt, which may be overridden by the speicfic dataset
|
||||
dict(role='BOT', begin='<|MOSS|>:', generate=True, end='氡\n'),
|
||||
],
|
||||
end="end of conversion",
|
||||
reserved_roles=[dict(role='SYSTEM', begin='<|SYSTEM|>: ', end='\n'),],
|
||||
# the token to stop the generation tasks (TODO: support string)
|
||||
eos_token_id=65605,
|
||||
),
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
这里,meta_template 是一个**字典**,该字典可以包含以下数个字段:
|
||||
|
||||
- `begin`,`end` :(str,可选) prompt 的开头,通常是一些 meta instruction。
|
||||
|
||||
- `round`:(list,可选) 约定了每一轮对话的 prompt 格式。每轮对话的 prompt 内容由 dataset config 中的 prompt template 控制(下文会详述)。如果不指定,则该字段将会直接被 dataset config 中的 prompt template 替换。
|
||||
|
||||
- (str,可选):收尾的 instruction。
|
||||
|
||||
- `reserved_roles` (list,可选)指定了在 meta template 中并未出现的预留角色。这里面定义的角色有可能在 dataset config 的 begin 或 end 中用到,例如 `SYSTEM` 角色。
|
||||
|
||||
- `eos_token_id` (int, 可选):指定了该模型在生成式任务中 eos token 的 id。如果不设置,则默认为 tokenizer 中的 eos token id。
|
||||
|
||||
`round` 指定了每轮对话中每个角色说话的格式,通常接受一个列表,内容可以是 **str 或 dict**。每个字典接受以下关键字:
|
||||
|
||||
- `role`(str): 对话中的角色,也可以认为是这个 prompt 的 identifier。该字符串并不影响实际的 prompt,仅用于在 dataset_config 中的指定对应项,并对其 prompt 内容进行覆盖。
|
||||
|
||||
- `begin`, `end` (str): 指定该角色在说话时的开头或结尾。
|
||||
|
||||
- `prompt` (str):prompt 的内容,遵循 `ICLPromptTemplate` 的格式规范。如果在 meta_prompt_template 中未指定,则必须在 dataset config 中的 prompt template 中指定。
|
||||
|
||||
- `generate` (bool): 指定为 True 时,该角色即为模型在生成任务中开始生成输出的位点。在生成任务中生成对应 prompt 时,prompt template 只会生成到该角色的 begin,剩下的内容由模型补全。
|
||||
|
||||
在上面的例子中,最后的 meta prompt 将会是:
|
||||
|
||||
```
|
||||
meta instructionYou are an AI assistant.
|
||||
<|HUMAN|>: 脷\n
|
||||
<|Inner Thoughts|>: None茔\n<|Commands|>: None蝮\n<|Results|>: None兒\n
|
||||
<|MOSS|>: 氡\n
|
||||
end of conversion
|
||||
```
|
||||
|
||||
特别地,在生成式任务中,prompt 仅会生成到 \<|MOSS|>: 后:
|
||||
|
||||
```
|
||||
meta instructionYou are an AI assistant.
|
||||
<|HUMAN|>: 脷\n
|
||||
<|Inner Thoughts|>: None茔\n<|Commands|>: None蝮\n<|Results|>: None兒\n
|
||||
<|MOSS|>:
|
||||
```
|
||||
|
||||
接下来我们在 dataset config 中进行进一步约定。
|
||||
|
||||
## Dataset: Prompt Template
|
||||
|
||||
在 model 配置中约定了该 model 所需的 meta template 后,dataset 中 prompt template 的格式也会有所变化。同时,该方向尽可能地保持了 prompt 的 backward compatibility。
|
||||
|
||||
在改动前,`PromptTemplate` 接受 str 或 dict 作为输入。其中,dict 形式的输入将 label string 映射到对应的 prompt (str)上,通常用作为 `PPLInferencer` 的输入。因而本质上,`PromptTemplate` 的旧版实现里表示 prompt 的方式只有 `str` 一种。
|
||||
|
||||
而改动后的 prompt template 允许接受的 prompt 基本形式从 str 扩展到了 dict。
|
||||
|
||||
这个 dict 的格式与 meta template 相似,用户也可以指定 `begin`, `end` 和 `round` 关键字:
|
||||
|
||||
```Python
|
||||
mmlu_prompt_template = dict(
|
||||
type='PromptTemplate',
|
||||
template=dict(
|
||||
begin=[dict(role='SYSTEM', fallback_role='HUMAN', prompt='The following are '
|
||||
'multiple choice questions (with answers) about physics.'),
|
||||
'</E>',
|
||||
],
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='</input>\nA. </A>\nB. </B>\nC. </C>\nD. </D>\nAnswer: '),
|
||||
dict(role='BOT', prompt='</target>'),
|
||||
],
|
||||
end="end of dataset prompt template."
|
||||
),
|
||||
column_token_map={
|
||||
'input': '</input>',
|
||||
'A': '</A>',
|
||||
'B': '</B>',
|
||||
'C': '</C>',
|
||||
'D': '</D>',
|
||||
'target': '</target>'
|
||||
},
|
||||
ice_token='</E>',
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
其中,`round`用于指定在每轮对话中角色的 prompt 格式,同时也是为了呼应和补全 meta template 中的配置,因此,其接受的参数和规则均与 meta template 中的 `round` 一致。**在实际运行时,两处 prompt 的配置将会融合,同时如果某一字段被重复定义,则以 dataset config 中定义为准。**
|
||||
|
||||
而 `begin` 和 `end` 则除了支持 str 类型的输入,也支持 list 类型的输入,在其中用户可以通过组合 dict 和字符串实现对系统角色的融合。留意到例子中引入了 `fallback_role` 的设定,意味着若系统在 meta template 中 reserved_roles 中找不到 `role` 中的角色时,会自动替换成 `fallback_role` 中的角色。这个特征的设立是为了尽可能确保 prompt 模板的通用性。
|
||||
|
||||
结合 meta template,最终生成的 prompt 模板为:
|
||||
|
||||
```Plain
|
||||
meta instruction
|
||||
You are an AI assistant.
|
||||
<|SYSTEM|>: The following are multiple choice questions (with answers) about college biology.
|
||||
<|HUMAN|>: Which of the following is NOT a characteristic of an oligotrophic lake?
|
||||
A. Low nutrient levels
|
||||
B. High altitudes
|
||||
C. Shallow water
|
||||
D. Sand or gravel bottom
|
||||
Answer: 脷\n
|
||||
<|Inner Thoughts|>: None茔\n
|
||||
<|Commands|>: None蝮\n
|
||||
<|Results|>: None兒\n
|
||||
<|MOSS|>: A氡\n
|
||||
end of dataset prompt template.
|
||||
end of conversion
|
||||
```
|
||||
|
||||
特别地,由于这种 prompt 的数据结构(dict)与旧版的 label -> prompt 映射相同,本实现仅在字典的 keys 为 {`begin`, `round`, `end`} 的子集时将 prompt 的输入以新版规则进行解码,否则依然将字典以 label -> prompt 的形式进行解码。此外,该方案也允许新版 prompt 字典嵌套在旧版的 label -> prompt 字典中。例如,以下表达方式也是合法的 (摘自 `configs/datasets/mmlu.py`):
|
||||
|
||||
```Python
|
||||
prompt_template={
|
||||
target:
|
||||
dict(
|
||||
begin=[dict(role='SYSTEM', fallback_role='HUMAN', prompt='The following are '
|
||||
'multiple choice questions (with answers) about '
|
||||
f'{name.replace("_", " ")}.\n'),
|
||||
'</E>',
|
||||
],
|
||||
round=[
|
||||
dict(role='HUMAN', prompt='</input>\nA. </A>\nB. </B>\nC. </C>\nD. </D>\nAnswer: '),
|
||||
dict(role='BOT', prompt=f'{target}'),
|
||||
]
|
||||
)
|
||||
for target in ['A', 'B', 'C', 'D'] # use the actual answer
|
||||
}
|
||||
```
|
||||
|
||||
### 无 meta template 时
|
||||
|
||||
为了保证后向兼容性,当用户未在 model config 中指定 meta template 时,`ICLPromptTemplate` 会将每个 dict 按照 `begin`, `prompt`, `end` 的顺序拼接为普通字符串。
|
||||
|
||||
### 多轮对话例子
|
||||
|
||||
在某些时候,一轮完整的交互中可能需要包含多轮对话。用户可以参考 `configs/datasets/gsm8k.py` 配置自己的模板。
|
66
docs/zh_cn/tools.md
Normal file
66
docs/zh_cn/tools.md
Normal file
@ -0,0 +1,66 @@
|
||||
# 实用工具
|
||||
|
||||
## Prompt Viewer
|
||||
|
||||
本工具允许你在不启动完整训练流程的情况下,直接查看模型会接收到的 prompt。
|
||||
|
||||
运行方式:
|
||||
|
||||
```bash
|
||||
python tools/prompt_viewer.py [CONFIG_PATH]
|
||||
```
|
||||
|
||||
## Case Analyzer
|
||||
|
||||
本工具在已有评测结果的基础上,产出推理错误样本以及带有标注信息的全量样本
|
||||
|
||||
运行方式:
|
||||
|
||||
```bash
|
||||
python tools/case_analyzer.py [CONFIG_PATH] [-w WORK_DIR]
|
||||
```
|
||||
|
||||
- `-w`:工作路径,默认为 `'./outputs/default'`。
|
||||
|
||||
更多细节见 [飞书文档](https://aicarrier.feishu.cn/docx/SgrLdwinion00Kxkzh2czz29nIh)
|
||||
|
||||
## Lark Bot
|
||||
|
||||
用户可以通过配置飞书机器人,实现任务状态的实时监控。飞书机器人的设置文档请[参考这里](https://open.feishu.cn/document/ukTMukTMukTM/ucTM5YjL3ETO24yNxkjN?lang=zh-CN#7a28964d)。
|
||||
|
||||
配置方式:
|
||||
|
||||
- 打开 `configs/secrets.py` 文件,并在文件中加入以下行:
|
||||
|
||||
```python
|
||||
lark_bot_url = 'YOUR_WEBHOOK_URL'
|
||||
```
|
||||
|
||||
通常, Webhook URL 格式如 https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx 。
|
||||
|
||||
- 在完整的评测配置中继承该文件:
|
||||
|
||||
```python
|
||||
_base_ = [
|
||||
'secrets.py',
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
实例可见 `configs/eval.py`。
|
||||
|
||||
- 为了避免机器人频繁发消息形成骚扰,默认运行时状态不会自动上报。有需要时,可以通过 `-l` 或 `--lark` 启动状态上报:
|
||||
|
||||
```bash
|
||||
python run.py configs/eval_demo.py -p {PARTITION} -l
|
||||
```
|
||||
|
||||
## API Model Tests
|
||||
|
||||
本工具可以快速测试 API Wrapper 的功能是否正常。
|
||||
|
||||
运行方式:
|
||||
|
||||
```bash
|
||||
python tools/test_api_model.py [CONFIG_PATH]
|
||||
```
|
105
docs/zh_cn/user_guides/dataset_prepare.md
Normal file
105
docs/zh_cn/user_guides/dataset_prepare.md
Normal file
@ -0,0 +1,105 @@
|
||||
# 数据集准备和选择
|
||||
|
||||
本节教程主要关注如何准备 OpenCompass 已支持的数据集,并构建需要的配置文件完成数据集的选择。
|
||||
|
||||
## 数据集配置文件目录结构
|
||||
|
||||
首先简单介绍一下 OpenCompass `configs/datasets` 目录下的结构,如下所示:
|
||||
|
||||
```
|
||||
configs/datasets/
|
||||
├── ChineseUniversal # 能力维度
|
||||
│ ├── CLUE_afqmc # 该维度下的数据集
|
||||
│ │ ├── CLUE_afqmc_gen_db509b.py # 该数据集的不同配置文件
|
||||
│ │ ├── CLUE_afqmc_gen.py
|
||||
│ │ ├── CLUE_afqmc_ppl_00b348.py
|
||||
│ │ ├── CLUE_afqmc_ppl_2313cf.py
|
||||
│ │ └── CLUE_afqmc_ppl.py
|
||||
│ ├── CLUE_C3
|
||||
│ │ ├── ...
|
||||
│ ├── ...
|
||||
├── Coding
|
||||
├── collections
|
||||
├── Completion
|
||||
├── EnglishUniversal
|
||||
├── Exam
|
||||
├── glm
|
||||
├── LongText
|
||||
├── MISC
|
||||
├── NLG
|
||||
├── QA
|
||||
├── Reasoning
|
||||
├── Security
|
||||
└── Translation
|
||||
```
|
||||
|
||||
在 `configs/datasets` 目录结构下,我们主要以能力维度对数据集划分了十余项维度,例如:中英文通用、考试、问答、推理、安全等等。每一项维度又包含了一系列数据集,在各个数据集对应的文件夹下存在多个数据集配置。
|
||||
|
||||
数据集配置文件名由以下命名方式构成 `{数据集名称}_{评测方式}_{prompt版本号}.py`,以 `ChineseUniversal/CLUE_afqmc/CLUE_afqmc_gen_db509b.py` 为例,该配置文件则为中文通用能力下的 `CLUE_afqmc` 数据集,对应的评测方式为 `gen`,即生成式评测,对应的prompt版本号为 `db509b`;同样的, `CLUE_afqmc_ppl_00b348.py` 指评测方式为`ppl`即判别式评测,prompt版本号为 `00b348` 。
|
||||
|
||||
除此之外,不带版本号的文件,例如: `CLUE_afqmc_gen.py` 则指向该评测方式最新的prompt配置文件,通常来说会是精度最高的prompt。
|
||||
|
||||
## 数据集准备
|
||||
|
||||
OpenCompass 支持的数据集主要包括两个部分:
|
||||
|
||||
1. Huggingface 数据集
|
||||
|
||||
[Huggingface Dataset](https://huggingface.co/datasets) 提供了大量的数据集。OpenCompass 已经支持了大多数常用于性能比较的数据集,具体支持的数据集列表请直接在 `configs/dataset` 下进行查找。
|
||||
|
||||
2. OpenCompass 自建数据集
|
||||
|
||||
除了支持 Huggingface 已有的数据集, OpenCompass 还提供了一些自建CN数据集,未来将会提供一个数据集相关的Repo供用户下载使用。按照文档指示将数据集统一放置在`./data`目录下即可完成数据集准备。
|
||||
|
||||
需要注意的是,Repo中不仅包含自建的数据集,为了方便也加入了部分HF已支持的数据集方便测试。
|
||||
|
||||
## 数据集选择
|
||||
|
||||
在各个数据集配置文件中,数据集将会被定义在 `{}_datasets` 变量当中,例如下面 `ChineseUniversal/CLUE_afqmc/CLUE_afqmc_gen_db509b.py` 中的 `afqmc_datasets`。
|
||||
|
||||
```python
|
||||
afqmc_datasets = [
|
||||
dict(
|
||||
abbr="afqmc-dev",
|
||||
type=AFQMCDataset_V2,
|
||||
path="./data/CLUE/AFQMC/dev.json",
|
||||
reader_cfg=afqmc_reader_cfg,
|
||||
infer_cfg=afqmc_infer_cfg,
|
||||
eval_cfg=afqmc_eval_cfg,
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
以及 `ChineseUniversal/CLUE_cmnli/CLUE_cmnli_ppl_b78ad4.py` 中的 `afqmc_datasets`。
|
||||
|
||||
```python
|
||||
cmnli_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
abbr='cmnli',
|
||||
path='json',
|
||||
split='train',
|
||||
data_files='./data/CLUE/cmnli/cmnli_public/dev.json',
|
||||
reader_cfg=cmnli_reader_cfg,
|
||||
infer_cfg=cmnli_infer_cfg,
|
||||
eval_cfg=cmnli_eval_cfg)
|
||||
]
|
||||
```
|
||||
|
||||
以上述两个数据集为例, 如果用户想同时评测这两个数据集,可以在 `configs` 目录下新建一个配置文件,我们使用 `mmengine` 配置中直接import的机制来构建数据集部分的参数,如下所示:
|
||||
|
||||
```python
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .datasets.CLUE_afqmc.CLUE_afqmc_gen_db509b import afqmc_datasets
|
||||
from .datasets.CLUE_cmnli.CLUE_cmnli_ppl_b78ad4 import cmnli_datasets
|
||||
|
||||
datasets = []
|
||||
datasets += afqmc_datasets
|
||||
datasets += cmnli_datasets
|
||||
```
|
||||
|
||||
用户可以根据需要,选择不同能力不同数据集以及不同评测方式的配置文件来构建评测脚本中数据集的部分。
|
||||
|
||||
有关如何启动评测任务,以及如何评测自建数据集可以参考相关文档。
|
132
opencompass/datasets/GaokaoBench.py
Normal file
132
opencompass/datasets/GaokaoBench.py
Normal file
@ -0,0 +1,132 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class GaokaoBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
return Dataset.from_list(data['example'])
|
||||
|
||||
|
||||
valid_gaokao_bench_question_types = [
|
||||
'single_choice', 'multi_choice', 'multi_question_choice',
|
||||
'five_out_of_seven', 'cloze', 'subjective', 'correction'
|
||||
]
|
||||
|
||||
|
||||
class GaokaoBenchEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self, question_type) -> None:
|
||||
super().__init__()
|
||||
assert question_type in valid_gaokao_bench_question_types
|
||||
self.question_type = question_type
|
||||
|
||||
def do_predictions_postprocess(self, model_output, answer_lenth=None):
|
||||
if self.question_type == 'single_choice':
|
||||
model_answer = []
|
||||
temp = re.findall(r'[A-D]', model_output[::-1])
|
||||
if len(temp) != 0:
|
||||
model_answer.append(temp[0])
|
||||
|
||||
elif self.question_type == 'multi_question_choice':
|
||||
model_answer = []
|
||||
temp = re.findall(r'【答案】\s*[::]*\s*[A-Z]', model_output)
|
||||
|
||||
if len(temp) == answer_lenth:
|
||||
for t in temp:
|
||||
model_answer.append(re.findall(r'[A-Z]', t)[0])
|
||||
else:
|
||||
temp = re.findall(r'[A-Z]', model_output)
|
||||
if len(temp) > 0:
|
||||
for k in range(min(len(temp), answer_lenth)):
|
||||
model_answer.append(temp[k])
|
||||
|
||||
elif self.question_type == 'multi_choice':
|
||||
model_answer = []
|
||||
answer = ''
|
||||
content = re.sub(r'\s+', '', model_output)
|
||||
answer_index = content.find('【答案】')
|
||||
if answer_index > 0:
|
||||
temp = content[answer_index:]
|
||||
if len(re.findall(r'[A-D]', temp)) > 0:
|
||||
for t in re.findall(r'[A-D]', temp):
|
||||
answer += t
|
||||
else:
|
||||
temp = content[-10:]
|
||||
if len(re.findall(r'[A-D]', temp)) > 0:
|
||||
for t in re.findall(r'[A-D]', temp):
|
||||
answer += t
|
||||
if len(answer) != 0:
|
||||
model_answer.append(answer)
|
||||
|
||||
elif self.question_type == 'five_out_of_seven':
|
||||
model_answer = []
|
||||
temp = re.findall(r'[A-G]', model_output)
|
||||
if len(temp) > 0:
|
||||
for k in range(min(5, len(temp))):
|
||||
model_answer.append(temp[k])
|
||||
|
||||
return model_answer
|
||||
|
||||
def ensure_same_length(self, pred, refr):
|
||||
if len(pred) == len(refr):
|
||||
return pred
|
||||
return ['Z'] * len(refr)
|
||||
|
||||
def score(self, predictions, references):
|
||||
if self.question_type not in [
|
||||
'single_choice', 'multi_choice', 'multi_question_choice',
|
||||
'five_out_of_seven'
|
||||
]:
|
||||
return {'score': 0}
|
||||
elif self.question_type == 'multi_choice':
|
||||
correct_score, total_score = 0, 0
|
||||
for pred, refr in zip(predictions, references):
|
||||
pred = self.do_predictions_postprocess(pred)
|
||||
pred = self.ensure_same_length(pred, refr)
|
||||
for p, r in zip(pred, refr):
|
||||
if p == r:
|
||||
correct_score += 2
|
||||
else:
|
||||
for i in p:
|
||||
if i not in r:
|
||||
break
|
||||
else:
|
||||
correct_score += 1
|
||||
total_score += 2
|
||||
return {'score': correct_score / total_score * 100}
|
||||
else:
|
||||
correct_score, total_score = 0, 0
|
||||
for pred, refr in zip(predictions, references):
|
||||
if self.question_type == 'multi_question_choice':
|
||||
pred = self.do_predictions_postprocess(pred, len(refr))
|
||||
else:
|
||||
pred = self.do_predictions_postprocess(pred)
|
||||
pred = self.ensure_same_length(pred, refr)
|
||||
for p, r in zip(pred, refr):
|
||||
if p == r:
|
||||
correct_score += 1
|
||||
total_score += 1
|
||||
return {'score': correct_score / total_score * 100}
|
||||
|
||||
|
||||
for question_type in valid_gaokao_bench_question_types:
|
||||
# fix classic closure problem
|
||||
def _gaokao_register(question_type):
|
||||
ICL_EVALUATORS.register_module(
|
||||
name='GaokaoBenchEvaluator' + '_' + question_type,
|
||||
module=lambda *args, **kwargs: GaokaoBenchEvaluator(
|
||||
question_type=question_type, *args, **kwargs))
|
||||
|
||||
_gaokao_register(question_type)
|
21
opencompass/datasets/afqmcd.py
Normal file
21
opencompass/datasets/afqmcd.py
Normal file
@ -0,0 +1,21 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class AFQMCDataset_V2(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path):
|
||||
data = []
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
line['label'] = 'AB'[int(line['label'])]
|
||||
data.append(line)
|
||||
return Dataset.from_list(data)
|
24
opencompass/datasets/ax.py
Normal file
24
opencompass/datasets/ax.py
Normal file
@ -0,0 +1,24 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class AXDataset_V2(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
dataset = []
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
line['label'] = {
|
||||
'entailment': 'A',
|
||||
'not_entailment': 'B'
|
||||
}[line['label']]
|
||||
dataset.append(line)
|
||||
return Dataset.from_list(dataset)
|
28
opencompass/datasets/base.py
Normal file
28
opencompass/datasets/base.py
Normal file
@ -0,0 +1,28 @@
|
||||
from abc import abstractstaticmethod
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.openicl import DatasetReader
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
|
||||
def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs):
|
||||
self.dataset = self.load(**kwargs)
|
||||
self._init_reader(**reader_cfg)
|
||||
|
||||
def _init_reader(self, **kwargs):
|
||||
self.reader = DatasetReader(self.dataset, **kwargs)
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
return self.reader.dataset['train']
|
||||
|
||||
@property
|
||||
def test(self):
|
||||
return self.reader.dataset['test']
|
||||
|
||||
@abstractstaticmethod
|
||||
def load(**kwargs) -> Union[Dataset, DatasetDict]:
|
||||
pass
|
71
opencompass/datasets/bbh.py
Normal file
71
opencompass/datasets/bbh.py
Normal file
@ -0,0 +1,71 @@
|
||||
import json
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||
TEXT_POSTPROCESSORS)
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class BBHDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str):
|
||||
with open(osp.join(path, f'{name}.json'), 'r') as f:
|
||||
data = json.load(f)['examples']
|
||||
dataset = Dataset.from_list(data)
|
||||
return dataset
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('bbh-mcq')
|
||||
def bbh_mcq_postprocess(text: str) -> str:
|
||||
ans = text
|
||||
ans_line = ans.split('answer is ')
|
||||
if len(ans_line) != 1:
|
||||
ans = ans_line[1].strip()
|
||||
match = re.search(r'\(([A-Z])\)*', ans)
|
||||
if match:
|
||||
return match.group(1)
|
||||
match = re.search(r'([A-Z])', ans)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return ans
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('bbh-freeform')
|
||||
def bbh_freeform_postprocess(text: str) -> str:
|
||||
ans = text
|
||||
ans_line = ans.split('answer is ')
|
||||
if len(ans_line) != 1:
|
||||
ans = ans_line[1].strip()
|
||||
ans = ans.split('\n')[0]
|
||||
if ans.endswith('.'):
|
||||
ans = ans[:-1]
|
||||
return ans
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class BBHEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error': 'predictions and references have different '
|
||||
'length'
|
||||
}
|
||||
|
||||
predictions = [bbh_freeform_postprocess(pred) for pred in predictions]
|
||||
|
||||
cnt = 0
|
||||
for pred, ref in zip(predictions, references):
|
||||
if pred == ref:
|
||||
cnt += 1
|
||||
|
||||
score = cnt / len(predictions) * 100
|
||||
|
||||
return {'score': score}
|
41
opencompass/datasets/boolq.py
Normal file
41
opencompass/datasets/boolq.py
Normal file
@ -0,0 +1,41 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class BoolQDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(**kwargs):
|
||||
|
||||
dataset = load_dataset(**kwargs)
|
||||
|
||||
def preprocess(example):
|
||||
if example['label'] == 'true':
|
||||
example['answer'] = 1
|
||||
else:
|
||||
example['answer'] = 0
|
||||
|
||||
return example
|
||||
|
||||
dataset = dataset.map(preprocess)
|
||||
return dataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class BoolQDataset_V2(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path):
|
||||
dataset = []
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
line['label'] = {'true': 'A', 'false': 'B'}[line['label']]
|
||||
dataset.append(line)
|
||||
return Dataset.from_list(dataset)
|
21
opencompass/datasets/bustum.py
Normal file
21
opencompass/datasets/bustum.py
Normal file
@ -0,0 +1,21 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class bustumDataset_V2(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path):
|
||||
data = []
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
line['label'] = 'AB'[int(line['label'])]
|
||||
data.append(line)
|
||||
return Dataset.from_list(data)
|
40
opencompass/datasets/cmrc.py
Normal file
40
opencompass/datasets/cmrc.py
Normal file
@ -0,0 +1,40 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class CMRCDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
# 将原始数据转换为所需的格式
|
||||
rows = []
|
||||
for index, paragraphs in enumerate(data['data']):
|
||||
for paragraph in paragraphs['paragraphs']:
|
||||
|
||||
context = paragraph['context']
|
||||
|
||||
for question in paragraph['qas']:
|
||||
answers = question['answers']
|
||||
unique_answers = list(set([a['text'] for a in answers]))
|
||||
rows.append({
|
||||
'context': context,
|
||||
'question': question['question'],
|
||||
'answers': unique_answers
|
||||
})
|
||||
|
||||
# 创建 Dataset
|
||||
dataset = Dataset.from_dict({
|
||||
'context': [row['context'] for row in rows],
|
||||
'question': [row['question'] for row in rows],
|
||||
'answers': [row['answers'] for row in rows]
|
||||
})
|
||||
|
||||
return dataset
|
37
opencompass/datasets/govrepcrs.py
Normal file
37
opencompass/datasets/govrepcrs.py
Normal file
@ -0,0 +1,37 @@
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class GovRepcrsDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
import json
|
||||
import os
|
||||
dataset_dict = DatasetDict()
|
||||
splits = ['train', 'valid', 'test']
|
||||
dataset_lists = {x: [] for x in splits}
|
||||
for split in splits:
|
||||
split_fp = os.path.join(path, 'gov-report', 'split_ids',
|
||||
'crs_' + split + '.ids')
|
||||
with open(split_fp, 'r') as f:
|
||||
for line in f.readlines():
|
||||
xpath = os.path.join(path, 'gov-report', 'crs',
|
||||
line.strip() + '.json')
|
||||
with open(xpath, 'r') as df:
|
||||
data = json.load(df)
|
||||
content = data['title'] + '\n' + '\n'.join(
|
||||
[(x['section_title'] if x['section_title'] else '')
|
||||
+ '\n' + '\n'.join(x['paragraphs'])
|
||||
for x in data['reports']['subsections']])
|
||||
summary = '\n'.join(data['summary'])
|
||||
dataset_lists[split].append({
|
||||
'content': content,
|
||||
'summary': summary,
|
||||
})
|
||||
dataset_dict[split] = Dataset.from_list(dataset_lists[split])
|
||||
return dataset_dict
|
16
opencompass/datasets/iwslt2017.py
Normal file
16
opencompass/datasets/iwslt2017.py
Normal file
@ -0,0 +1,16 @@
|
||||
from datasets import load_dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class IWSLT2017Dataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(**kwargs):
|
||||
dataset = load_dataset(**kwargs)
|
||||
dataset = dataset.map(lambda example: example['translation']
|
||||
).remove_columns('translation')
|
||||
return dataset
|
43
opencompass/datasets/narrativeqa.py
Normal file
43
opencompass/datasets/narrativeqa.py
Normal file
@ -0,0 +1,43 @@
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class NarrativeQADataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
import csv
|
||||
import os
|
||||
dataset_dict = DatasetDict()
|
||||
splits = ['train', 'valid', 'test']
|
||||
dataset_lists = {x: [] for x in splits}
|
||||
with open(os.path.join(path, 'qaps.csv'), 'r') as f:
|
||||
reader = csv.reader(f, delimiter=',')
|
||||
for row in reader:
|
||||
if row[1] == 'set':
|
||||
continue
|
||||
split = row[1] # set
|
||||
answers = [row[3], row[4]] # row['answer1'], row['answer2']
|
||||
question = row[2] # question
|
||||
x_path = os.path.join(path, 'tmp',
|
||||
row[0] + '.content') # document_id
|
||||
|
||||
try:
|
||||
with open(x_path, 'r', encoding='utf-8') as f:
|
||||
evidence = f.read(100000)
|
||||
except: # noqa: E722
|
||||
continue
|
||||
dataset_lists[split].append({
|
||||
'answer': answers,
|
||||
'question': question,
|
||||
'evidence': evidence,
|
||||
})
|
||||
|
||||
for split in splits:
|
||||
dataset_dict[split] = Dataset.from_list(dataset_lists[split])
|
||||
|
||||
return dataset_dict
|
58
opencompass/datasets/wsc.py
Normal file
58
opencompass/datasets/wsc.py
Normal file
@ -0,0 +1,58 @@
|
||||
import json
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class WSCDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(**kwargs):
|
||||
|
||||
dataset = load_dataset(**kwargs)
|
||||
|
||||
def preprocess(example):
|
||||
text_list = example['text'].split(' ')
|
||||
assert ' ' not in example['target']['span2_text']
|
||||
# span1 may have 1 or more than 1 words
|
||||
# span2 is the pronoun and has only 1 word
|
||||
text_list[example['target']
|
||||
['span2_index']] = example['target']['span1_text']
|
||||
example['new_text'] = ' '.join(text_list)
|
||||
if example['label'] == 'true':
|
||||
example['answer'] = 1
|
||||
else:
|
||||
example['answer'] = 0
|
||||
example['span1'] = example['target']['span1_text']
|
||||
example['span2'] = example['target']['span2_text']
|
||||
del example['target']
|
||||
return example
|
||||
|
||||
dataset = dataset.map(preprocess)
|
||||
return dataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class WSCDataset_V2(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path):
|
||||
data = []
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = json.loads(line)
|
||||
item = {
|
||||
'span1': line['target']['span1_text'],
|
||||
'span2': line['target']['span2_text'],
|
||||
'text': line['text'],
|
||||
'label': {
|
||||
'true': 'A',
|
||||
'false': 'B'
|
||||
}[line['label']],
|
||||
}
|
||||
data.append(item)
|
||||
return Dataset.from_list(data)
|
154
opencompass/models/openai_api.py
Normal file
154
opencompass/models/openai_api.py
Normal file
@ -0,0 +1,154 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OpenAI(BaseAPIModel):
|
||||
"""Model wrapper around OpenAI's models.
|
||||
|
||||
Args:
|
||||
path (str): The name of OpenAI's model.
|
||||
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||
Note that the length of prompt + generated tokens shall not exceed
|
||||
this value. Defaults to 2048.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
key (str): OpenAI key. In particular, when it is set to "ENV", the key
|
||||
will be fetched from the environment variable $OPENAI_API_KEY, as
|
||||
how openai defaults to be. Defaults to 'ENV'
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
||||
'https://api.openai.com/v1'.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
key: str = 'ENV',
|
||||
meta_template: Optional[Dict] = None,
|
||||
openai_api_base: str = 'https://api.openai.com/v1'):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
query_per_second=query_per_second,
|
||||
retry=retry)
|
||||
import openai
|
||||
import tiktoken
|
||||
self.openai = openai
|
||||
self.tiktoken = tiktoken
|
||||
|
||||
self.openai.api_key = os.getenv(
|
||||
'OPENAI_API_KEY') if key == 'ENV' else key
|
||||
self.openai.api_rase = openai_api_base
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic. Defaults to 0.7.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = list(
|
||||
executor.map(self._generate, inputs,
|
||||
[max_out_len] * len(inputs),
|
||||
[temperature] * len(inputs)))
|
||||
return results
|
||||
|
||||
def _generate(self, input: str or PromptList, max_out_len: int,
|
||||
temperature: float) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
# max num token for gpt-3.5-turbo is 4097
|
||||
max_out_len = min(max_out_len, 4000 - self.get_token_len(str(input)))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
elif item['role'] == 'SYSTEM':
|
||||
msg['role'] = 'system'
|
||||
messages.append(msg)
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.wait()
|
||||
try:
|
||||
response = self.openai.ChatCompletion.create(
|
||||
model=self.path,
|
||||
messages=messages,
|
||||
max_tokens=max_out_len,
|
||||
n=1,
|
||||
stop=None,
|
||||
temperature=temperature,
|
||||
)
|
||||
except self.openai.error.RateLimitError:
|
||||
max_num_retries -= 1
|
||||
max_num_retries += 1
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
return result
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
enc = self.tiktoken.encoding_for_model(self.path)
|
||||
return len(enc.encode(prompt))
|
162
opencompass/openicl/icl_inferencer/icl_base_inferencer.py
Normal file
162
opencompass/openicl/icl_inferencer/icl_base_inferencer.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""Basic Inferencer."""
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dist import is_main_process
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
|
||||
|
||||
class BaseInferencer:
|
||||
"""Base Inferencer class for all evaluation Inferencer.
|
||||
|
||||
Attributes:
|
||||
model (:obj:`BaseModel`, optional): The module to inference.
|
||||
max_model_token_num (:obj:`int`, optional): Maximum number of
|
||||
tokenized words allowed by the LM.
|
||||
batch_size (:obj:`int`, optional): Batch size for the
|
||||
:obj:`DataLoader`.
|
||||
output_json_filepath (:obj:`str`, optional): File path for output
|
||||
`JSON` file.
|
||||
output_json_filename (:obj:`str`, optional): File name for output
|
||||
`JSON` file.
|
||||
api_name (:obj:`str`, optional): Name of API service.
|
||||
call_api (:obj:`bool`): If ``True``, an API for LM models will be used,
|
||||
determined by :obj:`api_name`.
|
||||
"""
|
||||
model = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = 1,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.model = model
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.batch_size = batch_size
|
||||
self.output_json_filepath = output_json_filepath
|
||||
self.output_json_filename = output_json_filename
|
||||
self.is_main_process = is_main_process()
|
||||
if not os.path.exists(self.output_json_filepath):
|
||||
os.makedirs(self.output_json_filepath)
|
||||
|
||||
def inference(self,
|
||||
retriever: BaseRetriever,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
output_json_filepath: Optional[str] = None,
|
||||
output_json_filename: Optional[str] = None) -> List:
|
||||
"""Perform In-Context Inference given a retriever and optional
|
||||
templates.
|
||||
|
||||
Args:
|
||||
retriever (:obj:`BaseRetriever`): An instance of a Retriever class
|
||||
that will be used to retrieve in-context examples
|
||||
ice_template (:obj:`PromptTemplate`, optional): A template for
|
||||
generating the in-context examples prompt. Defaults to None.
|
||||
prompt_template (:obj:`PromptTemplate`, optional): A template for
|
||||
generating the final prompt. Defaults to None.
|
||||
output_json_filepath (:obj:`str`, optional): The file path to save
|
||||
the results as a `JSON` file. Defaults to None.
|
||||
output_json_filename (:obj:`str`, optional): The file name to save
|
||||
the results as a `JSON` file. Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the function is not implemented in the
|
||||
subclass.
|
||||
|
||||
Returns:
|
||||
:obj:`List:` A list of string, each representing the results of one
|
||||
inference.
|
||||
"""
|
||||
raise NotImplementedError("Method hasn't been implemented yet")
|
||||
|
||||
@staticmethod
|
||||
def get_dataloader(datalist: List[List], batch_size: int) -> DataLoader:
|
||||
"""Return a dataloader of the input data list."""
|
||||
dataloader = DataLoader(datalist,
|
||||
batch_size=batch_size,
|
||||
collate_fn=lambda x: x)
|
||||
return dataloader
|
||||
|
||||
|
||||
def dump_results_dict(results_dict, filename):
|
||||
with open(filename, 'w', encoding='utf-8') as json_file:
|
||||
json.dump(results_dict, json_file, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
class GenInferencerOutputHandler:
|
||||
origin_prompt_dict = {}
|
||||
output_dict = {}
|
||||
prediction_dict = {}
|
||||
results_dict = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.results_dict = {}
|
||||
|
||||
def write_to_json(self, save_dir: str, filename: str):
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||
|
||||
def save_results(self, origin_prompt, prediction, idx):
|
||||
self.results_dict[str(idx)] = {
|
||||
'origin_prompt': origin_prompt,
|
||||
'prediction': prediction,
|
||||
}
|
||||
|
||||
|
||||
class PPLInferencerOutputHandler:
|
||||
results_dict = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.results_dict = {}
|
||||
|
||||
def write_to_json(self, save_dir: str, filename: str):
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, Path(save_dir) / filename)
|
||||
|
||||
def save_ice(self, ice):
|
||||
for idx, example in enumerate(ice):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
self.results_dict[str(idx)]['in-context examples'] = example
|
||||
|
||||
def save_predictions(self, predictions):
|
||||
for idx, prediction in enumerate(predictions):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
self.results_dict[str(idx)]['prediction'] = prediction
|
||||
|
||||
def save_prompt_and_ppl(self, label, input, prompt, ppl, idx):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
if 'label: ' + str(label) not in self.results_dict[str(idx)].keys():
|
||||
self.results_dict[str(idx)]['label: ' + str(label)] = {}
|
||||
self.results_dict[str(idx)]['label: ' +
|
||||
str(label)]['testing input'] = input
|
||||
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
|
||||
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
|
||||
|
||||
def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
|
||||
if str(idx) not in self.results_dict.keys():
|
||||
self.results_dict[str(idx)] = {}
|
||||
# TODO:
|
||||
# for single token situation, the input will always be yes currently
|
||||
self.results_dict[str(idx)]['testing input'] = input
|
||||
self.results_dict[str(idx)]['prompt'] = prompt
|
||||
# TODO: hard code here
|
||||
self.results_dict[str(idx)]['choices'] = choices
|
||||
# For calculate auc scores, set scores as prediction
|
||||
self.results_dict[str(idx)]['prediction'] = cond_prob
|
||||
# set pred label in case needed
|
||||
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob))
|
186
opencompass/openicl/icl_retriever/icl_mdl_retriever.py
Normal file
186
opencompass/openicl/icl_retriever/icl_mdl_retriever.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""MDL Retriever."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from opencompass.openicl import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever.icl_topk_retriever import TopkRetriever
|
||||
from opencompass.openicl.utils.logging import get_logger
|
||||
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ICL_RETRIEVERS.register_module()
|
||||
class MDLRetriever(TopkRetriever):
|
||||
"""MDL Retriever, subclass of `TopkRetriever`. MDL is a abbreviation of
|
||||
Minimum Description Length, specially designed for ppl evaluation. You may
|
||||
refer to the paper for more details: https://arxiv.org/pdf/2212.10375.pdf.
|
||||
|
||||
Args:
|
||||
dataset (`BaseDataset`): Any BaseDataset instances.
|
||||
Attributes of ``reader``, ``train`` and ``test`` will be used.
|
||||
ice_separator (`Optional[str]`): The separator between each in-context
|
||||
example template when origin `PromptTemplate` is provided. Defaults
|
||||
to '\n'.
|
||||
ice_eos_token (`Optional[str]`): The end of sentence token for
|
||||
in-context example template when origin `PromptTemplate` is
|
||||
provided. Defaults to '\n'.
|
||||
ice_num (`Optional[int]`): The number of in-context example template
|
||||
when origin `PromptTemplate` is provided. Defaults to 1.
|
||||
sentence_transformers_model_name (`Optional[str]`): The name of the
|
||||
sentence transformers model. Defaults to 'all-mpnet-base-v2'.
|
||||
tokenizer_name (`Optional[str]`): The name of the tokenizer. Defaults
|
||||
to 'gpt2-xl'.
|
||||
batch_size (`Optional[int]`): The batch size for the dataloader.
|
||||
Defaults to 1.
|
||||
candidate_num (`Optional[int]`): The number of candidates to retrieve
|
||||
for each example. Defaults to 1.
|
||||
ce_model_name (`Optional[str]`): The name of the model for calculating
|
||||
MDL. Defaults to 'gpt2-xl'.
|
||||
select_time (`Optional[int]`): The number of times to select MDL.
|
||||
Defaults to 5.
|
||||
ice_template (`Optional[PromptTemplate]`): The template for in-context
|
||||
example. Defaults to None.
|
||||
prompt_template (`Optional[PromptTemplate]`): The template for prompt.
|
||||
Defaults to None.
|
||||
labels (`Optional[List]`): The labels for calculating MDL. Defaults to
|
||||
None.
|
||||
seed (`Optional[int]`): The seed for random. Defaults to 1.
|
||||
"""
|
||||
metric_model = None
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
ice_separator: Optional[str] = '\n',
|
||||
ice_eos_token: Optional[str] = '\n',
|
||||
ice_num: Optional[int] = 1,
|
||||
sentence_transformers_model_name: Optional[
|
||||
str] = 'all-mpnet-base-v2',
|
||||
tokenizer_name: Optional[str] = 'gpt2-xl',
|
||||
batch_size: Optional[int] = 1,
|
||||
candidate_num: Optional[int] = 1,
|
||||
ce_model_name: Optional[str] = 'gpt2-xl',
|
||||
select_time: Optional[int] = 5,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
labels: Optional[List] = None,
|
||||
seed: Optional[int] = 1) -> None:
|
||||
super().__init__(dataset, ice_separator, ice_eos_token, ice_num,
|
||||
sentence_transformers_model_name, tokenizer_name,
|
||||
batch_size)
|
||||
self.ce_model_name = ce_model_name
|
||||
self.candidate_num = candidate_num
|
||||
self.select_time = select_time
|
||||
self.ice_template = ICL_PROMPT_TEMPLATES.build(ice_template)
|
||||
if prompt_template is not None:
|
||||
self.prompt_template = ICL_PROMPT_TEMPLATES.build(prompt_template)
|
||||
else:
|
||||
self.prompt_template = None
|
||||
self.labels = labels
|
||||
self.seed = seed
|
||||
|
||||
def topk_search(self):
|
||||
np.random.seed(self.seed)
|
||||
res_list = self.forward(self.dataloader)
|
||||
rtr_idx_list = [[] for _ in range(len(res_list))]
|
||||
|
||||
logger.info('Retrieving data for test set...')
|
||||
for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
|
||||
idx = entry['metadata']['id']
|
||||
embed = np.expand_dims(entry['embed'], axis=0)
|
||||
near_ids = self.index.search(
|
||||
embed, min(self.candidate_num,
|
||||
len(self.index_ds)))[1][0].tolist()
|
||||
candidates = []
|
||||
mdl_scores = []
|
||||
for j in range(self.select_time):
|
||||
if j == 0:
|
||||
rand_idx_list = near_ids[:self.ice_num]
|
||||
else:
|
||||
rand_idx_list = np.random.choice(near_ids,
|
||||
self.ice_num,
|
||||
replace=False)
|
||||
rand_idx_list = [int(i) for i in rand_idx_list]
|
||||
candidates.append(rand_idx_list)
|
||||
|
||||
ice = self.generate_ice(rand_idx_list,
|
||||
ice_template=self.ice_template)
|
||||
ice = str(ice)
|
||||
mask_length = len(
|
||||
self.tokenizer(ice + self.ice_eos_token,
|
||||
verbose=False)['input_ids'])
|
||||
if self.labels is None:
|
||||
labels = self.get_labels(self.ice_template,
|
||||
self.prompt_template)
|
||||
else:
|
||||
labels = self.labels
|
||||
prompt_list = []
|
||||
for label in labels:
|
||||
prompt = self.generate_label_prompt(
|
||||
idx, ice, label, self.ice_template,
|
||||
self.prompt_template)
|
||||
prompt = str(prompt)
|
||||
prompt_list.append(prompt)
|
||||
loss_list = self.cal_ce(prompt_list, mask_length=mask_length)
|
||||
probs = np.exp(-np.array(loss_list))
|
||||
normalized_probs = probs / probs.sum(0, keepdims=True)
|
||||
neg_entropy = -entropy(normalized_probs, label_dim=0)
|
||||
mdl_scores.append(neg_entropy)
|
||||
|
||||
rtr_idx_list[idx] = candidates[mdl_scores.index(max(mdl_scores))]
|
||||
rtr_idx_list[idx] = [int(i) for i in rtr_idx_list[idx]]
|
||||
|
||||
return rtr_idx_list
|
||||
|
||||
def retrieve(self):
|
||||
"""Retrieve the in-context example index for each test example."""
|
||||
return self.topk_search()
|
||||
|
||||
def cal_ce(self, input_texts: List[str], mask_length=None):
|
||||
if self.metric_model is None:
|
||||
logger.info(
|
||||
f'Load model {self.ce_model_name} for calculating MDL...')
|
||||
self.metric_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.ce_model_name)
|
||||
self.metric_model.to(self.device)
|
||||
inputs = self.tokenizer(input_texts,
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
outputs = self.metric_model(**inputs)
|
||||
|
||||
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
||||
shift_labels = inputs['input_ids'][..., 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss(
|
||||
reduction='none', ignore_index=self.tokenizer.pad_token_id)
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
loss = loss_fct(shift_logits,
|
||||
shift_labels.view(-1)).view(shift_labels.size())
|
||||
if mask_length is not None:
|
||||
mask = torch.cat([
|
||||
torch.zeros([loss.shape[0], mask_length], dtype=torch.float),
|
||||
torch.ones([loss.shape[0], loss.shape[-1] - mask_length],
|
||||
dtype=torch.float)
|
||||
], -1)
|
||||
mask = mask.to(self.device)
|
||||
loss = torch.mul(mask, loss)
|
||||
|
||||
lens = (inputs['input_ids'] !=
|
||||
self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
|
||||
if mask_length is not None:
|
||||
lens -= mask_length
|
||||
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
|
||||
return ce_loss
|
||||
|
||||
|
||||
def entropy(probs: np.array, label_dim: int = 0, mask=None):
|
||||
if mask is None:
|
||||
return -(probs * np.log(probs)).sum(label_dim)
|
||||
return -(mask * probs * np.log(probs)).sum(label_dim)
|
99
opencompass/openicl/icl_retriever/icl_votek_retriever.py
Normal file
99
opencompass/openicl/icl_retriever/icl_votek_retriever.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""Votek Retriever."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from opencompass.openicl.icl_retriever.icl_topk_retriever import TopkRetriever
|
||||
|
||||
|
||||
class VotekRetriever(TopkRetriever):
|
||||
"""Vote-k In-context Learning Retriever, subclass of `TopkRetriever`.
|
||||
|
||||
**WARNING**: This class has not been tested thoroughly. Please use it with
|
||||
caution.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
ice_separator: Optional[str] = '\n',
|
||||
ice_eos_token: Optional[str] = '\n',
|
||||
ice_num: Optional[int] = 1,
|
||||
sentence_transformers_model_name: Optional[
|
||||
str] = 'all-mpnet-base-v2',
|
||||
tokenizer_name: Optional[str] = 'gpt2-xl',
|
||||
batch_size: Optional[int] = 1,
|
||||
votek_k: Optional[int] = 3) -> None:
|
||||
super().__init__(dataset, ice_separator, ice_eos_token, ice_num,
|
||||
sentence_transformers_model_name, tokenizer_name,
|
||||
batch_size)
|
||||
self.votek_k = votek_k
|
||||
|
||||
def votek_select(self,
|
||||
embeddings=None,
|
||||
select_num=None,
|
||||
k=None,
|
||||
overlap_threshold=None,
|
||||
vote_file=None):
|
||||
n = len(embeddings)
|
||||
if vote_file is not None and os.path.isfile(vote_file):
|
||||
with open(vote_file) as f:
|
||||
vote_stat = json.load(f)
|
||||
else:
|
||||
vote_stat = defaultdict(list)
|
||||
|
||||
for i in range(n):
|
||||
cur_emb = embeddings[i].reshape(1, -1)
|
||||
cur_scores = np.sum(cosine_similarity(embeddings, cur_emb),
|
||||
axis=1)
|
||||
sorted_indices = np.argsort(cur_scores).tolist()[-k - 1:-1]
|
||||
for idx in sorted_indices:
|
||||
if idx != i:
|
||||
vote_stat[idx].append(i)
|
||||
|
||||
if vote_file is not None:
|
||||
with open(vote_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(vote_stat, f)
|
||||
votes = sorted(vote_stat.items(),
|
||||
key=lambda x: len(x[1]),
|
||||
reverse=True)
|
||||
j = 0
|
||||
selected_indices = []
|
||||
while len(selected_indices) < select_num and j < len(votes):
|
||||
candidate_set = set(votes[j][1])
|
||||
flag = True
|
||||
for pre in range(j):
|
||||
cur_set = set(votes[pre][1])
|
||||
if len(candidate_set.intersection(
|
||||
cur_set)) >= overlap_threshold * len(candidate_set):
|
||||
flag = False
|
||||
break
|
||||
if not flag:
|
||||
j += 1
|
||||
continue
|
||||
selected_indices.append(int(votes[j][0]))
|
||||
j += 1
|
||||
if len(selected_indices) < select_num:
|
||||
unselected_indices = []
|
||||
cur_num = len(selected_indices)
|
||||
for i in range(n):
|
||||
if i not in selected_indices:
|
||||
unselected_indices.append(i)
|
||||
selected_indices += random.sample(unselected_indices,
|
||||
select_num - cur_num)
|
||||
return selected_indices
|
||||
|
||||
def vote_k_search(self):
|
||||
vote_k_idxs = self.votek_select(embeddings=self.embed_list,
|
||||
select_num=self.ice_num,
|
||||
k=self.votek_k,
|
||||
overlap_threshold=1)
|
||||
return [vote_k_idxs[:] for _ in range(len(self.test_ds))]
|
||||
|
||||
def retrieve(self):
|
||||
return self.vote_k_search()
|
104
opencompass/openicl/utils/api_service.py
Normal file
104
opencompass/openicl/utils/api_service.py
Normal file
@ -0,0 +1,104 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
OPENICL_API_NAME_LIST = ['opt-175b', 'gpt3']
|
||||
OPENICL_API_PARAMETER_DICT = {
|
||||
'opt-175b': ['URL', 'headers'],
|
||||
'gpt3': [
|
||||
'engine', 'temperature', 'max_tokens', 'top_p', 'frequency_penalty',
|
||||
'presence_penalty', 'sleep_time'
|
||||
]
|
||||
}
|
||||
OPENICL_API_REQUEST_CONFIG = {
|
||||
'opt-175b': {
|
||||
'URL': '', # http://xxx/completions or http://xxx/generate
|
||||
'headers': {
|
||||
'Content-Type': 'application/json; charset=UTF-8'
|
||||
}
|
||||
},
|
||||
'gpt3': {
|
||||
'engine': 'text-davinci-003',
|
||||
'temperature': 0,
|
||||
'max_tokens': 256,
|
||||
'top_p': 1.0,
|
||||
'frequency_penalty': 0.0,
|
||||
'presence_penalty': 0.0,
|
||||
'sleep_time': 3
|
||||
}
|
||||
}
|
||||
PROXIES = {'https': '', 'http': ''}
|
||||
|
||||
|
||||
def is_api_available(api_name):
|
||||
if api_name is None:
|
||||
return False
|
||||
return True if api_name in OPENICL_API_NAME_LIST else False
|
||||
|
||||
|
||||
def update_openicl_api_request_config(api_name, **kwargs):
|
||||
if api_name is None or not is_api_available(api_name):
|
||||
return
|
||||
|
||||
parameter_list = OPENICL_API_PARAMETER_DICT[api_name]
|
||||
for parameter in parameter_list:
|
||||
if parameter in kwargs.keys():
|
||||
OPENICL_API_REQUEST_CONFIG[api_name][parameter] = kwargs[parameter]
|
||||
|
||||
|
||||
def api_get_ppl(api_name, input_texts):
|
||||
if api_name == 'opt-175b':
|
||||
pyload = {'prompt': input_texts, 'max_tokens': 0, 'echo': True}
|
||||
response = json.loads(
|
||||
requests.post(
|
||||
OPENICL_API_REQUEST_CONFIG[api_name]['URL'],
|
||||
data=json.dumps(pyload),
|
||||
headers=OPENICL_API_REQUEST_CONFIG[api_name]['headers'],
|
||||
proxies=PROXIES).text)
|
||||
lens = np.array(
|
||||
[len(r['logprobs']['tokens']) for r in response['choices']])
|
||||
ce_loss = np.array([
|
||||
-sum(r['logprobs']['token_logprobs']) for r in response['choices']
|
||||
])
|
||||
return ce_loss / lens
|
||||
|
||||
if api_name == 'gpt3':
|
||||
raise NotImplementedError("GPT-3 API doesn't support PPL calculation")
|
||||
|
||||
|
||||
def api_get_tokens(api_name, input_texts):
|
||||
length_list = [len(text) for text in input_texts]
|
||||
|
||||
if api_name == 'opt-175b':
|
||||
pyload = {'prompt': input_texts, 'max_tokens': 100, 'echo': True}
|
||||
response = json.loads(
|
||||
requests.post(
|
||||
OPENICL_API_REQUEST_CONFIG[api_name]['URL'],
|
||||
data=json.dumps(pyload),
|
||||
headers=OPENICL_API_REQUEST_CONFIG[api_name]['headers'],
|
||||
proxies=PROXIES).text)
|
||||
return [r['text'] for r in response['choices']], [
|
||||
r['text'][length:]
|
||||
for r, length in zip(response['choices'], length_list)
|
||||
]
|
||||
|
||||
if api_name == 'gpt3':
|
||||
openai.api_key = os.getenv('OPENAI_API_KEY')
|
||||
response = openai.Completion.create(
|
||||
engine=OPENICL_API_REQUEST_CONFIG['gpt3']['engine'],
|
||||
prompt=input_texts,
|
||||
temperature=OPENICL_API_REQUEST_CONFIG['gpt3']['temperature'],
|
||||
max_tokens=OPENICL_API_REQUEST_CONFIG['gpt3']['max_tokens'],
|
||||
top_p=OPENICL_API_REQUEST_CONFIG['gpt3']['top_p'],
|
||||
frequency_penalty=OPENICL_API_REQUEST_CONFIG['gpt3']
|
||||
['frequency_penalty'],
|
||||
presence_penalty=OPENICL_API_REQUEST_CONFIG['gpt3']
|
||||
['presence_penalty'])
|
||||
time.sleep(OPENICL_API_REQUEST_CONFIG['gpt3']['sleep_time'])
|
||||
return [(input + r['text'])
|
||||
for r, input in zip(response['choices'], input_texts)
|
||||
], [r['text'] for r in response['choices']]
|
60
opencompass/partitioners/naive.py
Normal file
60
opencompass/partitioners/naive.py
Normal file
@ -0,0 +1,60 @@
|
||||
import os.path as osp
|
||||
from typing import Dict, List
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
|
||||
from opencompass.registry import PARTITIONERS
|
||||
from opencompass.utils import get_infer_output_path
|
||||
|
||||
from .base import BasePartitioner
|
||||
|
||||
|
||||
@PARTITIONERS.register_module()
|
||||
class NaivePartitioner(BasePartitioner):
|
||||
"""Naive task partitioner. This partitioner will generate a task for each
|
||||
model-dataset pair.
|
||||
|
||||
Args:
|
||||
config (ConfigDict): The full config dict.
|
||||
"""
|
||||
|
||||
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
|
||||
work_dir: str, out_dir: str) -> List[Dict]:
|
||||
"""Partition model-dataset pairs into tasks. Each task is defined as a
|
||||
dict and will run independently as a unit. Its structure is as
|
||||
follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'models': [], # a list of model configs
|
||||
'datasets': [[]], # a nested list of dataset configs, each
|
||||
list corresponds to a model
|
||||
'work_dir': '', # the work dir
|
||||
}
|
||||
|
||||
Args:
|
||||
models (List[ConfigDict]): A list of model configs.
|
||||
datasets (List[ConfigDict]): A list of dataset configs.
|
||||
work_dir (str): The work dir for the task.
|
||||
out_dir (str): The full output path for the task, intended for
|
||||
Partitioners to check whether the task is finished via the
|
||||
existency of result file in this directory.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of tasks.
|
||||
"""
|
||||
|
||||
tasks = []
|
||||
for model in models:
|
||||
for dataset in datasets:
|
||||
filename = get_infer_output_path(model, dataset, out_dir)
|
||||
if osp.exists(filename):
|
||||
continue
|
||||
task = Config({
|
||||
'models': [model],
|
||||
'datasets': [[dataset]],
|
||||
'work_dir': work_dir
|
||||
})
|
||||
tasks.append(task)
|
||||
return tasks
|
187
opencompass/partitioners/size.py
Normal file
187
opencompass/partitioners/size.py
Normal file
@ -0,0 +1,187 @@
|
||||
import copy
|
||||
import math
|
||||
import os.path as osp
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
|
||||
from opencompass.registry import PARTITIONERS
|
||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
get_infer_output_path)
|
||||
|
||||
from .base import BasePartitioner
|
||||
|
||||
|
||||
@PARTITIONERS.register_module()
|
||||
class SizePartitioner(BasePartitioner):
|
||||
"""Task partitioner based on the size of the dataset (with some rough
|
||||
expansion as an estimation of computational cost).
|
||||
|
||||
Args:
|
||||
out_dir (str): The output directory of tasks.
|
||||
max_task_size (int): The maximum size of a task.
|
||||
gen_task_coef (int): The dataset cost measurement coefficient for
|
||||
generation tasks.
|
||||
dataset_size_path (str): The path to the dataset size cache file.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_dir: str,
|
||||
max_task_size: int = 2000,
|
||||
gen_task_coef: int = 20,
|
||||
dataset_size_path: str = '.cache/dataset_size.json'):
|
||||
super().__init__(out_dir)
|
||||
self.max_task_size = max_task_size
|
||||
self.gen_task_coef = gen_task_coef
|
||||
self.dataset_size_path = dataset_size_path
|
||||
|
||||
def partition(self, models: List[ConfigDict], datasets: List[ConfigDict],
|
||||
work_dir: str, out_dir: str) -> List[ConfigDict]:
|
||||
"""Partition model-dataset pairs into tasks. Each task is defined as a
|
||||
dict and will run independently as a unit. Its structure is as
|
||||
follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'models': [], # a list of model configs
|
||||
'datasets': [[]], # a nested list of dataset configs, each
|
||||
list corresponds to a model
|
||||
'work_dir': '', # the work dir
|
||||
}
|
||||
|
||||
Args:
|
||||
models (List[ConfigDict]): A list of model configs.
|
||||
datasets (List[ConfigDict]): A list of dataset configs.
|
||||
work_dir (str): The work dir for the task.
|
||||
out_dir (str): The full output path for the task, intended for
|
||||
Partitioners to check whether the task is finished via the
|
||||
existency of result file in this directory.
|
||||
|
||||
Returns:
|
||||
List[ConfigDict]: A list of tasks.
|
||||
"""
|
||||
|
||||
datasets = sorted(datasets,
|
||||
key=lambda x: self.get_cost(x),
|
||||
reverse=True)
|
||||
tasks = []
|
||||
for model in models:
|
||||
task = Config({
|
||||
'models': [model],
|
||||
'datasets': [[]],
|
||||
'work_dir': work_dir
|
||||
})
|
||||
num_data = 0
|
||||
for dataset in datasets:
|
||||
filename = get_infer_output_path(model, dataset, out_dir)
|
||||
root, ext = osp.splitext(filename)
|
||||
# skip the task if the task output exists
|
||||
if osp.exists(filename):
|
||||
continue
|
||||
dataset_size = self.get_cost(dataset)
|
||||
if dataset_size > self.max_task_size:
|
||||
dataset_splits = self.split_dataset(dataset)
|
||||
for i, dataset_split in enumerate(dataset_splits):
|
||||
# skip the task it the task output exists
|
||||
if not osp.exists(f'{root}_{i}{ext}'):
|
||||
tasks.append(
|
||||
Config({
|
||||
'models': [model],
|
||||
'datasets': [[dataset_split]],
|
||||
'work_dir': work_dir
|
||||
}))
|
||||
else:
|
||||
if num_data + dataset_size > self.max_task_size:
|
||||
tasks.append(task)
|
||||
task = Config({
|
||||
'models': [model],
|
||||
'datasets': [[]],
|
||||
'work_dir': work_dir
|
||||
})
|
||||
num_data = 0
|
||||
task['datasets'][0].append(dataset)
|
||||
num_data = num_data + dataset_size
|
||||
if task['datasets'][0]:
|
||||
tasks.append(task)
|
||||
|
||||
return tasks
|
||||
|
||||
@property
|
||||
def dataset_size(self):
|
||||
if not hasattr(self, '_dataset_size'):
|
||||
if osp.exists(self.dataset_size_path):
|
||||
self._dataset_size = mmengine.load(self.dataset_size_path)
|
||||
else:
|
||||
self._dataset_size = {}
|
||||
return self._dataset_size
|
||||
|
||||
def split_dataset(self, dataset_cfg: ConfigDict) -> List[ConfigDict]:
|
||||
"""Split dataset into several parts."""
|
||||
dataset_size, num_repeats = self.get_cost(dataset_cfg,
|
||||
get_raw_factors=True)
|
||||
split_configs = []
|
||||
abbr = dataset_abbr_from_cfg(dataset_cfg)
|
||||
step = self.max_task_size // num_repeats
|
||||
# evenly distribute the task
|
||||
step = math.ceil(dataset_size / math.ceil(dataset_size / step))
|
||||
for part, i in enumerate(range(0, dataset_size, step)):
|
||||
cfg = copy.deepcopy(dataset_cfg)
|
||||
cfg['abbr'] = abbr + f'_{part}'
|
||||
test_range = cfg['reader_cfg'].get('test_range', '')
|
||||
cfg['reader_cfg']['test_range'] = f'{test_range}[{i}:{i+step}]'
|
||||
split_configs.append(cfg)
|
||||
return split_configs
|
||||
|
||||
def get_cost(self,
|
||||
dataset: ConfigDict,
|
||||
get_raw_factors: bool = False) -> Union[int, Tuple[int, int]]:
|
||||
"""Get the computational cost of inferring on the dataset.
|
||||
|
||||
Args:
|
||||
dataset (ConfigDict): The dataset config.
|
||||
get_raw_factors (bool): If True, the raw factors of computational
|
||||
cost will be returned.
|
||||
|
||||
Returns:
|
||||
int or Tuple[int, int]: The size of the dataset. If get_raw_factors
|
||||
is True, the number of repeats will also be returned.
|
||||
"""
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||
|
||||
# If it's the PPL template, the dataset size will be multiplied by the
|
||||
# number of labels
|
||||
infer_cfg = dataset.infer_cfg
|
||||
test_range = dataset.reader_cfg.get('test_range', '')
|
||||
template = (infer_cfg.prompt_template.template if 'prompt_template'
|
||||
in infer_cfg else infer_cfg.ice_template.template)
|
||||
# If it's the Gen template, the dataset size will be multiplied by the
|
||||
# self.gen_task_coef
|
||||
factor = self.gen_task_coef
|
||||
if isinstance(template, dict):
|
||||
ctr = sum(key in template for key in ('begin', 'round', 'end'))
|
||||
if ctr != len(template.keys()):
|
||||
factor = len(template.keys())
|
||||
|
||||
if dataset_abbr in self.dataset_size:
|
||||
actual_size = eval('len(range(self.dataset_size[dataset_abbr])'
|
||||
f'{test_range})')
|
||||
if get_raw_factors:
|
||||
return actual_size, factor
|
||||
return factor * actual_size
|
||||
|
||||
dataset = build_dataset_from_cfg(dataset)
|
||||
self.dataset_size[dataset_abbr] = len(dataset.test)
|
||||
|
||||
mmengine.mkdir_or_exist('.cache/')
|
||||
mmengine.dump(self.dataset_size,
|
||||
self.dataset_size_path,
|
||||
indent=4,
|
||||
ensure_ascii=False)
|
||||
|
||||
actual_size = eval('len(range(self.dataset_size[dataset_abbr])'
|
||||
f'{test_range})')
|
||||
if get_raw_factors:
|
||||
return actual_size, factor
|
||||
return factor * actual_size
|
24
opencompass/registry.py
Normal file
24
opencompass/registry.py
Normal file
@ -0,0 +1,24 @@
|
||||
from mmengine.registry import Registry
|
||||
|
||||
PARTITIONERS = Registry('partitioner', locations=['opencompass.partitioners'])
|
||||
RUNNERS = Registry('runner', locations=['opencompass.runners'])
|
||||
TASKS = Registry('task', locations=['opencompass.tasks'])
|
||||
MODELS = Registry('model', locations=['opencompass.models'])
|
||||
# TODO: LOAD_DATASET -> DATASETS
|
||||
LOAD_DATASET = Registry('load_dataset', locations=['opencompass.datasets'])
|
||||
TEXT_POSTPROCESSORS = Registry(
|
||||
'text_postprocessors', locations=['opencompass.utils.text_postprocessors'])
|
||||
EVALUATORS = Registry('evaluators', locations=['opencompass.evaluators'])
|
||||
|
||||
ICL_INFERENCERS = Registry('icl_inferencers',
|
||||
locations=['opencompass.openicl.icl_inferencer'])
|
||||
ICL_RETRIEVERS = Registry('icl_retrievers',
|
||||
locations=['opencompass.openicl.icl_retriever'])
|
||||
ICL_DATASET_READERS = Registry(
|
||||
'icl_dataset_readers',
|
||||
locations=['opencompass.openicl.icl_dataset_reader'])
|
||||
ICL_PROMPT_TEMPLATES = Registry(
|
||||
'icl_prompt_templates',
|
||||
locations=['opencompass.openicl.icl_prompt_template'])
|
||||
ICL_EVALUATORS = Registry('icl_evaluators',
|
||||
locations=['opencompass.openicl.icl_evaluator'])
|
80
opencompass/runners/base.py
Normal file
80
opencompass/runners/base.py
Normal file
@ -0,0 +1,80 @@
|
||||
import getpass
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from mmengine.config import ConfigDict, Config
|
||||
|
||||
from opencompass.utils import LarkReporter, get_logger
|
||||
|
||||
|
||||
class BaseRunner:
|
||||
"""Base class for all runners. A runner is responsible for launching
|
||||
multiple tasks.
|
||||
|
||||
Args:
|
||||
task (ConfigDict): Task type config.
|
||||
debug (bool): Whether to run in debug mode.
|
||||
lark_bot_url (str): Lark bot url.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
task: ConfigDict,
|
||||
debug: bool = False,
|
||||
lark_bot_url: str = None):
|
||||
self.task_cfg = Config(task)
|
||||
self.debug = debug
|
||||
if lark_bot_url:
|
||||
self.lark_reporter = LarkReporter(lark_bot_url)
|
||||
else:
|
||||
self.lark_reporter = None
|
||||
|
||||
def __call__(self, tasks: List[Dict[str, Any]]):
|
||||
"""Launch multiple tasks and summarize the results.
|
||||
|
||||
Args:
|
||||
tasks (list[dict]): A list of task configs, usually generated by
|
||||
Partitioner.
|
||||
"""
|
||||
status = self.launch(tasks)
|
||||
self.summarize(status)
|
||||
|
||||
@abstractmethod
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
"""Launch multiple tasks.
|
||||
|
||||
Args:
|
||||
tasks (list[dict]): A list of task configs, usually generated by
|
||||
Partitioner.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: A list of (task name, exit code).
|
||||
"""
|
||||
|
||||
def summarize(self, status: List[Tuple[str, int]]) -> None:
|
||||
"""Summarize the results of the tasks.
|
||||
|
||||
Args:
|
||||
status (list[tuple[str, int]]): A list of (task name, exit code).
|
||||
"""
|
||||
|
||||
failed_logs = []
|
||||
for _task, code in status:
|
||||
if code != 0:
|
||||
get_logger().error(f'{_task} failed with code {code}')
|
||||
failed_logs.append(_task)
|
||||
if self.lark_reporter:
|
||||
num_succeeded = len(status) - len(failed_logs)
|
||||
if len(failed_logs) > 0:
|
||||
content = f'{getpass.getuser()} 的 '
|
||||
content += f'{self.task_cfg.type} 任务已完成,'
|
||||
content += f'成功任务 {num_succeeded} 个,'
|
||||
content += f'失败 {len(failed_logs)} 个。以下为失败的任务列表:'
|
||||
content += '\n' + '\n'.join(failed_logs)
|
||||
self.lark_reporter.post(title=f'悲报:您有{len(failed_logs)}个'
|
||||
'任务炸了',
|
||||
content=content)
|
||||
else:
|
||||
content = f'{getpass.getuser()} 的 '
|
||||
content += f'{self.task_cfg.type} 任务已完成,'
|
||||
content += f'成功任务 {num_succeeded} 个。'
|
||||
self.lark_reporter.post(title='喜报:全部任务完成', content=content)
|
148
opencompass/runners/slurm.py
Normal file
148
opencompass/runners/slurm.py
Normal file
@ -0,0 +1,148 @@
|
||||
import inspect
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.utils import track_parallel_progress
|
||||
|
||||
from opencompass.registry import RUNNERS, TASKS
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
from .base import BaseRunner
|
||||
|
||||
|
||||
@RUNNERS.register_module()
|
||||
class SlurmRunner(BaseRunner):
|
||||
"""Distributed runner based on Slurm. It will launch tasks in parallel
|
||||
using `srun` command.
|
||||
|
||||
Args:
|
||||
task (ConfigDict): Task type config.
|
||||
max_num_workers (int): Max number of workers to run in parallel.
|
||||
Defaults to 32.
|
||||
retry (int): Number of retries if the job failed. Defaults to 2.
|
||||
partition (str): Slurm partition name. Defaults to None.
|
||||
quotatype (str): Slurm quota type. Defaults to None.
|
||||
debug (bool): Whether to run in debug mode. Defaults to False.
|
||||
lark_bot_url (str): Lark bot url. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
task: ConfigDict,
|
||||
max_num_workers: int = 32,
|
||||
retry: int = 2,
|
||||
partition: str = None,
|
||||
quotatype: str = None,
|
||||
debug: bool = False,
|
||||
lark_bot_url: str = None):
|
||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||
self.max_num_workers = max_num_workers
|
||||
self.retry = retry
|
||||
self.partition = partition
|
||||
self.quotatype = quotatype
|
||||
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
"""Launch multiple tasks.
|
||||
|
||||
Args:
|
||||
tasks (list[dict]): A list of task configs, usually generated by
|
||||
Partitioner.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: A list of (task name, exit code).
|
||||
"""
|
||||
|
||||
if not self.debug:
|
||||
status = track_parallel_progress(self._launch,
|
||||
tasks,
|
||||
nproc=self.max_num_workers,
|
||||
keep_order=False)
|
||||
else:
|
||||
status = [self._launch(task, random_sleep=False) for task in tasks]
|
||||
return status
|
||||
|
||||
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
|
||||
"""Launch a single task.
|
||||
|
||||
Args:
|
||||
task_cfg (ConfigDict): Task config.
|
||||
random_sleep (bool): Whether to sleep for a random time before
|
||||
running the command. This avoids cluster error when launching
|
||||
multiple tasks at the same time. Default: True.
|
||||
|
||||
Returns:
|
||||
tuple[str, int]: Task name and exit code.
|
||||
"""
|
||||
|
||||
task_type = self.task_cfg.type
|
||||
if isinstance(self.task_cfg.type, str):
|
||||
task_type = TASKS.get(task_type)
|
||||
task = task_type(task_cfg)
|
||||
num_gpus = task.num_gpus
|
||||
task_name = task.name
|
||||
script_path = inspect.getsourcefile(task_type)
|
||||
|
||||
# Dump task config to file
|
||||
mmengine.mkdir_or_exist('tmp/')
|
||||
param_file = f'tmp/{os.getpid()}_params.py'
|
||||
task_cfg.dump(param_file)
|
||||
|
||||
# Build up slurm command
|
||||
task_cmd_template = task.get_command_template()
|
||||
task_cmd = task_cmd_template.replace('{SCRIPT_PATH}',
|
||||
script_path).replace(
|
||||
'{CFG_PATH}', param_file)
|
||||
cmd = 'srun'
|
||||
if self.partition:
|
||||
cmd += f' -p {self.partition}'
|
||||
if self.quotatype:
|
||||
cmd += f' --quotatype={self.quotatype}'
|
||||
if num_gpus > 0:
|
||||
cmd += f' --gres=gpu:{num_gpus}'
|
||||
cmd += f" -N1 -J '{task_name[:512]}' {task_cmd}"
|
||||
logger = get_logger()
|
||||
logger.debug(f'Running command: {cmd}')
|
||||
|
||||
# Run command with retry
|
||||
if self.debug:
|
||||
stdout = None
|
||||
else:
|
||||
out_path = task.get_log_path(file_extension='out')
|
||||
mmengine.mkdir_or_exist(osp.split(out_path)[0])
|
||||
stdout = open(out_path, 'w', encoding='utf-8')
|
||||
|
||||
if random_sleep:
|
||||
time.sleep(random.randint(0, 10))
|
||||
result = subprocess.run(cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=stdout,
|
||||
stderr=stdout)
|
||||
|
||||
retry = self.retry
|
||||
output_paths = task.get_output_paths()
|
||||
while self._job_failed(result.returncode, output_paths) and retry > 0:
|
||||
retry -= 1
|
||||
if random_sleep:
|
||||
time.sleep(random.randint(0, 10))
|
||||
result = subprocess.run(cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=stdout,
|
||||
stderr=stdout)
|
||||
|
||||
if result.returncode != 0 and not self.debug:
|
||||
logger.warning(f'task {task_name} fail, see\n{out_path}')
|
||||
|
||||
# Clean up
|
||||
os.remove(param_file)
|
||||
return task_name, result.returncode
|
||||
|
||||
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
|
||||
return return_code != 0 or not all(
|
||||
osp.exists(output_path) for output_path in output_paths)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user