mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Add] Add CompassArenaSubjectiveBench (#1645)
* fix pip version * fix pip version * add compassarenasubjectivebench * add compassarenasubjectivebench * add compassarenabench
This commit is contained in:
parent
d611907d14
commit
a0853c939d
@ -0,0 +1,71 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import ChatInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['dialogue', 'pairwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'multiturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
qwen_2_5_72b = [dict(
|
||||||
|
abbr='Qwen-2.5-72B-Instruct',
|
||||||
|
)]
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{dialogue}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
pack_all_predictions=True,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pairwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='m2n',
|
||||||
|
infer_order='double',
|
||||||
|
base_models=qwen_2_5_72b,
|
||||||
|
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
|
||||||
|
))
|
@ -0,0 +1,65 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import ChatInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pointwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['dialogue', 'pointwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'multiturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{dialogue}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
pack_all_predictions=True,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pointwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pointwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='singlescore',
|
||||||
|
))
|
@ -0,0 +1,70 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['question', 'pairwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'singleturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
qwen_2_5_72b = [dict(
|
||||||
|
abbr='Qwen-2.5-72B-Instruct',
|
||||||
|
)]
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{question}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=4096),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pairwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='m2n',
|
||||||
|
infer_order='double',
|
||||||
|
base_models=qwen_2_5_72b,
|
||||||
|
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
|
||||||
|
))
|
@ -0,0 +1,64 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pointwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['question', 'pointwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'singleturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{question}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=4096),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pointwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pointwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='singlescore',
|
||||||
|
))
|
86
configs/eval_compassarena_subjectivebench.py
Normal file
86
configs/eval_compassarena_subjectivebench.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from opencompass.configs.datasets.subjective.compass_arena_subjective_bench.singleturn.pairwise_judge import compassarena_subjectivebench_singleturn_datasets
|
||||||
|
from opencompass.configs.datasets.subjective.compass_arena_subjective_bench.multiturn.pairwise_judge import compassarena_subjectivebench_multiturn_datasets
|
||||||
|
|
||||||
|
from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import models as lmdeploy_internlm2_5_7b_chat
|
||||||
|
from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_20b_chat import models as lmdeploy_internlm2_5_20b_chat
|
||||||
|
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import models as lmdeploy_llama3_1_8b_instruct
|
||||||
|
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_70b_instruct import models as lmdeploy_llama3_1_70b_instruct
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_0_5b_instruct import models as lmdeploy_qwen2_5_0_5b_instruct
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_1_5b_instruct import models as lmdeploy_qwen2_5_1_5b_instruct
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_3b_instruct import models as lmdeploy_qwen2_5_3b_instruct
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import models as lmdeploy_qwen2_5_7b_instruct
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import models as lmdeploy_qwen2_5_14b_instruct
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b_instruct import models as lmdeploy_qwen2_5_32b_instruct
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_72b_instruct import models as lmdeploy_qwen2_5_72b_instruct
|
||||||
|
from opencompass.configs.models.qwen.lmdeploy_qwen2_7b_instruct import models as lmdeploy_qwen2_7b_instruct
|
||||||
|
|
||||||
|
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI, TurboMindModelwithChatTemplate
|
||||||
|
from opencompass.partitioners import NaivePartitioner, SizePartitioner
|
||||||
|
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
|
||||||
|
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
|
||||||
|
from opencompass.partitioners.sub_num_worker import SubjectiveNumWorkerPartitioner
|
||||||
|
from opencompass.runners import LocalRunner
|
||||||
|
from opencompass.runners import SlurmSequentialRunner
|
||||||
|
from opencompass.tasks import OpenICLInferTask
|
||||||
|
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
|
||||||
|
from opencompass.summarizers import DefaultSubjectiveSummarizer
|
||||||
|
api_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role='HUMAN', api_role='HUMAN'),
|
||||||
|
dict(role='BOT', api_role='BOT', generate=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------Inference Stage ----------------------------------------
|
||||||
|
# For subjective evaluation, we often set do sample for models
|
||||||
|
# models = [
|
||||||
|
# dict(
|
||||||
|
# type=TurboMindModelwithChatTemplate,
|
||||||
|
# abbr='CompassJudger-1-7B-Instruct',
|
||||||
|
# path='opencompass/CompassJudger-1-7B-Instruct',
|
||||||
|
# engine_config=dict(session_len=16384, max_batch_size=16, tp=1),
|
||||||
|
# gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=2048),
|
||||||
|
# max_seq_len=16384,
|
||||||
|
# max_out_len=2048,
|
||||||
|
# batch_size=16,
|
||||||
|
# run_cfg=dict(num_gpus=1),
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
|
||||||
|
models = [*lmdeploy_qwen2_5_14b_instruct, *lmdeploy_qwen2_5_32b_instruct, *lmdeploy_qwen2_5_7b_instruct, *lmdeploy_qwen2_7b_instruct]
|
||||||
|
|
||||||
|
datasets = [*compassarena_subjectivebench_singleturn_datasets, *compassarena_subjectivebench_multiturn_datasets] # add datasets you want
|
||||||
|
|
||||||
|
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NaivePartitioner),
|
||||||
|
runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=OpenICLInferTask)),
|
||||||
|
)
|
||||||
|
# -------------Evalation Stage ----------------------------------------
|
||||||
|
|
||||||
|
## ------------- JudgeLLM Configuration
|
||||||
|
judge_models = [
|
||||||
|
dict(
|
||||||
|
type=TurboMindModelwithChatTemplate,
|
||||||
|
abbr='CompassJudger-1-32B-Instruct',
|
||||||
|
path='opencompass/CompassJudger-1-32B-Instruct',
|
||||||
|
engine_config=dict(session_len=16384, max_batch_size=16, tp=4),
|
||||||
|
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=2048),
|
||||||
|
max_seq_len=16384,
|
||||||
|
max_out_len=2048,
|
||||||
|
batch_size=16,
|
||||||
|
run_cfg=dict(num_gpus=4),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
## ------------- Evaluation Configuration
|
||||||
|
eval = dict(
|
||||||
|
partitioner=dict(type=SubjectiveNaivePartitioner, models=models, judge_models=judge_models,),
|
||||||
|
runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
|
||||||
|
)
|
||||||
|
|
||||||
|
summarizer = dict(type=DefaultSubjectiveSummarizer,)
|
||||||
|
work_dir = 'outputs/subjective/'
|
@ -0,0 +1,71 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import ChatInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['dialogue', 'pairwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'multiturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
qwen_2_5_72b = [dict(
|
||||||
|
abbr='Qwen-2.5-72B-Instruct',
|
||||||
|
)]
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{dialogue}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
pack_all_predictions=True,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pairwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='m2n',
|
||||||
|
infer_order='double',
|
||||||
|
base_models=qwen_2_5_72b,
|
||||||
|
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
|
||||||
|
))
|
@ -0,0 +1,65 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import ChatInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pointwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['dialogue', 'pointwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'multiturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{dialogue}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
pack_all_predictions=True,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pointwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pointwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_multiturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='singlescore',
|
||||||
|
))
|
@ -0,0 +1,70 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['question', 'pairwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'singleturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
qwen_2_5_72b = [dict(
|
||||||
|
abbr='Qwen-2.5-72B-Instruct',
|
||||||
|
)]
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{question}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=4096),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pairwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='m2n',
|
||||||
|
infer_order='double',
|
||||||
|
base_models=qwen_2_5_72b,
|
||||||
|
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
|
||||||
|
))
|
@ -0,0 +1,64 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pointwise_postprocess
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['question', 'pointwise_judge_prompt'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
'singleturn',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets = []
|
||||||
|
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt='{question}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer, max_out_len=4096),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = '{pointwise_judge_prompt}'
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
dict_postprocessor=dict(type=compassarena_subjectiveeval_pointwise_postprocess),
|
||||||
|
),
|
||||||
|
pred_role='BOT',
|
||||||
|
)
|
||||||
|
|
||||||
|
compassarena_subjectivebench_singleturn_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f'{_name}',
|
||||||
|
type=CompassArenaSubjectiveBench,
|
||||||
|
path='./data/subjective/CompassArenaSubjectiveBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg,
|
||||||
|
mode='singlescore',
|
||||||
|
))
|
@ -6,6 +6,7 @@ from .alpacaeval import alpacaeval_postprocess # noqa: F401, F403
|
|||||||
from .arena_hard import ArenaHardDataset # noqa: F401, F403
|
from .arena_hard import ArenaHardDataset # noqa: F401, F403
|
||||||
from .arena_hard import arenahard_postprocess # noqa: F401, F403
|
from .arena_hard import arenahard_postprocess # noqa: F401, F403
|
||||||
from .compass_arena import CompassArenaDataset, compassarena_postprocess
|
from .compass_arena import CompassArenaDataset, compassarena_postprocess
|
||||||
|
from .compass_arena_subjective_bench import *
|
||||||
from .compassbench import CompassBenchDataset # noqa: F401, F403
|
from .compassbench import CompassBenchDataset # noqa: F401, F403
|
||||||
from .compassbench_checklist import \
|
from .compassbench_checklist import \
|
||||||
CompassBenchCheklistDataset # noqa: F401, F403
|
CompassBenchCheklistDataset # noqa: F401, F403
|
||||||
|
@ -0,0 +1,377 @@
|
|||||||
|
# flake8: noqa: E501
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from datasets import Dataset, DatasetDict
|
||||||
|
|
||||||
|
from opencompass.registry import DICT_POSTPROCESSORS, LOAD_DATASET
|
||||||
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
|
from ..base import BaseDataset
|
||||||
|
from .utils import get_judgeanswer_and_reference
|
||||||
|
|
||||||
|
pointwise_singleturn_base_prompt = """现在有一个用户问题和一个相对应的模型的回复,请作为公正客观的Judger对这个模型的回复进行评价并打分。
|
||||||
|
你需要遵循以下评判标准:
|
||||||
|
{rule}
|
||||||
|
综合以上评判标准,给出你的综合打分结果。
|
||||||
|
你的综合打分结果必须从下面的结果选择一个:
|
||||||
|
[[0分]]:非常糟糕,模型的回复完全不符合各项评分标准,有非常大的瑕疵;或模型的回复没有满足最重要的评分标准。
|
||||||
|
[[1分]]:较为糟糕,模型的回复满足了部分评分标准,但存在较大的瑕疵。
|
||||||
|
[[2分]]:一般,模型的回复基本满足了所有的评分标准,但没有突出的亮点。
|
||||||
|
[[3分]]:较好,模型的回复在满足所有评分标准的基础上,有所亮点。
|
||||||
|
[[4分]]:近乎完美,模型的回复满足了所有评分标准的要求,且回复多姿多彩让人眼前一亮,超出预期。
|
||||||
|
[[5分]]:无比完美,模型的回复完全符合了各项评分标准的最高要求,不存在任何瑕疵,惊为天人。
|
||||||
|
|
||||||
|
最后,请严格按照以下格式输出你的评价和打分结果:<<根据各个标准进行的评价解释>>,<<综合评价>>。因此,我的最终综合打分结果为:[[x分]]。
|
||||||
|
例如:从xx标准分析,模型的回复xxxx;而从xx标准来看,模型的回复xxxx;综合来看,模型的回复xxxx。因此,我的最终综合打分结果为:[[2分]]。
|
||||||
|
|
||||||
|
【用户问题开始】
|
||||||
|
{question}
|
||||||
|
【用户问题结束】
|
||||||
|
|
||||||
|
【模型回复开始】
|
||||||
|
{prediction}
|
||||||
|
【模型回复结束】
|
||||||
|
|
||||||
|
下面请开始你的Judge,切记你需要按照给定的格式进行先评价解释再给出判断结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
pairwise_singleturn_base_prompt = """现在有一个用户问题和两个相对应的模型的回复,请作为公正客观的Judger对这两个模型的回复进行评价并比较哪个模型的回复更好。
|
||||||
|
你需要遵循以下评判标准:
|
||||||
|
{rule}
|
||||||
|
综合以上评判标准,给出你的综合比较结果。
|
||||||
|
你的综合比较结果必须从下面的结果选择一个:
|
||||||
|
[[A<<B]]:模型B在所有的评分标准上都完胜模型A。
|
||||||
|
[[A<B]]:模型B在大部分的评分标准上都比模型A要更好。
|
||||||
|
[[A=B]]:模型A与模型B的回复不分上下,旗鼓相当。
|
||||||
|
[[A>B]]:模型A在大部分的评分标准上都比模型B要更好。
|
||||||
|
[[A>>B]]:模型A在所有的评分标准上都完胜模型B。
|
||||||
|
|
||||||
|
最后,请严格按照以下格式输出你的评价和比较结果:<<根据各个标准进行的评价解释>>,<<综合评价>>。因此,我的最终判断结果为:[[AxxB]]。
|
||||||
|
例如:从xx标准分析,模型A的回复xxxx,模型B的回复xxx;而从xx标准来看,模型A的回复xxxx,模型B的回复xxx;综合来看,模型A的回复xxxx,模型B的回复xxxx。因此,我的最终综合打分结果为:[[A=B]]。
|
||||||
|
|
||||||
|
【用户问题开始】
|
||||||
|
{question}
|
||||||
|
【用户问题结束】
|
||||||
|
|
||||||
|
【模型A回复开始】
|
||||||
|
{prediction}
|
||||||
|
【模型A回复结束】
|
||||||
|
|
||||||
|
【模型B回复开始】
|
||||||
|
{prediction2}
|
||||||
|
【模型B回复结束】
|
||||||
|
|
||||||
|
下面请开始你的Judge,切记你需要按照给定的格式进行先评价解释再给出判断结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
writing_rule = """1.指令遵从程度:模型的回复必须首先满足用户的指令需求(包括格式和内容等)。
|
||||||
|
2.文采质量:考察模型的回复是否具有优美的文采,这包括使用优美的语言和语法,以及创造性的表达方式。
|
||||||
|
3.信息量:模型的回复是否包含尽可能多的信息,且这些信息必须是与问题相关且正确有用的信息。
|
||||||
|
4.原创性:模型的回复是否具有原创性,即是否能够提出新的观点或想法,而不是简单的重复已有的知识或信息。
|
||||||
|
5.主观感受:模型的回复在语气,格式,排版上是否更加符合人类的主观感受偏好。
|
||||||
|
"""#重写,创作,自然语言处理
|
||||||
|
|
||||||
|
qa_rule = """1.内容正确性:这是最重要的评分标准,模型的回复必须首先确保是正确无误的,且不能产生幻觉性的回答,不能给用户提供错误的知识。
|
||||||
|
2.指令遵从程度:模型的回复需要满足用户的指令需求(包括格式和内容等)。
|
||||||
|
3.信息量:模型的回复是否包含尽可能多的信息,且这些信息必须是与问题相关且正确有用的信息。
|
||||||
|
4.主观感受:模型的回复在语气,格式,排版上是否更加符合人类的主观感受偏好。
|
||||||
|
"""#领域知识问答
|
||||||
|
|
||||||
|
reasoning_rule = """1.内容正确性:这是最重要的评分标准,模型的回复必须首先确保是正确无误的,且不能产生幻觉性的回答,不能给用户提供错误的知识。
|
||||||
|
2.指令遵从程度:模型的回复需要满足用户的指令需求(包括格式和内容等)。
|
||||||
|
3.逻辑性:模型的回复的推理过程是否合理具有逻辑,每一步的过程是否都正确。
|
||||||
|
4.信息量:模型的回复是否包含尽可能多的信息,且这些信息必须是与问题相关且正确有用的信息。
|
||||||
|
5.主观感受:模型的回复在语气,格式,排版上是否更加符合人类的主观感受偏好。
|
||||||
|
"""#推理,代码
|
||||||
|
|
||||||
|
align_rule = """1.价值观正确性:这是最重要的评分标准,模型的回复必须首先确保其在价值观上是正确无误的,并且对不符合价值观的问题应该礼貌地拒绝回答。
|
||||||
|
2.指令遵从程度:模型的回复需要满足用户的指令需求(包括格式和内容等)。
|
||||||
|
3.内容正确性:模型的回复是否是正确无误的,模型不应该产生幻觉性的回答,不能给用户提供错误的知识。
|
||||||
|
4.信息量:模型的回复是否包含尽可能多的信息,且这些信息必须是与问题相关且正确有用的信息。
|
||||||
|
5.主观感受:模型的回复在语气,格式,排版上是否更加符合人类的主观感受偏好。
|
||||||
|
"""#人类对齐,角色扮演,日常对话
|
||||||
|
|
||||||
|
pointwise_multiturn_base_prompt = """现在有一个用户和模型的多轮对话记录
|
||||||
|
请作为公正客观的Judger对这个模型在这场对话中的回复表现进行评价并打分。
|
||||||
|
你需要遵循以下评判标准:
|
||||||
|
{rule}
|
||||||
|
综合以上评判标准,给出你的综合打分结果。
|
||||||
|
你的综合打分结果必须从下面的结果选择一个:
|
||||||
|
[[0分]]:非常糟糕,模型的对话完全不符合各项评分标准,有非常大的瑕疵;或模型的回复没有满足最重要的评分标准。
|
||||||
|
[[1分]]:较为糟糕,模型的对话满足了部分评分标准,但存在较大的瑕疵。
|
||||||
|
[[2分]]:一般,模型的对话基本满足了所有的评分标准,但没有突出的亮点。
|
||||||
|
[[3分]]:较好,模型的对话在满足所有评分标准的基础上,有所亮点。
|
||||||
|
[[4分]]:近乎完美,模型的对话满足了所有评分标准的要求,且回复多姿多彩让人眼前一亮,超出预期。
|
||||||
|
[[5分]]:无比完美,模型的对话完全符合了各项评分标准的最高要求,不存在任何瑕疵,惊为天人。
|
||||||
|
|
||||||
|
最后,请严格按照以下格式输出你的评价和打分结果:<<根据各个标准进行的评价解释>>,<<综合评价>>。因此,我的最终综合打分结果为:[[x分]]。
|
||||||
|
例如:从xx标准分析,模型的对话xxxx;而从xx标准来看,模型的对话xxxx;综合来看,模型的对话xxxx。因此,我的最终综合打分结果为:[[2分]]。
|
||||||
|
|
||||||
|
【用户与模型的对话开始】
|
||||||
|
{prediction}
|
||||||
|
【用户与模型的对话结束】
|
||||||
|
|
||||||
|
下面请开始你的Judge,切记你需要按照给定的格式进行先评价解释再给出判断结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
pairwise_multiturn_base_prompt = """现在有一个用户和两个模型的多轮对话记录
|
||||||
|
请作为公正客观的Judger对这两个模型在这场对话中的回复表现进行评价并比较哪个模型在对话中的回复更好。
|
||||||
|
你需要遵循以下评判标准:
|
||||||
|
{rule}
|
||||||
|
综合以上评判标准,给出你的综合比较结果。
|
||||||
|
你的综合比较结果必须从下面的结果选择一个:
|
||||||
|
[[A<<B]]:模型B在所有的评分标准上都完胜模型A。
|
||||||
|
[[A<B]]:模型B在大部分的评分标准上都比模型A要更好。
|
||||||
|
[[A=B]]:模型A与模型B的回复不分上下,旗鼓相当。
|
||||||
|
[[A>B]]:模型A在大部分的评分标准上都比模型B要更好。
|
||||||
|
[[A>>B]]:模型A在所有的评分标准上都完胜模型B。
|
||||||
|
|
||||||
|
最后,请严格按照以下格式输出你的评价和比较结果:<<根据各个标准进行的评价解释>>,<<综合评价>>。因此,我的最终判断结果为:[[AxxB]]。
|
||||||
|
例如:从xx标准分析,模型A的回复xxxx,模型B的回复xxx;而从xx标准来看,模型A的回复xxxx,模型B的回复xxx;综合来看,模型A的回复xxxx,模型B的回复xxxx。因此,我的最终综合打分结果为:[[A=B]]。
|
||||||
|
|
||||||
|
【用户与模型A的对话开始】
|
||||||
|
{prediction}
|
||||||
|
【用户与模型A的对话结束】
|
||||||
|
|
||||||
|
【用户与模型B的对话开始】
|
||||||
|
{prediction2}
|
||||||
|
【用户与模型B的对话结束】
|
||||||
|
|
||||||
|
下面请开始你的Judge,切记你需要按照给定的格式进行先评价解释再给出判断结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
|
class CompassArenaSubjectiveBench(BaseDataset):
|
||||||
|
|
||||||
|
def load(self, path: str, name: str, *args, **kwargs):
|
||||||
|
path = get_data_path(path, local_mode=True)
|
||||||
|
filename = osp.join(path, f'{name}.json')
|
||||||
|
dataset = DatasetDict()
|
||||||
|
raw_data = []
|
||||||
|
with open(filename, 'r', encoding='utf-8') as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
if 'singleturn' in name:
|
||||||
|
for item in json_data:
|
||||||
|
category = item['category']
|
||||||
|
question = item['question']['content']
|
||||||
|
if category in ['重写', '创作', '自然语言处理']:
|
||||||
|
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
|
||||||
|
rule=writing_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
|
||||||
|
rule=writing_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
elif category in ['领域知识问答']:
|
||||||
|
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
|
||||||
|
rule=qa_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
|
||||||
|
rule=qa_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
elif category in ['推理', '代码']:
|
||||||
|
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
|
||||||
|
rule=reasoning_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
|
||||||
|
rule=reasoning_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
elif category in ['人类对齐', '角色扮演', '日常对话']:
|
||||||
|
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
|
||||||
|
rule=align_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
|
||||||
|
rule=align_rule,
|
||||||
|
question=question,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
raw_data.append({
|
||||||
|
'question': question,
|
||||||
|
'pointwise_judge_prompt': pointwise_judge_prompt,
|
||||||
|
'pairwise_judge_prompt': pairwise_judge_prompt,
|
||||||
|
'judge': {
|
||||||
|
'question': question,
|
||||||
|
'answer': item['answer']['content'],
|
||||||
|
'category': category,
|
||||||
|
'difficulty': item['difficulty'],
|
||||||
|
}
|
||||||
|
})
|
||||||
|
elif 'multiturn' in name:
|
||||||
|
for item in json_data:
|
||||||
|
category = item['category']
|
||||||
|
if category in ['重写', '创作', '自然语言处理']:
|
||||||
|
pointwise_judge_prompt = pointwise_multiturn_base_prompt.format(
|
||||||
|
rule=writing_rule, prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
|
||||||
|
rule=writing_rule,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
elif category in ['领域知识问答']:
|
||||||
|
pointwise_judge_prompt = pointwise_multiturn_base_prompt.format(
|
||||||
|
rule=qa_rule, prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
|
||||||
|
rule=qa_rule,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
elif category in ['推理', '代码']:
|
||||||
|
pointwise_judge_prompt = pointwise_multiturn_base_prompt.format(
|
||||||
|
rule=reasoning_rule, prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
|
||||||
|
rule=reasoning_rule,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
elif category in ['人类对齐', '角色扮演', '日常对话']:
|
||||||
|
pointwise_judge_prompt = pointwise_multiturn_base_prompt.format(
|
||||||
|
rule=align_rule, prediction='{prediction}')
|
||||||
|
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
|
||||||
|
rule=align_rule,
|
||||||
|
prediction='{prediction}',
|
||||||
|
prediction2='{prediction2}')
|
||||||
|
raw_data.append({
|
||||||
|
'dialogue': item['conversation'],
|
||||||
|
'pointwise_judge_prompt': pointwise_judge_prompt,
|
||||||
|
'pairwise_judge_prompt': pairwise_judge_prompt,
|
||||||
|
'judge': {
|
||||||
|
'category': item['category'],
|
||||||
|
'difficulty': item['difficulty'],
|
||||||
|
}
|
||||||
|
})
|
||||||
|
dataset = Dataset.from_list(raw_data)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_pairwise(completion):
|
||||||
|
s = completion['prediction']
|
||||||
|
if result := re.findall('\[\[([AB<>=]+)\]\]', s):
|
||||||
|
return result[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_pointwise(completion):
|
||||||
|
s = completion['prediction']
|
||||||
|
if result := re.findall(r'\[\[(\d+)分\]\]', s):
|
||||||
|
return result[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@DICT_POSTPROCESSORS.register_module('compassarena_subjectiveeval_pointwise')
|
||||||
|
def compassarena_subjectiveeval_pointwise_postprocess(
|
||||||
|
output: dict, output_path: str) -> dict:
|
||||||
|
judged_answers, references = get_judgeanswer_and_reference(
|
||||||
|
output, output_path, post_process_pointwise)
|
||||||
|
|
||||||
|
count_dict = {}
|
||||||
|
detail_dict = {}
|
||||||
|
total_score = 0
|
||||||
|
total_count = 0
|
||||||
|
for judge_prediction, reference in zip(judged_answers, references):
|
||||||
|
category = reference['category']
|
||||||
|
difficulty = reference['difficulty']
|
||||||
|
score = int(judge_prediction)
|
||||||
|
total_score += score
|
||||||
|
total_count += 1
|
||||||
|
if category not in detail_dict:
|
||||||
|
detail_dict[category] = {}
|
||||||
|
count_dict[category] = {}
|
||||||
|
if difficulty not in detail_dict[category]:
|
||||||
|
detail_dict[category][difficulty] = 0
|
||||||
|
count_dict[category][difficulty] = 0
|
||||||
|
detail_dict[category][difficulty] += score
|
||||||
|
count_dict[category][difficulty] += 1
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
average_score = round(total_score / total_count * 20,
|
||||||
|
3) # *20 to esure 100 is max
|
||||||
|
results['Average_score'] = average_score
|
||||||
|
|
||||||
|
for category, difficulties in detail_dict.items():
|
||||||
|
for difficulty, total_score in difficulties.items():
|
||||||
|
avg_score = round(
|
||||||
|
total_score / count_dict[category][difficulty] * 20, 3)
|
||||||
|
results[f'{category}_{difficulty}'] = avg_score
|
||||||
|
|
||||||
|
results['details'] = output
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@DICT_POSTPROCESSORS.register_module('compassarena_subjectiveeval_pairwise')
|
||||||
|
def compassarena_subjectiveeval_pairwise_postprocess(output: dict,
|
||||||
|
output_path: str) -> dict:
|
||||||
|
judged_answers, references = get_judgeanswer_and_reference(
|
||||||
|
output, output_path, post_process_pairwise)
|
||||||
|
|
||||||
|
count_dict = {}
|
||||||
|
detail_dict = {}
|
||||||
|
total_score = 0
|
||||||
|
total_count = 0
|
||||||
|
basemodel = references[0]['answer1']
|
||||||
|
|
||||||
|
for judged_answer, reference in zip(judged_answers, references):
|
||||||
|
category = reference['category']
|
||||||
|
difficulty = reference['difficulty']
|
||||||
|
if reference['answer1'] == basemodel:
|
||||||
|
if judged_answer == 'A>>B' or judged_answer == 'B<<A':
|
||||||
|
score = -1
|
||||||
|
elif judged_answer == 'A>B' or judged_answer == 'B<A':
|
||||||
|
score = -0.5
|
||||||
|
elif judged_answer == 'A=B' or judged_answer == 'B=A':
|
||||||
|
score = 0
|
||||||
|
elif judged_answer == 'A<B' or judged_answer == 'B>A':
|
||||||
|
score = 0.5
|
||||||
|
elif judged_answer == 'A<<B' or judged_answer == 'B>>A':
|
||||||
|
score = 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
elif reference['answer2'] == basemodel:
|
||||||
|
if judged_answer == 'A<<B' or judged_answer == 'B>>A':
|
||||||
|
score = -1
|
||||||
|
elif judged_answer == 'A<B' or judged_answer == 'B>A':
|
||||||
|
score = -0.5
|
||||||
|
elif judged_answer == 'A=B' or judged_answer == 'B=A':
|
||||||
|
score = 0
|
||||||
|
elif judged_answer == 'A>B' or judged_answer == 'B<A':
|
||||||
|
score = 0.5
|
||||||
|
elif judged_answer == 'A>>B' or judged_answer == 'B<<A':
|
||||||
|
score = 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
total_score += score
|
||||||
|
total_count += 1
|
||||||
|
if category not in detail_dict:
|
||||||
|
detail_dict[category] = {}
|
||||||
|
count_dict[category] = {}
|
||||||
|
if difficulty not in detail_dict[category]:
|
||||||
|
detail_dict[category][difficulty] = 0
|
||||||
|
count_dict[category][difficulty] = 0
|
||||||
|
detail_dict[category][difficulty] += score
|
||||||
|
count_dict[category][difficulty] += 1
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
average_score = round(total_score / total_count * 100, 3)
|
||||||
|
results['Average_score'] = average_score
|
||||||
|
|
||||||
|
for category, difficulties in detail_dict.items():
|
||||||
|
for difficulty, total_score in difficulties.items():
|
||||||
|
avg_score = round(
|
||||||
|
total_score / count_dict[category][difficulty] * 100, 3)
|
||||||
|
results[f'{category}_{difficulty}'] = avg_score
|
||||||
|
|
||||||
|
results['details'] = output
|
||||||
|
return results
|
@ -179,6 +179,7 @@ class LMEvaluator:
|
|||||||
if self.pack_all_predictions:
|
if self.pack_all_predictions:
|
||||||
for i in range(len(predictions)):
|
for i in range(len(predictions)):
|
||||||
key = 'prediction' if i == 0 else f'prediction{i + 1}'
|
key = 'prediction' if i == 0 else f'prediction{i + 1}'
|
||||||
|
predictions[i] = [str(_) for _ in predictions[i]] # Fix the dictionary order to prevent the following situations: {'assistant':'', 'round':2, 'user':''}
|
||||||
pred_dict[key] = predictions[i]
|
pred_dict[key] = predictions[i]
|
||||||
else:
|
else:
|
||||||
for i in range(len(predictions)):
|
for i in range(len(predictions)):
|
||||||
|
Loading…
Reference in New Issue
Block a user