mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add multi-model judge and fix some problems (#1016)
* support multi-model judge and moe judge * test_moe * test_moe * test * add moe judge * support multi-judge-model
This commit is contained in:
parent
c220550fb9
commit
2d4e559763
@ -65,7 +65,6 @@ for _name in subjective_all_sets:
|
||||
subjective_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=LMEvaluator,
|
||||
infer_order='random',
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
|
@ -67,7 +67,6 @@ for _name in subjective_all_sets:
|
||||
subjective_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=LMEvaluator,
|
||||
infer_order='random',
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
|
@ -119,7 +119,6 @@ for _name, _prompt in sub_map.items():
|
||||
subjective_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=LMEvaluator,
|
||||
infer_order='double',
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
|
@ -0,0 +1,156 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||
from opencompass.datasets import CompassArenaDataset
|
||||
|
||||
subjective_reader_cfg = dict(
|
||||
input_columns=['question', 'ref'],
|
||||
output_column='judge',
|
||||
)
|
||||
|
||||
data_path ="data/subjective/compass_arena"
|
||||
|
||||
subjective_datasets = []
|
||||
|
||||
base_prompt = """
|
||||
|
||||
[回答1开始]
|
||||
{prediction}
|
||||
[回答1结束]
|
||||
|
||||
[回答2开始]
|
||||
{prediction2}
|
||||
[回答2结束]
|
||||
|
||||
根据评分要求,在以下 3 个选项中做出选择:
|
||||
A. 回答1更好
|
||||
B. 回答2更好
|
||||
C. 回答1、2平局
|
||||
并提供你的解释原因。
|
||||
|
||||
如果你认为回答1更好,你的输出应形如:
|
||||
选择:A
|
||||
原因:blahblah blahblah\n
|
||||
|
||||
如果你认为回答2更好,你的输出应形如:
|
||||
选择:B
|
||||
原因:blahblah blahblah\n
|
||||
|
||||
如果你认为回答1、2打成平手,你的输出应形如:
|
||||
选择:C
|
||||
原因:blahblah blahblah\n
|
||||
"""
|
||||
|
||||
knowledge_prompt = """
|
||||
请根据提供的 评分要求,用户问题,参考答案 以及 相应的两个回答(回答1,回答2),判断两个回答中哪一个更好。
|
||||
评分要求(重要性依次递减):
|
||||
1. 更好的回答能与参考答案吻合或表明参考答案的意思。
|
||||
2. 在都准确答对问题的前提下,更好的回答能对知识点进行额外补充,且补充的知识准确无误。
|
||||
3. 更好的回答更加符合与人类对话的习惯,包括语气、情调等。
|
||||
|
||||
[用户问题]
|
||||
{question}
|
||||
|
||||
[参考答案]
|
||||
{ref}
|
||||
""" + base_prompt
|
||||
|
||||
|
||||
language_prompt = """
|
||||
请根据提供的 评分要求,用户问题 以及 相应的两个回答(回答1,回答2),判断两个回答中哪一个更好。
|
||||
评分要求(重要性依次递减):
|
||||
1. 在有明确的参考答案的情况下,越贴近参考答案或表明了参考答案的意思的回答越好。
|
||||
2. 更好的回答在语言表达上更流畅,更加符合与人类对话的习惯,包括语气、情调等
|
||||
3. 在都准确答对问题的前提下,更好的回答能进行额外补充,且补充的内容准确无误。
|
||||
|
||||
[用户问题]
|
||||
{question}
|
||||
|
||||
[参考答案]
|
||||
{ref}
|
||||
""" + base_prompt
|
||||
|
||||
|
||||
math_prompt = """
|
||||
请根据提供的 评分要求,用户问题,参考答案 以及 相应的两个回答(回答1,回答2),判断两个回答中哪一个更好。
|
||||
评分要求(重要性依次递减):
|
||||
1. 更好的回答的答案能和参考答案一致。
|
||||
2. 若两个回答的答案都与参考答案不一致,则更好的回答的推理过程应更加合理。
|
||||
3. 更好的回答更加符合与人类对话的习惯,包括语气、情调等。
|
||||
|
||||
[用户问题]
|
||||
{question}
|
||||
|
||||
[参考答案]
|
||||
{ref}
|
||||
""" + base_prompt
|
||||
|
||||
reason_prompt = math_prompt
|
||||
|
||||
creation_prompt = """
|
||||
请根据提供的 评分要求,用户问题 以及 相应的两个回答(回答1,回答2),判断两个回答中哪一个更好。
|
||||
评分要求(重要性依次递减):
|
||||
1. 好的回答必须首先符合用户问题里的各种需求,不能跑题
|
||||
2. 好的回答必须具有逻辑连贯性,围绕一个中心进行回答
|
||||
3. 好的回答必须具有创造性的词语和表达丰富度
|
||||
|
||||
[用户问题]
|
||||
{question}
|
||||
""" + base_prompt
|
||||
|
||||
sub_map = {"knowledge": knowledge_prompt, "language": language_prompt, "math_v2": math_prompt, "reason_v2": reason_prompt, "creationv2_zh": creation_prompt}
|
||||
|
||||
meta_prompt = """
|
||||
\n你是一个评判专家,请根据提供的 评分要求,用户问题 以及 相应的两个回答(回答1,回答2),判断两个回答中哪一个更好。\n评分要求(重要性依次递减):\n1. 好的回答必须首先符合用户问题里的各种需求,不能跑题 \n2. 好的回答必须具有逻辑连贯性,围绕一个中心进行回答\n3. 好的回答必须具有创造性的词语和表达丰富度\n\n[用户问题]\n{question}\n[回答1开始]\n{prediction}\n[回答1结束]\n[回答2开始]\n{prediction2}\n[回答2结束]\n此外,还有两个其他评判专家的评判意见供你参考。\n[评判意见1]\n{judgement}\n[评判意见2]\n{judgement2}\n\n最终请你综合其他评判专家的评判意见与你自己的意见,在以下 3 个选项中做出选择:\nA. 回答1更好\nB. 回答2更好\nC. 回答1、2平局\n并提供你的解释原因。\n\n如果你认为回答1更好,你的输出应形如:\n选择:A\n原因:blahblah blahblah\n\n\n如果你认为回答2更好,你的输出应形如:\n选择:B\n原因:blahblah blahblah\n\n\n如果你认为回答1、2打成平手,你的输出应形如:\n选择:C\n原因:blahblah blahblah\n\n
|
||||
"""
|
||||
for _name, _prompt in sub_map.items():
|
||||
subjective_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt="{question}"
|
||||
),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=2048),
|
||||
)
|
||||
|
||||
subjective_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=LMEvaluator,
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt = _prompt
|
||||
),
|
||||
]),
|
||||
),
|
||||
meta_review_prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt = meta_prompt
|
||||
),
|
||||
]),
|
||||
),
|
||||
),
|
||||
pred_role="BOT",
|
||||
)
|
||||
|
||||
subjective_datasets.append(
|
||||
dict(
|
||||
abbr=f"{_name}",
|
||||
type=CompassArenaDataset,
|
||||
path=data_path,
|
||||
name=_name,
|
||||
reader_cfg=subjective_reader_cfg,
|
||||
infer_cfg=subjective_infer_cfg,
|
||||
eval_cfg=subjective_eval_cfg
|
||||
))
|
@ -1,71 +0,0 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||
from opencompass.datasets import IRDataset
|
||||
|
||||
subjective_reader_cfg = dict(
|
||||
input_columns=['question', 'capability', 'ref'],
|
||||
output_column='judge',
|
||||
)
|
||||
|
||||
subjective_all_sets = [
|
||||
"information_retrieval",
|
||||
]
|
||||
data_path ="data/subjective/"
|
||||
|
||||
subjective_datasets = []
|
||||
|
||||
for _name in subjective_all_sets:
|
||||
subjective_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt="{question}"
|
||||
),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=512),
|
||||
)
|
||||
|
||||
subjective_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=LMEvaluator,
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt = """为上传的针对给定用户问题的回应撰写评论, 并为该回复打分:
|
||||
|
||||
[BEGIN DATA]
|
||||
***
|
||||
[用户问询]: {question}
|
||||
***
|
||||
[回应]: {prediction}
|
||||
***
|
||||
[参考答案]: {ref}
|
||||
***
|
||||
[END DATA]
|
||||
|
||||
请根据参考答案为这个回应撰写评论. 在这之后, 你应该按照如下格式给这个回应一个最终的1-10范围的评分: "[[评分]]", 例如: "评分: [[5]]"."""
|
||||
),
|
||||
]),
|
||||
),
|
||||
),
|
||||
pred_role="BOT",
|
||||
)
|
||||
|
||||
subjective_datasets.append(
|
||||
dict(
|
||||
abbr=f"{_name}",
|
||||
type=IRDataset,
|
||||
path=data_path,
|
||||
name=_name,
|
||||
reader_cfg=subjective_reader_cfg,
|
||||
infer_cfg=subjective_infer_cfg,
|
||||
eval_cfg=subjective_eval_cfg
|
||||
))
|
@ -1,59 +0,0 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||
from opencompass.datasets import IRDataset
|
||||
|
||||
subjective_reader_cfg = dict(
|
||||
input_columns=['question', 'capability', 'gpt4_prefix', 'gpt4_suffix', 'ref'],
|
||||
output_column='judge',
|
||||
)
|
||||
|
||||
subjective_all_sets = [
|
||||
"information_retrieval",
|
||||
]
|
||||
data_path ="data/subjective/"
|
||||
|
||||
subjective_datasets = []
|
||||
|
||||
for _name in subjective_all_sets:
|
||||
subjective_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt="{question}"
|
||||
),
|
||||
]),
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=512),
|
||||
)
|
||||
|
||||
subjective_eval_cfg = dict(
|
||||
evaluator=dict(
|
||||
type=LMEvaluator,
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt = "{gpt4_prefix}{prediction}{gpt4_suffix}"
|
||||
),
|
||||
]),
|
||||
),
|
||||
),
|
||||
pred_role="BOT",
|
||||
)
|
||||
|
||||
subjective_datasets.append(
|
||||
dict(
|
||||
abbr=f"{_name}",
|
||||
type=IRDataset,
|
||||
path=data_path,
|
||||
name=_name,
|
||||
reader_cfg=subjective_reader_cfg,
|
||||
infer_cfg=subjective_infer_cfg,
|
||||
eval_cfg=subjective_eval_cfg
|
||||
))
|
@ -44,7 +44,7 @@ models = [
|
||||
meta_template=api_meta_template,
|
||||
max_out_len=2048,
|
||||
max_seq_len=4096,
|
||||
batch_size=1,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
@ -54,7 +54,7 @@ datasets = [*subjective_datasets]
|
||||
# -------------Evalation Stage ----------------------------------------
|
||||
|
||||
## ------------- JudgeLLM Configuration
|
||||
judge_model = dict(
|
||||
judge_models = [dict(
|
||||
abbr='GPT4-Turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-1106-preview',
|
||||
@ -65,18 +65,14 @@ judge_model = dict(
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
temperature=0,
|
||||
)
|
||||
)]
|
||||
|
||||
## ------------- Evaluation Configuration
|
||||
eval = dict(
|
||||
partitioner=dict(
|
||||
type=SubjectiveNaivePartitioner, mode='singlescore', models=models
|
||||
),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=2,
|
||||
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
|
||||
type=SubjectiveSizePartitioner, max_task_size=1000, mode='singlescore', models=models, judge_models=judge_models,
|
||||
),
|
||||
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
|
||||
)
|
||||
|
||||
summarizer = dict(type=AlignmentBenchSummarizer, judge_type='general')
|
||||
|
@ -47,7 +47,7 @@ models = [
|
||||
meta_template=api_meta_template,
|
||||
max_out_len=2048,
|
||||
max_seq_len=4096,
|
||||
batch_size=1,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
@ -73,7 +73,7 @@ gpt4 = dict(
|
||||
# -------------Evalation Stage ----------------------------------------
|
||||
|
||||
## ------------- JudgeLLM Configuration
|
||||
judge_model = dict(
|
||||
judge_models = [dict(
|
||||
abbr='GPT4-Turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-1106-preview',
|
||||
@ -85,21 +85,20 @@ judge_model = dict(
|
||||
batch_size=2,
|
||||
retry=20,
|
||||
temperature=0,
|
||||
)
|
||||
)]
|
||||
|
||||
## ------------- Evaluation Configuration
|
||||
eval = dict(
|
||||
partitioner=dict(
|
||||
type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models
|
||||
),
|
||||
runner=dict(
|
||||
type=SlurmSequentialRunner,
|
||||
partition='llmeval',
|
||||
quotatype='auto',
|
||||
max_num_workers=256,
|
||||
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
|
||||
type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models,
|
||||
infer_order='random',
|
||||
judge_models=judge_models
|
||||
),
|
||||
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
|
||||
given_pred = [{'abbr':'gpt4-turbo', 'path':''}]
|
||||
)
|
||||
work_dir = 'outputs/alpaca/'
|
||||
|
||||
|
||||
|
||||
summarizer = dict(type=AlpacaSummarizer, judge_type='v2')
|
@ -72,7 +72,7 @@ gpt4 = dict(
|
||||
# -------------Evalation Stage ----------------------------------------
|
||||
|
||||
## ------------- JudgeLLM Configuration
|
||||
judge_model = dict(
|
||||
judge_models = [dict(
|
||||
abbr='GPT4-Turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-1106-preview',
|
||||
@ -84,7 +84,7 @@ judge_model = dict(
|
||||
batch_size=2,
|
||||
retry=20,
|
||||
temperature=0,
|
||||
)
|
||||
)]
|
||||
|
||||
## ------------- Evaluation Configuration
|
||||
eval = dict(
|
||||
@ -93,16 +93,13 @@ eval = dict(
|
||||
strategy='split',
|
||||
max_task_size=10000,
|
||||
mode='m2n',
|
||||
infer_order='double',
|
||||
base_models=[gpt4],
|
||||
compare_models=models,
|
||||
judge_models=judge_models,
|
||||
),
|
||||
runner=dict(
|
||||
type=SlurmSequentialRunner,
|
||||
partition='llm_dev2',
|
||||
quotatype='auto',
|
||||
max_num_workers=32,
|
||||
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
|
||||
),
|
||||
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
|
||||
given_pred = [{'abbr':'gpt4-turbo', 'path':''}]
|
||||
)
|
||||
|
||||
work_dir = 'outputs/compass_arena_debug/'
|
||||
|
@ -63,7 +63,7 @@ infer = dict(
|
||||
# -------------Evalation Stage ----------------------------------------
|
||||
|
||||
## ------------- JudgeLLM Configuration
|
||||
judge_model = dict(
|
||||
judge_models = [dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='pandalm-7b-v1-hf',
|
||||
path='WeOpenML/PandaLM-7B-v1',
|
||||
@ -79,12 +79,12 @@ judge_model = dict(
|
||||
batch_size=8,
|
||||
model_kwargs=dict(device_map='auto', trust_remote_code=True),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
)]
|
||||
|
||||
## ------------- Evaluation Configuration
|
||||
eval = dict(
|
||||
partitioner=dict(type=SubjectiveNaivePartitioner, mode='singlescore', models=models),
|
||||
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
|
||||
partitioner=dict(type=SubjectiveNaivePartitioner, mode='singlescore', models=models, judge_models=judge_models),
|
||||
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
|
||||
)
|
||||
|
||||
summarizer = dict(type=AlignmentBenchSummarizer)
|
||||
|
@ -2,7 +2,6 @@ from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .datasets.subjective.multiround.mtbench_single_judge_diff_temp import subjective_datasets
|
||||
# from .datasets.subjective.multiround.mtbench_pair_judge import subjective_datasets
|
||||
|
||||
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
|
||||
from opencompass.models.openai_api import OpenAIAllesAPIN
|
||||
@ -62,7 +61,7 @@ datasets = [*subjective_datasets]
|
||||
# -------------Evalation Stage ----------------------------------------
|
||||
|
||||
## ------------- JudgeLLM Configuration
|
||||
judge_model = dict(
|
||||
judge_models = [dict(
|
||||
abbr='GPT4-Turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-0613', # To compare with the official leaderboard, please use gpt4-0613
|
||||
@ -73,23 +72,12 @@ judge_model = dict(
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
temperature=0,
|
||||
)
|
||||
## ------------- Evaluation Configuration
|
||||
# ## pair evaluation
|
||||
# eval = dict(
|
||||
# partitioner=dict(
|
||||
# type=SubjectiveSizePartitioner, max_task_size=100, mode='m2n', base_models=[gpt4], compare_models=models
|
||||
# ),
|
||||
# runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
|
||||
# )
|
||||
|
||||
# summarizer = dict(type=MTBenchSummarizer, judge_type='pair')
|
||||
|
||||
)]
|
||||
|
||||
## single evaluation
|
||||
eval = dict(
|
||||
partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models),
|
||||
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
|
||||
partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models, judge_models=judge_models),
|
||||
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask)),
|
||||
)
|
||||
|
||||
summarizer = dict(type=MTBenchSummarizer, judge_type='single')
|
||||
|
@ -27,10 +27,12 @@ def extract_dicts(data):
|
||||
return predictions
|
||||
|
||||
|
||||
def order_preds_and_record_references(predictions,
|
||||
references,
|
||||
infer_order,
|
||||
seed=2680):
|
||||
def order_preds_and_record_references(
|
||||
predictions,
|
||||
references,
|
||||
infer_order,
|
||||
seed=666,
|
||||
):
|
||||
"""Order predictions based on args and recording regrading references.
|
||||
|
||||
Args:
|
||||
@ -85,17 +87,19 @@ class LMEvaluator:
|
||||
prompt_template: ConfigDict,
|
||||
judge_cfg: ConfigDict,
|
||||
output_path: str,
|
||||
infer_order: Optional[str] = 'random',
|
||||
meta_review_prompt_template: Optional[ConfigDict] = None,
|
||||
dataset_cfg: Optional[ConfigDict] = None,
|
||||
postprocessor: ConfigDict = dict(type=first_number_postprocess)
|
||||
) -> None:
|
||||
assert infer_order in ['random', 'double']
|
||||
self.output_path = output_path
|
||||
out_dir, out_name = osp.split(output_path)
|
||||
if not out_dir:
|
||||
out_dir = './'
|
||||
|
||||
self.prompt_tmpl = ICL_PROMPT_TEMPLATES.build(prompt_template)
|
||||
if meta_review_prompt_template is not None:
|
||||
self.meta_review_prompt_tmpl = ICL_PROMPT_TEMPLATES.build(
|
||||
meta_review_prompt_template)
|
||||
|
||||
max_out_len = judge_cfg.get('max_out_len', None)
|
||||
batch_size = judge_cfg.get('batch_size', None)
|
||||
@ -108,16 +112,20 @@ class LMEvaluator:
|
||||
self.postprocessor = get_type_from_cfg(postprocessor)
|
||||
self.logger = get_logger()
|
||||
self.dataset_cfg = dataset_cfg
|
||||
self.infer_order = infer_order
|
||||
|
||||
def score(self, predictions, references: Optional[List] = None) -> Dict:
|
||||
def score(self,
|
||||
predictions,
|
||||
judgements: Optional[List] = None,
|
||||
references: Optional[List] = None,
|
||||
meta: Optional[bool] = False,
|
||||
infer_order: Optional[str] = 'random') -> Dict:
|
||||
dup_indices = []
|
||||
if type(predictions) == list:
|
||||
"""Apply to multi-model comparison."""
|
||||
references = [{} for _ in range(len(predictions[0]['model_preds']))
|
||||
] if references is None else references
|
||||
predictions, references = order_preds_and_record_references(
|
||||
predictions, references, self.infer_order)
|
||||
predictions, references, infer_order)
|
||||
|
||||
# calculate dupicated predictions numbers
|
||||
total_predictions_num = len(predictions[0])
|
||||
@ -135,6 +143,9 @@ class LMEvaluator:
|
||||
] if references is None else references
|
||||
predictions = [predictions['model_preds']]
|
||||
|
||||
# Due to the rarity of identical predictions, we have temporarily disabled the plagiarism detection feature.
|
||||
dup_indices = []
|
||||
|
||||
if len(dup_indices) != 0:
|
||||
# remove dupicated predictions
|
||||
for index in sorted(dup_indices, reverse=True):
|
||||
@ -149,6 +160,14 @@ class LMEvaluator:
|
||||
for i in range(len(predictions)):
|
||||
key = 'prediction' if i == 0 else f'prediction{i + 1}'
|
||||
pred_dict[key] = predictions[i]
|
||||
if judgements:
|
||||
for i in range(len(judgements)):
|
||||
key = 'judgement' if i == 0 else f'judgement{i + 1}'
|
||||
pred_dict[key] = judgements[i]['model_preds']
|
||||
for j in range(len(references)):
|
||||
references[j]['judge_model' +
|
||||
str(i + 1)] = judgements[i]['model_name']
|
||||
|
||||
elif isinstance(
|
||||
predictions[0][0], list
|
||||
): #multi round for format like [[[{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}], [{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}]]]
|
||||
@ -158,11 +177,13 @@ class LMEvaluator:
|
||||
key = 'prediction' if i == 0 else f'prediction{i}'
|
||||
key += '_r' + str(j + 1)
|
||||
pred_dict[key] = multiround_predictions[j]
|
||||
|
||||
if judgements:
|
||||
raise NotImplementedError(
|
||||
'Not applied meta-reivew judge on multi-round dataset')
|
||||
if self.dataset_cfg:
|
||||
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
||||
|
||||
if self.infer_order == 'double':
|
||||
if infer_order == 'double':
|
||||
new_ds = {
|
||||
k: dataset.test[k] * 2
|
||||
for k in dataset.test.column_names
|
||||
@ -179,7 +200,6 @@ class LMEvaluator:
|
||||
print(
|
||||
f'Among total {total_predictions_num} predictions, there are {len(dup_indices)} predictions totally same, which are removed!'
|
||||
)
|
||||
|
||||
for k, v in pred_dict.items():
|
||||
dataset.reader.dataset['test'] = dataset.test.add_column(k, v)
|
||||
dataset.reader.input_columns.append(k)
|
||||
@ -201,8 +221,13 @@ class LMEvaluator:
|
||||
**pred_dict)
|
||||
dataset.reader.output_column = 'reference'
|
||||
retriever = ZeroRetriever(dataset)
|
||||
self.inferencer.inference(retriever=retriever,
|
||||
prompt_template=self.prompt_tmpl)
|
||||
if meta:
|
||||
self.inferencer.inference(
|
||||
retriever=retriever,
|
||||
prompt_template=self.meta_review_prompt_tmpl)
|
||||
else:
|
||||
self.inferencer.inference(retriever=retriever,
|
||||
prompt_template=self.prompt_tmpl)
|
||||
|
||||
output = mmengine.load(self.output_path)
|
||||
return self.postprocess(output)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# flake8: noqa: E501
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
@ -81,11 +82,21 @@ class BasePartitioner:
|
||||
work_dir=work_dir,
|
||||
out_dir=self.out_dir,
|
||||
add_cfg=add_cfg)
|
||||
|
||||
self.logger.info(f'Partitioned into {len(tasks)} tasks.')
|
||||
for i, task in enumerate(tasks):
|
||||
self.logger.debug(f'Task {i}: {task_abbr_from_cfg(task)}')
|
||||
|
||||
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
|
||||
tasks[0], list):
|
||||
self.logger.info(
|
||||
f'Now we are in the subjective evluation! Partitioned into 2 stages and {len(tasks[0])} tasks in first stage, {len(tasks[1])} tasks in second stage.'
|
||||
)
|
||||
cnt = 0
|
||||
for task_part in tasks:
|
||||
for task in task_part:
|
||||
self.logger.debug(
|
||||
f'Task {cnt}: {task_abbr_from_cfg(task)}')
|
||||
cnt += 1
|
||||
else:
|
||||
self.logger.info(f'Partitioned into {len(tasks)} tasks.')
|
||||
for i, task in enumerate(tasks):
|
||||
self.logger.debug(f'Task {i}: {task_abbr_from_cfg(task)}')
|
||||
return tasks
|
||||
|
||||
def parse_model_dataset_args(self, cfg: ConfigDict):
|
||||
|
@ -1,14 +1,20 @@
|
||||
# flake8: noqa: E501
|
||||
import copy
|
||||
import os.path as osp
|
||||
from itertools import combinations, product
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
from opencompass.registry import PARTITIONERS
|
||||
from opencompass.utils import (deal_with_judge_model_abbr,
|
||||
get_infer_output_path, model_abbr_from_cfg)
|
||||
|
||||
from .naive import NaivePartitioner
|
||||
|
||||
|
||||
def remove_duplicate_pairs(model_combinations):
|
||||
# For compare mode, we need to remove redundant pairs first
|
||||
combo_dict = {}
|
||||
for i, combo in enumerate(model_combinations):
|
||||
sorted_names = tuple(sorted((combo[0]['abbr'], combo[1]['abbr'])))
|
||||
@ -20,6 +26,82 @@ def remove_duplicate_pairs(model_combinations):
|
||||
return new_model_combinations
|
||||
|
||||
|
||||
def replicate_tasks_with_judge_models(tasks, judge_models, meta_judge_model):
|
||||
# When all tasks are already partitioned, we add judge_models and meta_judge_model as additional args.
|
||||
if meta_judge_model:
|
||||
replicated_tasks = [[], []]
|
||||
else:
|
||||
replicated_tasks = []
|
||||
for task in tasks:
|
||||
replicated_task_dicts = [task.copy() for _ in range(len(judge_models))]
|
||||
for idx, replicated_task in enumerate(replicated_task_dicts):
|
||||
replicated_task['judge_model'] = judge_models[idx]
|
||||
if meta_judge_model:
|
||||
meta_task = task.copy()
|
||||
meta_task['meta_judge_model'] = meta_judge_model
|
||||
meta_task['judge_models'] = judge_models
|
||||
replicated_tasks[1].append(meta_task)
|
||||
replicated_tasks[0].extend(replicated_task_dicts)
|
||||
else:
|
||||
replicated_tasks.extend(replicated_task_dicts)
|
||||
return replicated_tasks
|
||||
|
||||
|
||||
def remove_already_tasks(tasks, work_dir, meta_judge_model):
|
||||
# Check and remove the already finished subjective evaluation tasks
|
||||
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
|
||||
tasks[0], list):
|
||||
tasks_to_keep = [[], []]
|
||||
for i in range(2):
|
||||
for task in tasks[i]:
|
||||
temp_task = copy.deepcopy(task)
|
||||
to_delete_index = [
|
||||
] # To deal with the situation that the partition strategy is not split, which means that there will be a task contains multi dataset, and when we need to re-start, we need to remove the already done tasks.
|
||||
for idx, dataset in enumerate(task['datasets'][0]):
|
||||
if i == 0:
|
||||
filename = get_infer_output_path(
|
||||
deal_with_judge_model_abbr(task['models'][0],
|
||||
task['judge_model'],
|
||||
False), dataset,
|
||||
osp.join(work_dir, 'results'))
|
||||
else:
|
||||
filename = get_infer_output_path(
|
||||
deal_with_judge_model_abbr(
|
||||
task['models'][0], task['meta_judge_model'],
|
||||
True), dataset, osp.join(work_dir, 'results'))
|
||||
if osp.exists(filename):
|
||||
to_delete_index.append(idx)
|
||||
temp_task['datasets'][0] = [
|
||||
temp_task['datasets'][0][j]
|
||||
for j in range(len(temp_task['datasets'][0]))
|
||||
if j not in to_delete_index
|
||||
]
|
||||
if len(temp_task['datasets'][0]) != 0:
|
||||
tasks_to_keep[i].append(temp_task)
|
||||
else:
|
||||
tasks_to_keep = []
|
||||
for task in tasks:
|
||||
temp_task = copy.deepcopy(task)
|
||||
to_delete_index = [
|
||||
] # To deal with the situation that the partition strategy is not split, which means that there will be a task contains multi dataset, and when we need to re-start, we need to remove the already done tasks.
|
||||
for idx, dataset in enumerate(task['datasets'][0]):
|
||||
filename = get_infer_output_path(
|
||||
deal_with_judge_model_abbr(task['models'][0],
|
||||
task['judge_model']), dataset,
|
||||
osp.join(work_dir, 'results'))
|
||||
if osp.exists(filename):
|
||||
to_delete_index.append(idx)
|
||||
# Remove the already done tasks
|
||||
temp_task['datasets'][0] = [
|
||||
temp_task['datasets'][0][j]
|
||||
for j in range(len(temp_task['datasets'][0]))
|
||||
if j not in to_delete_index
|
||||
]
|
||||
if len(temp_task['datasets'][0]) != 0:
|
||||
tasks_to_keep.append(temp_task)
|
||||
return tasks_to_keep
|
||||
|
||||
|
||||
@PARTITIONERS.register_module()
|
||||
class SubjectiveNaivePartitioner(NaivePartitioner):
|
||||
"""Naive task partitioner for subjective evaluation. Compared to
|
||||
@ -37,15 +119,22 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
|
||||
models: Optional[List[ConfigDict]] = [],
|
||||
base_models: Optional[List[ConfigDict]] = [],
|
||||
compare_models: Optional[List[ConfigDict]] = [],
|
||||
judge_models: Optional[List[ConfigDict]] = [],
|
||||
meta_judge_model: Optional[ConfigDict] = None,
|
||||
model_pairs: Optional[List[Tuple]] = None,
|
||||
keep_keys: Optional[List[str]] = None):
|
||||
keep_keys: Optional[List[str]] = None,
|
||||
infer_order: Optional[str] = 'random'):
|
||||
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
|
||||
assert mode in ['singlescore', 'allpair', 'm2n', 'fixed']
|
||||
assert infer_order in ['random', 'double']
|
||||
self.mode = mode
|
||||
self.models = models
|
||||
self.base_models = base_models
|
||||
self.compare_models = compare_models
|
||||
self.model_pairs = model_pairs
|
||||
self.judge_models = judge_models
|
||||
self.meta_judge_model = meta_judge_model
|
||||
self.infer_order = infer_order
|
||||
|
||||
def get_model_combinations(
|
||||
self,
|
||||
@ -97,14 +186,35 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
|
||||
"""
|
||||
models = self.models if self.models != [] else models
|
||||
base_models, compare_models = self.base_models, self.compare_models
|
||||
judge_models, meta_judge_model = self.judge_models, self.meta_judge_model
|
||||
if self.mode == 'singlescore':
|
||||
models = models
|
||||
else:
|
||||
models = self.get_model_combinations(models, base_models,
|
||||
compare_models)
|
||||
model_dataset_combinations = [{'models': models, 'datasets': datasets}]
|
||||
return super().partition(
|
||||
tasks = super().partition(
|
||||
model_dataset_combinations=model_dataset_combinations,
|
||||
work_dir=work_dir,
|
||||
out_dir=out_dir,
|
||||
add_cfg=add_cfg)
|
||||
|
||||
# We need to add judge models and meta-judge-model as new tasks
|
||||
# When there is no meta-judge-model, we assign all judge models to each tasks
|
||||
# When there is a meta-judge-model, we add an additional task stage
|
||||
tasks = replicate_tasks_with_judge_models(tasks, judge_models,
|
||||
meta_judge_model)
|
||||
|
||||
# We also need to check and remove the already done tasks
|
||||
tasks = remove_already_tasks(tasks, work_dir, meta_judge_model)
|
||||
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
|
||||
tasks[0], list):
|
||||
# Refer to meta review judge
|
||||
for task_stage in tasks:
|
||||
for task in task_stage:
|
||||
task['infer_order'] = self.infer_order
|
||||
else:
|
||||
# Refer to just have review judge
|
||||
for task in tasks:
|
||||
task['infer_order'] = self.infer_order
|
||||
return tasks
|
||||
|
@ -1,3 +1,4 @@
|
||||
# flake8: noqa: E501
|
||||
import copy
|
||||
import math
|
||||
import os.path as osp
|
||||
@ -11,7 +12,8 @@ from opencompass.registry import PARTITIONERS
|
||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
get_infer_output_path)
|
||||
|
||||
from .sub_naive import SubjectiveNaivePartitioner
|
||||
from .sub_naive import (SubjectiveNaivePartitioner, remove_already_tasks,
|
||||
replicate_tasks_with_judge_models)
|
||||
|
||||
|
||||
@PARTITIONERS.register_module()
|
||||
@ -40,19 +42,25 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
|
||||
models: Optional[List[ConfigDict]] = [],
|
||||
base_models: Optional[List[ConfigDict]] = [],
|
||||
compare_models: Optional[List[ConfigDict]] = [],
|
||||
judge_models: Optional[List[ConfigDict]] = [],
|
||||
meta_judge_model: Optional[ConfigDict] = None,
|
||||
model_pairs: Optional[List[Tuple]] = None,
|
||||
max_task_size: int = 40000,
|
||||
gen_task_coef: int = 20,
|
||||
strategy: str = 'heuristic',
|
||||
dataset_size_path: str = '.cache/dataset_size.json',
|
||||
keep_keys: Optional[List[str]] = None):
|
||||
keep_keys: Optional[List[str]] = None,
|
||||
infer_order: Optional[str] = 'random'):
|
||||
super().__init__(out_dir=out_dir,
|
||||
keep_keys=keep_keys,
|
||||
mode=mode,
|
||||
models=models,
|
||||
base_models=base_models,
|
||||
compare_models=compare_models,
|
||||
model_pairs=model_pairs)
|
||||
judge_models=judge_models,
|
||||
meta_judge_model=meta_judge_model,
|
||||
model_pairs=model_pairs,
|
||||
infer_order=infer_order)
|
||||
self.max_task_size = max_task_size
|
||||
self.gen_task_coef = gen_task_coef
|
||||
self.dataset_size_path = dataset_size_path
|
||||
@ -96,13 +104,13 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
|
||||
"""
|
||||
models = self.models if self.models != [] else models
|
||||
base_models, compare_models = self.base_models, self.compare_models
|
||||
judge_models, meta_judge_model = self.judge_models, self.meta_judge_model
|
||||
if self.mode == 'singlescore':
|
||||
models = models
|
||||
else:
|
||||
models = super().get_model_combinations(models, base_models,
|
||||
compare_models)
|
||||
model_dataset_combinations = [{'models': models, 'datasets': datasets}]
|
||||
|
||||
tasks = []
|
||||
for comb in model_dataset_combinations:
|
||||
comb['datasets'] = sorted(comb['datasets'],
|
||||
@ -113,8 +121,8 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
|
||||
for dataset in comb['datasets']:
|
||||
filename = get_infer_output_path(model, dataset, out_dir)
|
||||
# skip the task if the task output exists
|
||||
if osp.exists(filename):
|
||||
continue
|
||||
# if osp.exists(filename):
|
||||
# continue
|
||||
dataset_size = self.get_cost(dataset)
|
||||
if dataset_size > self.max_task_size:
|
||||
root, ext = osp.splitext(filename)
|
||||
@ -151,6 +159,21 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
|
||||
'work_dir': work_dir,
|
||||
**add_cfg
|
||||
}))
|
||||
|
||||
tasks = replicate_tasks_with_judge_models(tasks, judge_models,
|
||||
meta_judge_model)
|
||||
tasks = remove_already_tasks(tasks, work_dir, meta_judge_model)
|
||||
|
||||
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
|
||||
tasks[0], list):
|
||||
# Refer to meta review judge
|
||||
for task_stage in tasks:
|
||||
for task in task_stage:
|
||||
task['infer_order'] = self.infer_order
|
||||
else:
|
||||
# Refer to just have review judge
|
||||
for task in tasks:
|
||||
task['infer_order'] = self.infer_order
|
||||
return tasks
|
||||
|
||||
@property
|
||||
|
@ -309,7 +309,7 @@ class AlignmentBenchSummarizer:
|
||||
self.eval_model_abbrs = [
|
||||
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
|
||||
]
|
||||
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
|
||||
self.judge_models = self.cfg.get('judge_models', None)
|
||||
self.judge_type = judge_type
|
||||
assert self.judge_type in [
|
||||
'general', 'autoj', 'judgelm', 'general_plus'
|
||||
@ -333,33 +333,36 @@ class AlignmentBenchSummarizer:
|
||||
Returns:
|
||||
pd.DataFrame: The summary results.
|
||||
"""
|
||||
dataset_cfgs = self.cfg['datasets']
|
||||
output_dir, results_folder = get_outdir(self.cfg, time_str)
|
||||
fout_flag, fout_flag2 = 0, 0
|
||||
for eval_model_abbr in self.eval_model_abbrs:
|
||||
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
|
||||
subdir_path = os.path.join(results_folder, subdir)
|
||||
if os.path.isdir(subdir_path):
|
||||
model, judge_model = eval_model_abbr, self.judge_abbr
|
||||
if self.judge_type == 'general':
|
||||
fout = osp.join(
|
||||
output_dir,
|
||||
'judged-by--' + judge_model + '-dimension.csv')
|
||||
fout2 = osp.join(
|
||||
output_dir,
|
||||
'judged-by--' + judge_model + '-capability.csv')
|
||||
for dataset in dataset_cfgs:
|
||||
judged_answers, references = get_judgeanswer_and_reference(
|
||||
dataset, subdir_path, self.judge_function)
|
||||
for judge_model in self.judge_models:
|
||||
judge_abbr = model_abbr_from_cfg(judge_model)
|
||||
dataset_cfgs = self.cfg['datasets']
|
||||
output_dir, results_folder = get_outdir(self.cfg, time_str)
|
||||
fout_flag, fout_flag2 = 0, 0
|
||||
for eval_model_abbr in self.eval_model_abbrs:
|
||||
subdir = eval_model_abbr + '_judged-by--' + judge_abbr
|
||||
subdir_path = os.path.join(results_folder, subdir)
|
||||
if os.path.isdir(subdir_path):
|
||||
model = eval_model_abbr
|
||||
if self.judge_type == 'general':
|
||||
get_dimension_results(judged_answers, references, fout,
|
||||
fout_flag, model)
|
||||
fout_flag += 1
|
||||
get_capability_results(judged_answers, references, fout2,
|
||||
fout_flag2, model, self.category)
|
||||
fout_flag2 += 1
|
||||
else:
|
||||
print(subdir_path + ' is not exist! please check!')
|
||||
fout = osp.join(
|
||||
output_dir,
|
||||
'judged-by--' + judge_abbr + '-dimension.csv')
|
||||
fout2 = osp.join(
|
||||
output_dir,
|
||||
'judged-by--' + judge_abbr + '-capability.csv')
|
||||
for dataset in dataset_cfgs:
|
||||
judged_answers, references = get_judgeanswer_and_reference(
|
||||
dataset, subdir_path, self.judge_function)
|
||||
if self.judge_type == 'general':
|
||||
get_dimension_results(judged_answers, references,
|
||||
fout, fout_flag, model)
|
||||
fout_flag += 1
|
||||
get_capability_results(judged_answers, references,
|
||||
fout2, fout_flag2, model,
|
||||
self.category)
|
||||
fout_flag2 += 1
|
||||
else:
|
||||
print(subdir_path + ' is not exist! please check!')
|
||||
if self.judge_type == 'general':
|
||||
with open(fout, 'r') as f:
|
||||
x = from_csv(f)
|
||||
|
@ -82,7 +82,8 @@ class AlpacaSummarizer:
|
||||
self.cfg = config
|
||||
self.base_models = self.cfg['eval']['partitioner']['base_models']
|
||||
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
|
||||
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
|
||||
self.judge_abbr = model_abbr_from_cfg(
|
||||
self.cfg['judge_models'][0]) # We will reorganize the summarizers
|
||||
self.judge_type = judge_type
|
||||
assert self.judge_type in ['v1', 'v2']
|
||||
self.judge_map = {
|
||||
|
@ -67,7 +67,9 @@ class CompassArenaSummarizer:
|
||||
self.cfg = config
|
||||
self.base_models = self.cfg['eval']['partitioner']['base_models']
|
||||
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
|
||||
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
|
||||
self.judge_models = self.cfg.get('judge_models', None)
|
||||
self.meta_judge_model = self.cfg.eval.partitioner.get(
|
||||
'meta_judge_model', None)
|
||||
self.judge_type = judge_type
|
||||
assert self.judge_type in ['general']
|
||||
self.judge_map = {
|
||||
@ -95,109 +97,135 @@ class CompassArenaSummarizer:
|
||||
product(self.base_models, self.compare_models))
|
||||
unique_combinations = remove_duplicate_pairs(
|
||||
[combo for combo in model_combinations if combo[0] != combo[1]])
|
||||
judge_model = self.judge_abbr
|
||||
|
||||
fout_list = []
|
||||
for dataset in dataset_cfgs:
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||
fout = osp.join(
|
||||
output_dir, 'judged-by--' + judge_model + '-' + dataset_abbr +
|
||||
'-report.csv')
|
||||
fout_list.append(fout)
|
||||
for model_pair in unique_combinations:
|
||||
model1, model2, = model_pair[0]['abbr'], model_pair[1]['abbr'],
|
||||
subdir = model1 + '_' + model2 + '_judged-by--' + judge_model
|
||||
subdir_path = os.path.join(results_folder, subdir)
|
||||
if os.path.isdir(subdir_path):
|
||||
judged_answers, references = get_judgeanswer_and_reference(
|
||||
dataset,
|
||||
subdir_path,
|
||||
self.judge_function,
|
||||
)
|
||||
if self.check_pos_bias:
|
||||
bias_num = check_position_bias(judged_answers,
|
||||
references)
|
||||
else:
|
||||
bias_num = 0
|
||||
win_model1, win_model2, categories = defaultdict(
|
||||
float), defaultdict(float), defaultdict(float)
|
||||
model1, model2 = references[0]['answer1'], references[0][
|
||||
'answer2']
|
||||
for prediction, reference in zip(judged_answers,
|
||||
references):
|
||||
if self.summary_type == 'single':
|
||||
if prediction == 'A':
|
||||
categories['total'] += 1
|
||||
categories[reference['capability']] += 1
|
||||
if reference['answer1'] == model1:
|
||||
win_model1[reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
else:
|
||||
win_model2[reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
elif prediction == 'B':
|
||||
categories['total'] += 1
|
||||
categories[reference['capability']] += 1
|
||||
if reference['answer1'] == model1:
|
||||
win_model2[reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
else:
|
||||
win_model1[reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
elif self.summary_type == 'half_add':
|
||||
categories['total'] += 1
|
||||
categories[reference['capability']] += 1
|
||||
if prediction == 'A':
|
||||
if reference['answer1'] == model1:
|
||||
win_model1[reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
else:
|
||||
win_model2[reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
elif prediction == 'B':
|
||||
if reference['answer1'] == model1:
|
||||
win_model2[reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
else:
|
||||
win_model1[reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
elif prediction == 'C':
|
||||
win_model1[reference['capability']] += 0.5
|
||||
win_model1['total'] += 0.5
|
||||
win_model2[reference['capability']] += 0.5
|
||||
win_model2['total'] += 0.5
|
||||
for capability in categories:
|
||||
if capability not in win_model1:
|
||||
win_model1[capability] = 0.0
|
||||
else:
|
||||
win_model1[capability] = round(
|
||||
(win_model1[capability] /
|
||||
categories[capability]) * 100, 2)
|
||||
if capability not in win_model2:
|
||||
win_model2[capability] = 0.0
|
||||
else:
|
||||
win_model2[capability] = round(
|
||||
(win_model2[capability] /
|
||||
categories[capability]) * 100, 2)
|
||||
win_model1['position_bias'] = bias_num
|
||||
win_model2['position_bias'] = bias_num
|
||||
scores = {
|
||||
'win_' + model1: win_model1,
|
||||
'win_' + model2: win_model2
|
||||
}
|
||||
rows = list(scores.keys())
|
||||
columns = list(scores[rows[0]].keys())
|
||||
columns.insert(0, columns.pop(columns.index('total')))
|
||||
columns.insert(1,
|
||||
columns.pop(columns.index('position_bias')))
|
||||
with open(fout, 'a+', newline='') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow([model1 + '_vs_' + model2] + columns)
|
||||
for row in rows:
|
||||
writer.writerow(
|
||||
[row] +
|
||||
[scores[row][column] for column in columns])
|
||||
pre_len = len(self.judge_models)
|
||||
if self.meta_judge_model is not None:
|
||||
self.judge_models.append(self.meta_judge_model)
|
||||
meta_judge_model_abbr = model_abbr_from_cfg(self.meta_judge_model)
|
||||
else:
|
||||
meta_judge_model_abbr = None
|
||||
for idx, judge_model in enumerate(self.judge_models):
|
||||
judge_model = model_abbr_from_cfg(judge_model)
|
||||
for dataset in dataset_cfgs:
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||
if idx == pre_len:
|
||||
fout = osp.join(
|
||||
output_dir, 'summarized-by--' + judge_model + '-' +
|
||||
dataset_abbr + '-report.csv')
|
||||
else:
|
||||
print(subdir_path + ' is not exist! please check!')
|
||||
fout = osp.join(
|
||||
output_dir, 'judged-by--' + judge_model + '-' +
|
||||
dataset_abbr + '-report.csv')
|
||||
fout_list.append(fout)
|
||||
for model_pair in unique_combinations:
|
||||
model1, model2, = model_pair[0]['abbr'], model_pair[1][
|
||||
'abbr'],
|
||||
if idx == pre_len:
|
||||
subdir = model1 + '_' + model2 + '_summarized-by--' + judge_model
|
||||
else:
|
||||
subdir = model1 + '_' + model2 + '_judged-by--' + judge_model
|
||||
subdir_path = os.path.join(results_folder, subdir)
|
||||
if os.path.isdir(subdir_path):
|
||||
judged_answers, references = get_judgeanswer_and_reference(
|
||||
dataset,
|
||||
subdir_path,
|
||||
self.judge_function,
|
||||
)
|
||||
if self.check_pos_bias:
|
||||
bias_num = check_position_bias(
|
||||
judged_answers, references)
|
||||
else:
|
||||
bias_num = 0
|
||||
win_model1, win_model2, categories = defaultdict(
|
||||
float), defaultdict(float), defaultdict(float)
|
||||
model1, model2 = references[0]['answer1'], references[
|
||||
0]['answer2']
|
||||
for prediction, reference in zip(
|
||||
judged_answers, references):
|
||||
if self.summary_type == 'single':
|
||||
if prediction == 'A':
|
||||
categories['total'] += 1
|
||||
categories[reference['capability']] += 1
|
||||
if reference['answer1'] == model1:
|
||||
win_model1[
|
||||
reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
else:
|
||||
win_model2[
|
||||
reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
elif prediction == 'B':
|
||||
categories['total'] += 1
|
||||
categories[reference['capability']] += 1
|
||||
if reference['answer1'] == model1:
|
||||
win_model2[
|
||||
reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
else:
|
||||
win_model1[
|
||||
reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
elif self.summary_type == 'half_add':
|
||||
categories['total'] += 1
|
||||
categories[reference['capability']] += 1
|
||||
if prediction == 'A':
|
||||
if reference['answer1'] == model1:
|
||||
win_model1[
|
||||
reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
else:
|
||||
win_model2[
|
||||
reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
elif prediction == 'B':
|
||||
if reference['answer1'] == model1:
|
||||
win_model2[
|
||||
reference['capability']] += 1
|
||||
win_model2['total'] += 1
|
||||
else:
|
||||
win_model1[
|
||||
reference['capability']] += 1
|
||||
win_model1['total'] += 1
|
||||
elif prediction == 'C':
|
||||
win_model1[reference['capability']] += 0.5
|
||||
win_model1['total'] += 0.5
|
||||
win_model2[reference['capability']] += 0.5
|
||||
win_model2['total'] += 0.5
|
||||
for capability in categories:
|
||||
if capability not in win_model1:
|
||||
win_model1[capability] = 0.0
|
||||
else:
|
||||
win_model1[capability] = round(
|
||||
(win_model1[capability] /
|
||||
categories[capability]) * 100, 2)
|
||||
if capability not in win_model2:
|
||||
win_model2[capability] = 0.0
|
||||
else:
|
||||
win_model2[capability] = round(
|
||||
(win_model2[capability] /
|
||||
categories[capability]) * 100, 2)
|
||||
win_model1['position_bias'] = bias_num
|
||||
win_model2['position_bias'] = bias_num
|
||||
scores = {
|
||||
'win_' + model1: win_model1,
|
||||
'win_' + model2: win_model2
|
||||
}
|
||||
rows = list(scores.keys())
|
||||
columns = list(scores[rows[0]].keys())
|
||||
columns.insert(0, columns.pop(columns.index('total')))
|
||||
columns.insert(
|
||||
1, columns.pop(columns.index('position_bias')))
|
||||
with open(fout, 'a+', newline='') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow([model1 + '_vs_' + model2] +
|
||||
columns)
|
||||
for row in rows:
|
||||
writer.writerow([row] + [
|
||||
scores[row][column] for column in columns
|
||||
])
|
||||
else:
|
||||
print(subdir_path + ' is not exist! please check!')
|
||||
for fout in fout_list:
|
||||
with open(fout, 'r') as f:
|
||||
x = from_csv(f)
|
||||
|
@ -98,7 +98,7 @@ class MTBenchSummarizer(CompassArenaSummarizer):
|
||||
self.base_models = self.cfg['eval']['partitioner']['base_models']
|
||||
self.compare_models = self.cfg['eval']['partitioner'][
|
||||
'compare_models']
|
||||
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
|
||||
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
|
||||
self.judge_map = {
|
||||
'single': post_process_mtbench_single,
|
||||
'pair': post_process_mtbench_pair
|
||||
|
@ -1,10 +1,11 @@
|
||||
# flake8: noqa: E501
|
||||
import argparse
|
||||
import copy
|
||||
import fnmatch
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
@ -14,6 +15,7 @@ from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.tasks.openicl_eval import extract_role_pred
|
||||
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
|
||||
deal_with_judge_model_abbr,
|
||||
get_infer_output_path, get_logger,
|
||||
model_abbr_from_cfg, task_abbr_from_cfg)
|
||||
|
||||
@ -35,21 +37,25 @@ class SubjectiveEvalTask(BaseTask):
|
||||
def __init__(self, cfg: ConfigDict):
|
||||
super().__init__(cfg)
|
||||
self.logger = get_logger()
|
||||
judge_cfg = cfg.eval.runner.task.get('judge_cfg', {})
|
||||
if type(judge_cfg) != ConfigDict:
|
||||
print('*' * 100)
|
||||
print('Due to different Judge model needs different summarizer and'
|
||||
" prompts, we don't support multi judge model evaluation at "
|
||||
'one time, please do not use list to set your judge cfg, jus'
|
||||
't use a dict or list[0] should be fine. If you want to eval'
|
||||
'uation multi judge model in one script, we suggest you to u'
|
||||
'se a bash or bat script to start multi configs evaluation!')
|
||||
print('*' * 100)
|
||||
assert type(judge_cfg) == ConfigDict
|
||||
judge_cfg = cfg.get('judge_model', None)
|
||||
meta_judge_cfg = cfg.get('meta_judge_model', None)
|
||||
judge_models = cfg.get('judge_models', None)
|
||||
|
||||
if judge_cfg is None and meta_judge_cfg is None:
|
||||
assert judge_cfg is not None, 'Both judge_cfg and meta_judge_cfg are None, but judge_models must be provided.'
|
||||
|
||||
if meta_judge_cfg is not None:
|
||||
assert judge_models is not None, 'meta_judge_cfg is provided, but judge_models are missing.'
|
||||
judge_cfg = meta_judge_cfg # Relpace judge_cfg to meta_judge_cfg when it is not None
|
||||
self.meta = True
|
||||
else:
|
||||
self.meta = False
|
||||
run_cfg = judge_cfg.get('run_cfg', {})
|
||||
self.num_gpus = run_cfg.get('num_gpus', 0)
|
||||
self.num_procs = run_cfg.get('num_procs', 1)
|
||||
self.judge_cfg = copy.deepcopy(judge_cfg)
|
||||
self.judge_models = judge_models
|
||||
self.infer_order = cfg.get('infer_order')
|
||||
self.given_pred = cfg.eval.get('given_pred', [])
|
||||
|
||||
def get_command(self, cfg_path, template):
|
||||
@ -78,17 +84,15 @@ class SubjectiveEvalTask(BaseTask):
|
||||
# Load Dataset
|
||||
eval_cfg = dataset_cfg.get('eval_cfg')
|
||||
output_column = dataset_cfg['reader_cfg']['output_column']
|
||||
if type(model_cfg) == ConfigDict:
|
||||
model_cfg = (model_cfg, )
|
||||
model_cfg += ({
|
||||
'abbr':
|
||||
'judged-by--' + model_abbr_from_cfg(self.judge_cfg)
|
||||
}, )
|
||||
out_path = get_infer_output_path(
|
||||
model_cfg, dataset_cfg, osp.join(self.work_dir, 'results'))
|
||||
deal_with_judge_model_abbr(model_cfg, self.judge_cfg,
|
||||
self.meta), dataset_cfg,
|
||||
osp.join(self.work_dir, 'results'))
|
||||
if osp.exists(out_path):
|
||||
continue
|
||||
self._score(model_cfg, dataset_cfg, eval_cfg, output_column)
|
||||
|
||||
self._score(model_cfg, dataset_cfg, eval_cfg, output_column,
|
||||
self.meta)
|
||||
|
||||
def _load_model_pred(
|
||||
self,
|
||||
@ -194,7 +198,139 @@ class SubjectiveEvalTask(BaseTask):
|
||||
'model_preds': pred_strs
|
||||
}
|
||||
|
||||
def _score(self, model_cfg, dataset_cfg, eval_cfg, output_column):
|
||||
def _load_model_judgements(
|
||||
self,
|
||||
model_cfg: Union[ConfigDict, List[ConfigDict]],
|
||||
dataset_cfg: ConfigDict,
|
||||
eval_cfg: ConfigDict,
|
||||
judge_cfg: Union[ConfigDict, List[ConfigDict]],
|
||||
) -> Union[None, List[str]]:
|
||||
|
||||
if isinstance(judge_cfg, (tuple, list)):
|
||||
return [
|
||||
self._load_model_judgements(model_cfg, dataset_cfg, eval_cfg,
|
||||
j) for j in judge_cfg
|
||||
]
|
||||
|
||||
pred_strs = None
|
||||
model_cfg = [model_cfg] if isinstance(model_cfg,
|
||||
ConfigDict) else model_cfg
|
||||
# There will be 5 situations, so we need to deal with them
|
||||
# 1.There are no partitions in infer and judge stage
|
||||
# 2.No partition in infer stage, but use partition in judge stage
|
||||
# 3.Use partition in infer stage, but not use partition in judge stage
|
||||
# 4.Use both partition, with same partition size
|
||||
# 5.Use both partition, but different partition size
|
||||
|
||||
# If take SubjectSizePartition, get new filename without _0
|
||||
if 'test_range' in dataset_cfg['reader_cfg']:
|
||||
filename = get_infer_output_path(
|
||||
deal_with_judge_model_abbr([m for m in model_cfg], judge_cfg),
|
||||
dataset_cfg, osp.join(self.work_dir, 'results'))
|
||||
root, ext = osp.splitext(filename)
|
||||
last_underscore_index = root.rfind('_')
|
||||
root = root[:last_underscore_index]
|
||||
filename = root + ext
|
||||
# If take SubjectNaivePartition, get filename
|
||||
else:
|
||||
filename = get_infer_output_path(
|
||||
deal_with_judge_model_abbr([m for m in model_cfg], judge_cfg),
|
||||
dataset_cfg, osp.join(self.work_dir, 'results'))
|
||||
# Get partition name
|
||||
root, ext = osp.splitext(filename)
|
||||
partial_filename = root + '_0' + ext
|
||||
|
||||
# If no predictions get in predictions dir
|
||||
if not osp.exists(osp.realpath(filename)) and not osp.exists(
|
||||
osp.realpath(partial_filename)):
|
||||
return {'error': 'No judgements found.'}
|
||||
else:
|
||||
# If use Naive partition in infer stage
|
||||
if osp.exists(osp.realpath(filename)):
|
||||
preds = mmengine.load(filename)
|
||||
pred_strs = [
|
||||
preds[str(i)]['prediction'] for i in range(len(preds))
|
||||
]
|
||||
# If use Size partition in infer stage
|
||||
else:
|
||||
filename = partial_filename
|
||||
pred_strs = []
|
||||
i = 1
|
||||
while osp.exists(osp.realpath(filename)):
|
||||
preds = mmengine.load(filename)
|
||||
filename = root + f'_{i}' + ext
|
||||
i += 1
|
||||
pred_strs += [
|
||||
preds[str(i)]['prediction'] for i in range(len(preds))
|
||||
]
|
||||
# Get all judgements in pred_strs
|
||||
# If take SubjectSizePartition, get new pred_strs based on test_range
|
||||
if 'test_range' in dataset_cfg['reader_cfg']:
|
||||
test_range = dataset_cfg['reader_cfg']['test_range']
|
||||
if self.infer_order == 'double':
|
||||
# When set infer_order as double, we need to select the judgements to meet the predctions which will be doubled later
|
||||
start = 0
|
||||
end = None
|
||||
pred_strs_length = len(pred_strs)
|
||||
# Split the string by the ':', the test_range is a string shapes like '[0:15]'
|
||||
parts = test_range.strip('[]').split(':')
|
||||
# Check if the start index is provided
|
||||
if parts[0]:
|
||||
start = int(parts[0])
|
||||
# Check if the end index is provided
|
||||
if len(parts) > 1 and parts[1]:
|
||||
end = int(parts[1])
|
||||
else:
|
||||
# If the end is not provided, determine the default end based on the length of 'pred_strs'
|
||||
end = int(pred_strs_length / 2)
|
||||
assert pred_strs_length % 2 == 0, "Since you have set the infer_order as 'double', the length of 'pred_strs' must be even."
|
||||
assert end <= pred_strs_length / 2, "The 'end' value must not exceed half of the 'pred_strs' length."
|
||||
# Reset the newly start and end
|
||||
start *= 2
|
||||
end *= 2
|
||||
pred_strs = eval('pred_strs[' + str(start) + ':' + str(end) +
|
||||
']')
|
||||
else:
|
||||
pred_strs = eval('pred_strs' + test_range)
|
||||
# If take SubjectNaivePartition, get all pred_strs
|
||||
else:
|
||||
pred_strs = pred_strs
|
||||
if ('pred_role' in eval_cfg and 'meta_template' in judge_cfg
|
||||
and not MODELS.get(judge_cfg['type']).is_api
|
||||
and isinstance(pred_strs[0], str)):
|
||||
# Create a prompt template for role config parsing
|
||||
from opencompass.models.base import LMTemplateParser
|
||||
parser = LMTemplateParser(judge_cfg['meta_template'])
|
||||
role = parser.roles[eval_cfg['pred_role']]
|
||||
pred_strs = [
|
||||
extract_role_pred(pred, role.get('begin', None),
|
||||
role.get('end', None)) for pred in pred_strs
|
||||
]
|
||||
|
||||
# Postprocess predictions if necessary
|
||||
ds_abbr = dataset_abbr_from_cfg(dataset_cfg)
|
||||
model_postprocessors = judge_cfg.get('pred_postprocessor', {})
|
||||
pred_postprocessor = None
|
||||
for pattern in model_postprocessors.keys():
|
||||
if fnmatch.fnmatch(ds_abbr, pattern):
|
||||
pred_postprocessor = model_postprocessors[pattern]
|
||||
break
|
||||
if 'pred_postprocessor' in eval_cfg or pred_postprocessor:
|
||||
kwargs = pred_postprocessor or eval_cfg['pred_postprocessor']
|
||||
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
||||
pred_strs = [proc(s, **kwargs) for s in pred_strs]
|
||||
|
||||
return {
|
||||
'model_name': model_abbr_from_cfg(judge_cfg),
|
||||
'model_preds': pred_strs
|
||||
}
|
||||
|
||||
def _score(self,
|
||||
model_cfg,
|
||||
dataset_cfg,
|
||||
eval_cfg,
|
||||
output_column,
|
||||
meta=False):
|
||||
test_set = build_dataset_from_cfg(dataset_cfg).test
|
||||
# Postprocess dataset if necessary
|
||||
if 'dataset_postprocessor' in eval_cfg:
|
||||
@ -208,27 +344,32 @@ class SubjectiveEvalTask(BaseTask):
|
||||
|
||||
test_set = test_set.map(postprocess)
|
||||
# Get out_path
|
||||
out_path = get_infer_output_path(model_cfg, dataset_cfg,
|
||||
osp.join(self.work_dir, 'results'))
|
||||
new_model_cfg = []
|
||||
for m_cfg in model_cfg:
|
||||
if len(m_cfg) > 1:
|
||||
new_model_cfg.append(m_cfg)
|
||||
if len(new_model_cfg) == 1:
|
||||
new_model_cfg = new_model_cfg[0]
|
||||
model_preds = self._load_model_pred(new_model_cfg, dataset_cfg,
|
||||
eval_cfg, self.given_pred)
|
||||
out_path = get_infer_output_path(
|
||||
deal_with_judge_model_abbr(model_cfg, self.judge_cfg, self.meta),
|
||||
dataset_cfg, osp.join(self.work_dir, 'results'))
|
||||
if meta:
|
||||
model_preds = self._load_model_pred(model_cfg, dataset_cfg,
|
||||
eval_cfg, self.given_pred)
|
||||
model_judges = self._load_model_judgements(model_cfg, dataset_cfg,
|
||||
eval_cfg,
|
||||
self.judge_models)
|
||||
else:
|
||||
model_preds = self._load_model_pred(model_cfg, dataset_cfg,
|
||||
eval_cfg, self.given_pred)
|
||||
model_judges = None
|
||||
if not self.judge_cfg:
|
||||
raise ValueError('missing "eval.runner.task.judge_cfg"')
|
||||
raise ValueError('missing "eval.judge_cfg"')
|
||||
eval_cfg['evaluator']['judge_cfg'] = self.judge_cfg
|
||||
eval_cfg['evaluator']['dataset_cfg'] = dataset_cfg
|
||||
eval_cfg['evaluator']['output_path'] = out_path
|
||||
icl_evaluator = ICL_EVALUATORS.build(eval_cfg['evaluator'])
|
||||
references = (test_set[output_column] if output_column else None)
|
||||
|
||||
if 'error' not in model_preds:
|
||||
result = icl_evaluator.score(predictions=model_preds,
|
||||
references=references)
|
||||
judgements=model_judges,
|
||||
references=references,
|
||||
meta=meta,
|
||||
infer_order=self.infer_order)
|
||||
else:
|
||||
result = model_preds
|
||||
|
||||
@ -259,17 +400,24 @@ class SubjectiveEvalTask(BaseTask):
|
||||
output_paths = []
|
||||
for model, datasets in zip(self.model_cfgs, self.dataset_cfgs):
|
||||
for dataset in datasets:
|
||||
if type(model) == ConfigDict:
|
||||
if isinstance(model, ConfigDict):
|
||||
model = (model, )
|
||||
model += ({
|
||||
'abbr':
|
||||
'judged-by--' + model_abbr_from_cfg(self.judge_cfg)
|
||||
}, )
|
||||
if self.meta:
|
||||
model += ({
|
||||
'abbr':
|
||||
'summarized-by--' + model_abbr_from_cfg(self.judge_cfg)
|
||||
}, )
|
||||
else:
|
||||
model += ({
|
||||
'abbr':
|
||||
'judged-by--' + model_abbr_from_cfg(self.judge_cfg)
|
||||
}, )
|
||||
output_paths.append(
|
||||
get_infer_output_path(
|
||||
model, dataset,
|
||||
osp.join(self.work_dir, self.output_subdir),
|
||||
file_extension))
|
||||
model = model[:-1]
|
||||
return output_paths
|
||||
|
||||
|
||||
|
@ -46,3 +46,25 @@ def get_infer_output_path(model_cfg: ConfigDict,
|
||||
model_abbr = model_abbr_from_cfg(model_cfg)
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset_cfg)
|
||||
return osp.join(root_path, model_abbr, f'{dataset_abbr}.{file_extension}')
|
||||
|
||||
|
||||
def deal_with_judge_model_abbr(model_cfg, judge_model_cfg, meta=False):
|
||||
if isinstance(model_cfg, ConfigDict):
|
||||
model_cfg = (model_cfg, )
|
||||
if meta:
|
||||
for m_cfg in model_cfg:
|
||||
if 'summarized-by--' in m_cfg['abbr']:
|
||||
return model_cfg
|
||||
model_cfg += ({
|
||||
'abbr':
|
||||
'summarized-by--' + model_abbr_from_cfg(judge_model_cfg)
|
||||
}, )
|
||||
else:
|
||||
for m_cfg in model_cfg:
|
||||
if 'judged-by--' in m_cfg['abbr']:
|
||||
return model_cfg
|
||||
model_cfg += ({
|
||||
'abbr':
|
||||
'judged-by--' + model_abbr_from_cfg(judge_model_cfg)
|
||||
}, )
|
||||
return model_cfg
|
||||
|
9
run.py
9
run.py
@ -341,7 +341,14 @@ def main():
|
||||
if args.dry_run:
|
||||
return
|
||||
runner = RUNNERS.build(cfg.eval.runner)
|
||||
runner(tasks)
|
||||
|
||||
# For meta-review-judge in subjective evaluation
|
||||
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
|
||||
tasks[0], list):
|
||||
for task_part in tasks:
|
||||
runner(task_part)
|
||||
else:
|
||||
runner(tasks)
|
||||
|
||||
# visualize
|
||||
if args.mode in ['all', 'eval', 'viz']:
|
||||
|
Loading…
Reference in New Issue
Block a user