[Feature] Add multi-model judge and fix some problems (#1016)

* support multi-model judge and moe judge

* test_moe

* test_moe

* test

* add moe judge

* support multi-judge-model
This commit is contained in:
bittersweet1999 2024-04-02 11:52:06 +08:00 committed by GitHub
parent c220550fb9
commit 2d4e559763
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 761 additions and 380 deletions

View File

@ -65,7 +65,6 @@ for _name in subjective_all_sets:
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
infer_order='random',
prompt_template=dict(
type=PromptTemplate,
template=dict(

View File

@ -67,7 +67,6 @@ for _name in subjective_all_sets:
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
infer_order='random',
prompt_template=dict(
type=PromptTemplate,
template=dict(

View File

@ -119,7 +119,6 @@ for _name, _prompt in sub_map.items():
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
infer_order='double',
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[

View File

@ -0,0 +1,156 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CompassArenaDataset
subjective_reader_cfg = dict(
input_columns=['question', 'ref'],
output_column='judge',
)
data_path ="data/subjective/compass_arena"
subjective_datasets = []
base_prompt = """
[回答1开始]
{prediction}
[回答1结束]
[回答2开始]
{prediction2}
[回答2结束]
根据评分要求在以下 3 个选项中做出选择:
A. 回答1更好
B. 回答2更好
C. 回答12平局
并提供你的解释原因
如果你认为回答1更好你的输出应形如
选择A
原因blahblah blahblah\n
如果你认为回答2更好你的输出应形如
选择B
原因blahblah blahblah\n
如果你认为回答12打成平手你的输出应形如
选择C
原因blahblah blahblah\n
"""
knowledge_prompt = """
请根据提供的 评分要求用户问题参考答案 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 更好的回答能与参考答案吻合或表明参考答案的意思
2. 在都准确答对问题的前提下更好的回答能对知识点进行额外补充且补充的知识准确无误
3. 更好的回答更加符合与人类对话的习惯包括语气情调等
[用户问题]
{question}
[参考答案]
{ref}
""" + base_prompt
language_prompt = """
请根据提供的 评分要求用户问题 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 在有明确的参考答案的情况下越贴近参考答案或表明了参考答案的意思的回答越好
2. 更好的回答在语言表达上更流畅更加符合与人类对话的习惯包括语气情调等
3. 在都准确答对问题的前提下更好的回答能进行额外补充且补充的内容准确无误
[用户问题]
{question}
[参考答案]
{ref}
""" + base_prompt
math_prompt = """
请根据提供的 评分要求用户问题参考答案 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 更好的回答的答案能和参考答案一致
2. 若两个回答的答案都与参考答案不一致则更好的回答的推理过程应更加合理
3. 更好的回答更加符合与人类对话的习惯包括语气情调等
[用户问题]
{question}
[参考答案]
{ref}
""" + base_prompt
reason_prompt = math_prompt
creation_prompt = """
请根据提供的 评分要求用户问题 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 好的回答必须首先符合用户问题里的各种需求不能跑题
2. 好的回答必须具有逻辑连贯性围绕一个中心进行回答
3. 好的回答必须具有创造性的词语和表达丰富度
[用户问题]
{question}
""" + base_prompt
sub_map = {"knowledge": knowledge_prompt, "language": language_prompt, "math_v2": math_prompt, "reason_v2": reason_prompt, "creationv2_zh": creation_prompt}
meta_prompt = """
\n你是一个评判专家请根据提供的 评分要求用户问题 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好\n评分要求重要性依次递减:\n1. 好的回答必须首先符合用户问题里的各种需求不能跑题 \n2. 好的回答必须具有逻辑连贯性围绕一个中心进行回答\n3. 好的回答必须具有创造性的词语和表达丰富度\n\n[用户问题]\n{question}\n[回答1开始]\n{prediction}\n[回答1结束]\n[回答2开始]\n{prediction2}\n[回答2结束]\n此外还有两个其他评判专家的评判意见供你参考\n[评判意见1]\n{judgement}\n[评判意见2]\n{judgement2}\n\n最终请你综合其他评判专家的评判意见与你自己的意见在以下 3 个选项中做出选择:\nA. 回答1更好\nB. 回答2更好\nC. 回答12平局\n并提供你的解释原因\n\n如果你认为回答1更好你的输出应形如\n选择A\n原因blahblah blahblah\n\n\n如果你认为回答2更好你的输出应形如\n选择B\n原因blahblah blahblah\n\n\n如果你认为回答12打成平手你的输出应形如\n选择C\n原因blahblah blahblah\n\n
"""
for _name, _prompt in sub_map.items():
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=2048),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = _prompt
),
]),
),
meta_review_prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = meta_prompt
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=CompassArenaDataset,
path=data_path,
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -1,71 +0,0 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import IRDataset
subjective_reader_cfg = dict(
input_columns=['question', 'capability', 'ref'],
output_column='judge',
)
subjective_all_sets = [
"information_retrieval",
]
data_path ="data/subjective/"
subjective_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=512),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = """为上传的针对给定用户问题的回应撰写评论, 并为该回复打分:
[BEGIN DATA]
***
[用户问询]: {question}
***
[回应]: {prediction}
***
[参考答案]: {ref}
***
[END DATA]
请根据参考答案为这个回应撰写评论. 在这之后, 你应该按照如下格式给这个回应一个最终的1-10范围的评分: "[[评分]]", 例如: "评分: [[5]]"."""
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=IRDataset,
path=data_path,
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -1,59 +0,0 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import IRDataset
subjective_reader_cfg = dict(
input_columns=['question', 'capability', 'gpt4_prefix', 'gpt4_suffix', 'ref'],
output_column='judge',
)
subjective_all_sets = [
"information_retrieval",
]
data_path ="data/subjective/"
subjective_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=512),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = "{gpt4_prefix}{prediction}{gpt4_suffix}"
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=IRDataset,
path=data_path,
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -44,7 +44,7 @@ models = [
meta_template=api_meta_template,
max_out_len=2048,
max_seq_len=4096,
batch_size=1,
batch_size=8,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]
@ -54,7 +54,7 @@ datasets = [*subjective_datasets]
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_model = dict(
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-1106-preview',
@ -65,18 +65,14 @@ judge_model = dict(
max_seq_len=2048,
batch_size=8,
temperature=0,
)
)]
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveNaivePartitioner, mode='singlescore', models=models
),
runner=dict(
type=LocalRunner,
max_num_workers=2,
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
type=SubjectiveSizePartitioner, max_task_size=1000, mode='singlescore', models=models, judge_models=judge_models,
),
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
)
summarizer = dict(type=AlignmentBenchSummarizer, judge_type='general')

View File

@ -47,7 +47,7 @@ models = [
meta_template=api_meta_template,
max_out_len=2048,
max_seq_len=4096,
batch_size=1,
batch_size=8,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]
@ -73,7 +73,7 @@ gpt4 = dict(
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_model = dict(
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-1106-preview',
@ -85,21 +85,20 @@ judge_model = dict(
batch_size=2,
retry=20,
temperature=0,
)
)]
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models
),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=256,
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models,
infer_order='random',
judge_models=judge_models
),
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
given_pred = [{'abbr':'gpt4-turbo', 'path':''}]
)
work_dir = 'outputs/alpaca/'
summarizer = dict(type=AlpacaSummarizer, judge_type='v2')

View File

@ -72,7 +72,7 @@ gpt4 = dict(
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_model = dict(
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-1106-preview',
@ -84,7 +84,7 @@ judge_model = dict(
batch_size=2,
retry=20,
temperature=0,
)
)]
## ------------- Evaluation Configuration
eval = dict(
@ -93,16 +93,13 @@ eval = dict(
strategy='split',
max_task_size=10000,
mode='m2n',
infer_order='double',
base_models=[gpt4],
compare_models=models,
judge_models=judge_models,
),
runner=dict(
type=SlurmSequentialRunner,
partition='llm_dev2',
quotatype='auto',
max_num_workers=32,
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
),
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
given_pred = [{'abbr':'gpt4-turbo', 'path':''}]
)
work_dir = 'outputs/compass_arena_debug/'

View File

@ -63,7 +63,7 @@ infer = dict(
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_model = dict(
judge_models = [dict(
type=HuggingFaceCausalLM,
abbr='pandalm-7b-v1-hf',
path='WeOpenML/PandaLM-7B-v1',
@ -79,12 +79,12 @@ judge_model = dict(
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
)]
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(type=SubjectiveNaivePartitioner, mode='singlescore', models=models),
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
partitioner=dict(type=SubjectiveNaivePartitioner, mode='singlescore', models=models, judge_models=judge_models),
runner=dict(type=LocalRunner, max_num_workers=2, task=dict(type=SubjectiveEvalTask)),
)
summarizer = dict(type=AlignmentBenchSummarizer)

View File

@ -2,7 +2,6 @@ from mmengine.config import read_base
with read_base():
from .datasets.subjective.multiround.mtbench_single_judge_diff_temp import subjective_datasets
# from .datasets.subjective.multiround.mtbench_pair_judge import subjective_datasets
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.models.openai_api import OpenAIAllesAPIN
@ -62,7 +61,7 @@ datasets = [*subjective_datasets]
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_model = dict(
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-0613', # To compare with the official leaderboard, please use gpt4-0613
@ -73,23 +72,12 @@ judge_model = dict(
max_seq_len=2048,
batch_size=8,
temperature=0,
)
## ------------- Evaluation Configuration
# ## pair evaluation
# eval = dict(
# partitioner=dict(
# type=SubjectiveSizePartitioner, max_task_size=100, mode='m2n', base_models=[gpt4], compare_models=models
# ),
# runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
# )
# summarizer = dict(type=MTBenchSummarizer, judge_type='pair')
)]
## single evaluation
eval = dict(
partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models),
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models, judge_models=judge_models),
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask)),
)
summarizer = dict(type=MTBenchSummarizer, judge_type='single')

View File

@ -27,10 +27,12 @@ def extract_dicts(data):
return predictions
def order_preds_and_record_references(predictions,
references,
infer_order,
seed=2680):
def order_preds_and_record_references(
predictions,
references,
infer_order,
seed=666,
):
"""Order predictions based on args and recording regrading references.
Args:
@ -85,17 +87,19 @@ class LMEvaluator:
prompt_template: ConfigDict,
judge_cfg: ConfigDict,
output_path: str,
infer_order: Optional[str] = 'random',
meta_review_prompt_template: Optional[ConfigDict] = None,
dataset_cfg: Optional[ConfigDict] = None,
postprocessor: ConfigDict = dict(type=first_number_postprocess)
) -> None:
assert infer_order in ['random', 'double']
self.output_path = output_path
out_dir, out_name = osp.split(output_path)
if not out_dir:
out_dir = './'
self.prompt_tmpl = ICL_PROMPT_TEMPLATES.build(prompt_template)
if meta_review_prompt_template is not None:
self.meta_review_prompt_tmpl = ICL_PROMPT_TEMPLATES.build(
meta_review_prompt_template)
max_out_len = judge_cfg.get('max_out_len', None)
batch_size = judge_cfg.get('batch_size', None)
@ -108,16 +112,20 @@ class LMEvaluator:
self.postprocessor = get_type_from_cfg(postprocessor)
self.logger = get_logger()
self.dataset_cfg = dataset_cfg
self.infer_order = infer_order
def score(self, predictions, references: Optional[List] = None) -> Dict:
def score(self,
predictions,
judgements: Optional[List] = None,
references: Optional[List] = None,
meta: Optional[bool] = False,
infer_order: Optional[str] = 'random') -> Dict:
dup_indices = []
if type(predictions) == list:
"""Apply to multi-model comparison."""
references = [{} for _ in range(len(predictions[0]['model_preds']))
] if references is None else references
predictions, references = order_preds_and_record_references(
predictions, references, self.infer_order)
predictions, references, infer_order)
# calculate dupicated predictions numbers
total_predictions_num = len(predictions[0])
@ -135,6 +143,9 @@ class LMEvaluator:
] if references is None else references
predictions = [predictions['model_preds']]
# Due to the rarity of identical predictions, we have temporarily disabled the plagiarism detection feature.
dup_indices = []
if len(dup_indices) != 0:
# remove dupicated predictions
for index in sorted(dup_indices, reverse=True):
@ -149,6 +160,14 @@ class LMEvaluator:
for i in range(len(predictions)):
key = 'prediction' if i == 0 else f'prediction{i + 1}'
pred_dict[key] = predictions[i]
if judgements:
for i in range(len(judgements)):
key = 'judgement' if i == 0 else f'judgement{i + 1}'
pred_dict[key] = judgements[i]['model_preds']
for j in range(len(references)):
references[j]['judge_model' +
str(i + 1)] = judgements[i]['model_name']
elif isinstance(
predictions[0][0], list
): #multi round for format like [[[{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}], [{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}]]]
@ -158,11 +177,13 @@ class LMEvaluator:
key = 'prediction' if i == 0 else f'prediction{i}'
key += '_r' + str(j + 1)
pred_dict[key] = multiround_predictions[j]
if judgements:
raise NotImplementedError(
'Not applied meta-reivew judge on multi-round dataset')
if self.dataset_cfg:
dataset = build_dataset_from_cfg(self.dataset_cfg)
if self.infer_order == 'double':
if infer_order == 'double':
new_ds = {
k: dataset.test[k] * 2
for k in dataset.test.column_names
@ -179,7 +200,6 @@ class LMEvaluator:
print(
f'Among total {total_predictions_num} predictions, there are {len(dup_indices)} predictions totally same, which are removed!'
)
for k, v in pred_dict.items():
dataset.reader.dataset['test'] = dataset.test.add_column(k, v)
dataset.reader.input_columns.append(k)
@ -201,8 +221,13 @@ class LMEvaluator:
**pred_dict)
dataset.reader.output_column = 'reference'
retriever = ZeroRetriever(dataset)
self.inferencer.inference(retriever=retriever,
prompt_template=self.prompt_tmpl)
if meta:
self.inferencer.inference(
retriever=retriever,
prompt_template=self.meta_review_prompt_tmpl)
else:
self.inferencer.inference(retriever=retriever,
prompt_template=self.prompt_tmpl)
output = mmengine.load(self.output_path)
return self.postprocess(output)

View File

@ -1,3 +1,4 @@
# flake8: noqa: E501
import inspect
from abc import abstractmethod
from copy import deepcopy
@ -81,11 +82,21 @@ class BasePartitioner:
work_dir=work_dir,
out_dir=self.out_dir,
add_cfg=add_cfg)
self.logger.info(f'Partitioned into {len(tasks)} tasks.')
for i, task in enumerate(tasks):
self.logger.debug(f'Task {i}: {task_abbr_from_cfg(task)}')
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
tasks[0], list):
self.logger.info(
f'Now we are in the subjective evluation! Partitioned into 2 stages and {len(tasks[0])} tasks in first stage, {len(tasks[1])} tasks in second stage.'
)
cnt = 0
for task_part in tasks:
for task in task_part:
self.logger.debug(
f'Task {cnt}: {task_abbr_from_cfg(task)}')
cnt += 1
else:
self.logger.info(f'Partitioned into {len(tasks)} tasks.')
for i, task in enumerate(tasks):
self.logger.debug(f'Task {i}: {task_abbr_from_cfg(task)}')
return tasks
def parse_model_dataset_args(self, cfg: ConfigDict):

View File

@ -1,14 +1,20 @@
# flake8: noqa: E501
import copy
import os.path as osp
from itertools import combinations, product
from typing import Dict, List, Optional, Tuple
from mmengine.config import ConfigDict
from opencompass.registry import PARTITIONERS
from opencompass.utils import (deal_with_judge_model_abbr,
get_infer_output_path, model_abbr_from_cfg)
from .naive import NaivePartitioner
def remove_duplicate_pairs(model_combinations):
# For compare mode, we need to remove redundant pairs first
combo_dict = {}
for i, combo in enumerate(model_combinations):
sorted_names = tuple(sorted((combo[0]['abbr'], combo[1]['abbr'])))
@ -20,6 +26,82 @@ def remove_duplicate_pairs(model_combinations):
return new_model_combinations
def replicate_tasks_with_judge_models(tasks, judge_models, meta_judge_model):
# When all tasks are already partitioned, we add judge_models and meta_judge_model as additional args.
if meta_judge_model:
replicated_tasks = [[], []]
else:
replicated_tasks = []
for task in tasks:
replicated_task_dicts = [task.copy() for _ in range(len(judge_models))]
for idx, replicated_task in enumerate(replicated_task_dicts):
replicated_task['judge_model'] = judge_models[idx]
if meta_judge_model:
meta_task = task.copy()
meta_task['meta_judge_model'] = meta_judge_model
meta_task['judge_models'] = judge_models
replicated_tasks[1].append(meta_task)
replicated_tasks[0].extend(replicated_task_dicts)
else:
replicated_tasks.extend(replicated_task_dicts)
return replicated_tasks
def remove_already_tasks(tasks, work_dir, meta_judge_model):
# Check and remove the already finished subjective evaluation tasks
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
tasks[0], list):
tasks_to_keep = [[], []]
for i in range(2):
for task in tasks[i]:
temp_task = copy.deepcopy(task)
to_delete_index = [
] # To deal with the situation that the partition strategy is not split, which means that there will be a task contains multi dataset, and when we need to re-start, we need to remove the already done tasks.
for idx, dataset in enumerate(task['datasets'][0]):
if i == 0:
filename = get_infer_output_path(
deal_with_judge_model_abbr(task['models'][0],
task['judge_model'],
False), dataset,
osp.join(work_dir, 'results'))
else:
filename = get_infer_output_path(
deal_with_judge_model_abbr(
task['models'][0], task['meta_judge_model'],
True), dataset, osp.join(work_dir, 'results'))
if osp.exists(filename):
to_delete_index.append(idx)
temp_task['datasets'][0] = [
temp_task['datasets'][0][j]
for j in range(len(temp_task['datasets'][0]))
if j not in to_delete_index
]
if len(temp_task['datasets'][0]) != 0:
tasks_to_keep[i].append(temp_task)
else:
tasks_to_keep = []
for task in tasks:
temp_task = copy.deepcopy(task)
to_delete_index = [
] # To deal with the situation that the partition strategy is not split, which means that there will be a task contains multi dataset, and when we need to re-start, we need to remove the already done tasks.
for idx, dataset in enumerate(task['datasets'][0]):
filename = get_infer_output_path(
deal_with_judge_model_abbr(task['models'][0],
task['judge_model']), dataset,
osp.join(work_dir, 'results'))
if osp.exists(filename):
to_delete_index.append(idx)
# Remove the already done tasks
temp_task['datasets'][0] = [
temp_task['datasets'][0][j]
for j in range(len(temp_task['datasets'][0]))
if j not in to_delete_index
]
if len(temp_task['datasets'][0]) != 0:
tasks_to_keep.append(temp_task)
return tasks_to_keep
@PARTITIONERS.register_module()
class SubjectiveNaivePartitioner(NaivePartitioner):
"""Naive task partitioner for subjective evaluation. Compared to
@ -37,15 +119,22 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
models: Optional[List[ConfigDict]] = [],
base_models: Optional[List[ConfigDict]] = [],
compare_models: Optional[List[ConfigDict]] = [],
judge_models: Optional[List[ConfigDict]] = [],
meta_judge_model: Optional[ConfigDict] = None,
model_pairs: Optional[List[Tuple]] = None,
keep_keys: Optional[List[str]] = None):
keep_keys: Optional[List[str]] = None,
infer_order: Optional[str] = 'random'):
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
assert mode in ['singlescore', 'allpair', 'm2n', 'fixed']
assert infer_order in ['random', 'double']
self.mode = mode
self.models = models
self.base_models = base_models
self.compare_models = compare_models
self.model_pairs = model_pairs
self.judge_models = judge_models
self.meta_judge_model = meta_judge_model
self.infer_order = infer_order
def get_model_combinations(
self,
@ -97,14 +186,35 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
"""
models = self.models if self.models != [] else models
base_models, compare_models = self.base_models, self.compare_models
judge_models, meta_judge_model = self.judge_models, self.meta_judge_model
if self.mode == 'singlescore':
models = models
else:
models = self.get_model_combinations(models, base_models,
compare_models)
model_dataset_combinations = [{'models': models, 'datasets': datasets}]
return super().partition(
tasks = super().partition(
model_dataset_combinations=model_dataset_combinations,
work_dir=work_dir,
out_dir=out_dir,
add_cfg=add_cfg)
# We need to add judge models and meta-judge-model as new tasks
# When there is no meta-judge-model, we assign all judge models to each tasks
# When there is a meta-judge-model, we add an additional task stage
tasks = replicate_tasks_with_judge_models(tasks, judge_models,
meta_judge_model)
# We also need to check and remove the already done tasks
tasks = remove_already_tasks(tasks, work_dir, meta_judge_model)
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
tasks[0], list):
# Refer to meta review judge
for task_stage in tasks:
for task in task_stage:
task['infer_order'] = self.infer_order
else:
# Refer to just have review judge
for task in tasks:
task['infer_order'] = self.infer_order
return tasks

View File

@ -1,3 +1,4 @@
# flake8: noqa: E501
import copy
import math
import os.path as osp
@ -11,7 +12,8 @@ from opencompass.registry import PARTITIONERS
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path)
from .sub_naive import SubjectiveNaivePartitioner
from .sub_naive import (SubjectiveNaivePartitioner, remove_already_tasks,
replicate_tasks_with_judge_models)
@PARTITIONERS.register_module()
@ -40,19 +42,25 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
models: Optional[List[ConfigDict]] = [],
base_models: Optional[List[ConfigDict]] = [],
compare_models: Optional[List[ConfigDict]] = [],
judge_models: Optional[List[ConfigDict]] = [],
meta_judge_model: Optional[ConfigDict] = None,
model_pairs: Optional[List[Tuple]] = None,
max_task_size: int = 40000,
gen_task_coef: int = 20,
strategy: str = 'heuristic',
dataset_size_path: str = '.cache/dataset_size.json',
keep_keys: Optional[List[str]] = None):
keep_keys: Optional[List[str]] = None,
infer_order: Optional[str] = 'random'):
super().__init__(out_dir=out_dir,
keep_keys=keep_keys,
mode=mode,
models=models,
base_models=base_models,
compare_models=compare_models,
model_pairs=model_pairs)
judge_models=judge_models,
meta_judge_model=meta_judge_model,
model_pairs=model_pairs,
infer_order=infer_order)
self.max_task_size = max_task_size
self.gen_task_coef = gen_task_coef
self.dataset_size_path = dataset_size_path
@ -96,13 +104,13 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
"""
models = self.models if self.models != [] else models
base_models, compare_models = self.base_models, self.compare_models
judge_models, meta_judge_model = self.judge_models, self.meta_judge_model
if self.mode == 'singlescore':
models = models
else:
models = super().get_model_combinations(models, base_models,
compare_models)
model_dataset_combinations = [{'models': models, 'datasets': datasets}]
tasks = []
for comb in model_dataset_combinations:
comb['datasets'] = sorted(comb['datasets'],
@ -113,8 +121,8 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
for dataset in comb['datasets']:
filename = get_infer_output_path(model, dataset, out_dir)
# skip the task if the task output exists
if osp.exists(filename):
continue
# if osp.exists(filename):
# continue
dataset_size = self.get_cost(dataset)
if dataset_size > self.max_task_size:
root, ext = osp.splitext(filename)
@ -151,6 +159,21 @@ class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
'work_dir': work_dir,
**add_cfg
}))
tasks = replicate_tasks_with_judge_models(tasks, judge_models,
meta_judge_model)
tasks = remove_already_tasks(tasks, work_dir, meta_judge_model)
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
tasks[0], list):
# Refer to meta review judge
for task_stage in tasks:
for task in task_stage:
task['infer_order'] = self.infer_order
else:
# Refer to just have review judge
for task in tasks:
task['infer_order'] = self.infer_order
return tasks
@property

View File

@ -309,7 +309,7 @@ class AlignmentBenchSummarizer:
self.eval_model_abbrs = [
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_models = self.cfg.get('judge_models', None)
self.judge_type = judge_type
assert self.judge_type in [
'general', 'autoj', 'judgelm', 'general_plus'
@ -333,33 +333,36 @@ class AlignmentBenchSummarizer:
Returns:
pd.DataFrame: The summary results.
"""
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
fout_flag, fout_flag2 = 0, 0
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model, judge_model = eval_model_abbr, self.judge_abbr
if self.judge_type == 'general':
fout = osp.join(
output_dir,
'judged-by--' + judge_model + '-dimension.csv')
fout2 = osp.join(
output_dir,
'judged-by--' + judge_model + '-capability.csv')
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, self.judge_function)
for judge_model in self.judge_models:
judge_abbr = model_abbr_from_cfg(judge_model)
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
fout_flag, fout_flag2 = 0, 0
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model = eval_model_abbr
if self.judge_type == 'general':
get_dimension_results(judged_answers, references, fout,
fout_flag, model)
fout_flag += 1
get_capability_results(judged_answers, references, fout2,
fout_flag2, model, self.category)
fout_flag2 += 1
else:
print(subdir_path + ' is not exist! please check!')
fout = osp.join(
output_dir,
'judged-by--' + judge_abbr + '-dimension.csv')
fout2 = osp.join(
output_dir,
'judged-by--' + judge_abbr + '-capability.csv')
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, self.judge_function)
if self.judge_type == 'general':
get_dimension_results(judged_answers, references,
fout, fout_flag, model)
fout_flag += 1
get_capability_results(judged_answers, references,
fout2, fout_flag2, model,
self.category)
fout_flag2 += 1
else:
print(subdir_path + ' is not exist! please check!')
if self.judge_type == 'general':
with open(fout, 'r') as f:
x = from_csv(f)

View File

@ -82,7 +82,8 @@ class AlpacaSummarizer:
self.cfg = config
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_abbr = model_abbr_from_cfg(
self.cfg['judge_models'][0]) # We will reorganize the summarizers
self.judge_type = judge_type
assert self.judge_type in ['v1', 'v2']
self.judge_map = {

View File

@ -67,7 +67,9 @@ class CompassArenaSummarizer:
self.cfg = config
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_models = self.cfg.get('judge_models', None)
self.meta_judge_model = self.cfg.eval.partitioner.get(
'meta_judge_model', None)
self.judge_type = judge_type
assert self.judge_type in ['general']
self.judge_map = {
@ -95,109 +97,135 @@ class CompassArenaSummarizer:
product(self.base_models, self.compare_models))
unique_combinations = remove_duplicate_pairs(
[combo for combo in model_combinations if combo[0] != combo[1]])
judge_model = self.judge_abbr
fout_list = []
for dataset in dataset_cfgs:
dataset_abbr = dataset_abbr_from_cfg(dataset)
fout = osp.join(
output_dir, 'judged-by--' + judge_model + '-' + dataset_abbr +
'-report.csv')
fout_list.append(fout)
for model_pair in unique_combinations:
model1, model2, = model_pair[0]['abbr'], model_pair[1]['abbr'],
subdir = model1 + '_' + model2 + '_judged-by--' + judge_model
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
judged_answers, references = get_judgeanswer_and_reference(
dataset,
subdir_path,
self.judge_function,
)
if self.check_pos_bias:
bias_num = check_position_bias(judged_answers,
references)
else:
bias_num = 0
win_model1, win_model2, categories = defaultdict(
float), defaultdict(float), defaultdict(float)
model1, model2 = references[0]['answer1'], references[0][
'answer2']
for prediction, reference in zip(judged_answers,
references):
if self.summary_type == 'single':
if prediction == 'A':
categories['total'] += 1
categories[reference['capability']] += 1
if reference['answer1'] == model1:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
else:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
elif prediction == 'B':
categories['total'] += 1
categories[reference['capability']] += 1
if reference['answer1'] == model1:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
else:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
elif self.summary_type == 'half_add':
categories['total'] += 1
categories[reference['capability']] += 1
if prediction == 'A':
if reference['answer1'] == model1:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
else:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
elif prediction == 'B':
if reference['answer1'] == model1:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
else:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
elif prediction == 'C':
win_model1[reference['capability']] += 0.5
win_model1['total'] += 0.5
win_model2[reference['capability']] += 0.5
win_model2['total'] += 0.5
for capability in categories:
if capability not in win_model1:
win_model1[capability] = 0.0
else:
win_model1[capability] = round(
(win_model1[capability] /
categories[capability]) * 100, 2)
if capability not in win_model2:
win_model2[capability] = 0.0
else:
win_model2[capability] = round(
(win_model2[capability] /
categories[capability]) * 100, 2)
win_model1['position_bias'] = bias_num
win_model2['position_bias'] = bias_num
scores = {
'win_' + model1: win_model1,
'win_' + model2: win_model2
}
rows = list(scores.keys())
columns = list(scores[rows[0]].keys())
columns.insert(0, columns.pop(columns.index('total')))
columns.insert(1,
columns.pop(columns.index('position_bias')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([model1 + '_vs_' + model2] + columns)
for row in rows:
writer.writerow(
[row] +
[scores[row][column] for column in columns])
pre_len = len(self.judge_models)
if self.meta_judge_model is not None:
self.judge_models.append(self.meta_judge_model)
meta_judge_model_abbr = model_abbr_from_cfg(self.meta_judge_model)
else:
meta_judge_model_abbr = None
for idx, judge_model in enumerate(self.judge_models):
judge_model = model_abbr_from_cfg(judge_model)
for dataset in dataset_cfgs:
dataset_abbr = dataset_abbr_from_cfg(dataset)
if idx == pre_len:
fout = osp.join(
output_dir, 'summarized-by--' + judge_model + '-' +
dataset_abbr + '-report.csv')
else:
print(subdir_path + ' is not exist! please check!')
fout = osp.join(
output_dir, 'judged-by--' + judge_model + '-' +
dataset_abbr + '-report.csv')
fout_list.append(fout)
for model_pair in unique_combinations:
model1, model2, = model_pair[0]['abbr'], model_pair[1][
'abbr'],
if idx == pre_len:
subdir = model1 + '_' + model2 + '_summarized-by--' + judge_model
else:
subdir = model1 + '_' + model2 + '_judged-by--' + judge_model
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
judged_answers, references = get_judgeanswer_and_reference(
dataset,
subdir_path,
self.judge_function,
)
if self.check_pos_bias:
bias_num = check_position_bias(
judged_answers, references)
else:
bias_num = 0
win_model1, win_model2, categories = defaultdict(
float), defaultdict(float), defaultdict(float)
model1, model2 = references[0]['answer1'], references[
0]['answer2']
for prediction, reference in zip(
judged_answers, references):
if self.summary_type == 'single':
if prediction == 'A':
categories['total'] += 1
categories[reference['capability']] += 1
if reference['answer1'] == model1:
win_model1[
reference['capability']] += 1
win_model1['total'] += 1
else:
win_model2[
reference['capability']] += 1
win_model2['total'] += 1
elif prediction == 'B':
categories['total'] += 1
categories[reference['capability']] += 1
if reference['answer1'] == model1:
win_model2[
reference['capability']] += 1
win_model2['total'] += 1
else:
win_model1[
reference['capability']] += 1
win_model1['total'] += 1
elif self.summary_type == 'half_add':
categories['total'] += 1
categories[reference['capability']] += 1
if prediction == 'A':
if reference['answer1'] == model1:
win_model1[
reference['capability']] += 1
win_model1['total'] += 1
else:
win_model2[
reference['capability']] += 1
win_model2['total'] += 1
elif prediction == 'B':
if reference['answer1'] == model1:
win_model2[
reference['capability']] += 1
win_model2['total'] += 1
else:
win_model1[
reference['capability']] += 1
win_model1['total'] += 1
elif prediction == 'C':
win_model1[reference['capability']] += 0.5
win_model1['total'] += 0.5
win_model2[reference['capability']] += 0.5
win_model2['total'] += 0.5
for capability in categories:
if capability not in win_model1:
win_model1[capability] = 0.0
else:
win_model1[capability] = round(
(win_model1[capability] /
categories[capability]) * 100, 2)
if capability not in win_model2:
win_model2[capability] = 0.0
else:
win_model2[capability] = round(
(win_model2[capability] /
categories[capability]) * 100, 2)
win_model1['position_bias'] = bias_num
win_model2['position_bias'] = bias_num
scores = {
'win_' + model1: win_model1,
'win_' + model2: win_model2
}
rows = list(scores.keys())
columns = list(scores[rows[0]].keys())
columns.insert(0, columns.pop(columns.index('total')))
columns.insert(
1, columns.pop(columns.index('position_bias')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([model1 + '_vs_' + model2] +
columns)
for row in rows:
writer.writerow([row] + [
scores[row][column] for column in columns
])
else:
print(subdir_path + ' is not exist! please check!')
for fout in fout_list:
with open(fout, 'r') as f:
x = from_csv(f)

View File

@ -98,7 +98,7 @@ class MTBenchSummarizer(CompassArenaSummarizer):
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner'][
'compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
self.judge_map = {
'single': post_process_mtbench_single,
'pair': post_process_mtbench_pair

View File

@ -1,10 +1,11 @@
# flake8: noqa: E501
import argparse
import copy
import fnmatch
import os.path as osp
import random
import time
from typing import List, Union
from typing import List, Optional, Union
import mmengine
from mmengine.config import Config, ConfigDict
@ -14,6 +15,7 @@ from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
from opencompass.tasks.base import BaseTask
from opencompass.tasks.openicl_eval import extract_role_pred
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
deal_with_judge_model_abbr,
get_infer_output_path, get_logger,
model_abbr_from_cfg, task_abbr_from_cfg)
@ -35,21 +37,25 @@ class SubjectiveEvalTask(BaseTask):
def __init__(self, cfg: ConfigDict):
super().__init__(cfg)
self.logger = get_logger()
judge_cfg = cfg.eval.runner.task.get('judge_cfg', {})
if type(judge_cfg) != ConfigDict:
print('*' * 100)
print('Due to different Judge model needs different summarizer and'
" prompts, we don't support multi judge model evaluation at "
'one time, please do not use list to set your judge cfg, jus'
't use a dict or list[0] should be fine. If you want to eval'
'uation multi judge model in one script, we suggest you to u'
'se a bash or bat script to start multi configs evaluation!')
print('*' * 100)
assert type(judge_cfg) == ConfigDict
judge_cfg = cfg.get('judge_model', None)
meta_judge_cfg = cfg.get('meta_judge_model', None)
judge_models = cfg.get('judge_models', None)
if judge_cfg is None and meta_judge_cfg is None:
assert judge_cfg is not None, 'Both judge_cfg and meta_judge_cfg are None, but judge_models must be provided.'
if meta_judge_cfg is not None:
assert judge_models is not None, 'meta_judge_cfg is provided, but judge_models are missing.'
judge_cfg = meta_judge_cfg # Relpace judge_cfg to meta_judge_cfg when it is not None
self.meta = True
else:
self.meta = False
run_cfg = judge_cfg.get('run_cfg', {})
self.num_gpus = run_cfg.get('num_gpus', 0)
self.num_procs = run_cfg.get('num_procs', 1)
self.judge_cfg = copy.deepcopy(judge_cfg)
self.judge_models = judge_models
self.infer_order = cfg.get('infer_order')
self.given_pred = cfg.eval.get('given_pred', [])
def get_command(self, cfg_path, template):
@ -78,17 +84,15 @@ class SubjectiveEvalTask(BaseTask):
# Load Dataset
eval_cfg = dataset_cfg.get('eval_cfg')
output_column = dataset_cfg['reader_cfg']['output_column']
if type(model_cfg) == ConfigDict:
model_cfg = (model_cfg, )
model_cfg += ({
'abbr':
'judged-by--' + model_abbr_from_cfg(self.judge_cfg)
}, )
out_path = get_infer_output_path(
model_cfg, dataset_cfg, osp.join(self.work_dir, 'results'))
deal_with_judge_model_abbr(model_cfg, self.judge_cfg,
self.meta), dataset_cfg,
osp.join(self.work_dir, 'results'))
if osp.exists(out_path):
continue
self._score(model_cfg, dataset_cfg, eval_cfg, output_column)
self._score(model_cfg, dataset_cfg, eval_cfg, output_column,
self.meta)
def _load_model_pred(
self,
@ -194,7 +198,139 @@ class SubjectiveEvalTask(BaseTask):
'model_preds': pred_strs
}
def _score(self, model_cfg, dataset_cfg, eval_cfg, output_column):
def _load_model_judgements(
self,
model_cfg: Union[ConfigDict, List[ConfigDict]],
dataset_cfg: ConfigDict,
eval_cfg: ConfigDict,
judge_cfg: Union[ConfigDict, List[ConfigDict]],
) -> Union[None, List[str]]:
if isinstance(judge_cfg, (tuple, list)):
return [
self._load_model_judgements(model_cfg, dataset_cfg, eval_cfg,
j) for j in judge_cfg
]
pred_strs = None
model_cfg = [model_cfg] if isinstance(model_cfg,
ConfigDict) else model_cfg
# There will be 5 situations, so we need to deal with them
# 1.There are no partitions in infer and judge stage
# 2.No partition in infer stage, but use partition in judge stage
# 3.Use partition in infer stage, but not use partition in judge stage
# 4.Use both partition, with same partition size
# 5.Use both partition, but different partition size
# If take SubjectSizePartition, get new filename without _0
if 'test_range' in dataset_cfg['reader_cfg']:
filename = get_infer_output_path(
deal_with_judge_model_abbr([m for m in model_cfg], judge_cfg),
dataset_cfg, osp.join(self.work_dir, 'results'))
root, ext = osp.splitext(filename)
last_underscore_index = root.rfind('_')
root = root[:last_underscore_index]
filename = root + ext
# If take SubjectNaivePartition, get filename
else:
filename = get_infer_output_path(
deal_with_judge_model_abbr([m for m in model_cfg], judge_cfg),
dataset_cfg, osp.join(self.work_dir, 'results'))
# Get partition name
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
# If no predictions get in predictions dir
if not osp.exists(osp.realpath(filename)) and not osp.exists(
osp.realpath(partial_filename)):
return {'error': 'No judgements found.'}
else:
# If use Naive partition in infer stage
if osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
pred_strs = [
preds[str(i)]['prediction'] for i in range(len(preds))
]
# If use Size partition in infer stage
else:
filename = partial_filename
pred_strs = []
i = 1
while osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
filename = root + f'_{i}' + ext
i += 1
pred_strs += [
preds[str(i)]['prediction'] for i in range(len(preds))
]
# Get all judgements in pred_strs
# If take SubjectSizePartition, get new pred_strs based on test_range
if 'test_range' in dataset_cfg['reader_cfg']:
test_range = dataset_cfg['reader_cfg']['test_range']
if self.infer_order == 'double':
# When set infer_order as double, we need to select the judgements to meet the predctions which will be doubled later
start = 0
end = None
pred_strs_length = len(pred_strs)
# Split the string by the ':', the test_range is a string shapes like '[0:15]'
parts = test_range.strip('[]').split(':')
# Check if the start index is provided
if parts[0]:
start = int(parts[0])
# Check if the end index is provided
if len(parts) > 1 and parts[1]:
end = int(parts[1])
else:
# If the end is not provided, determine the default end based on the length of 'pred_strs'
end = int(pred_strs_length / 2)
assert pred_strs_length % 2 == 0, "Since you have set the infer_order as 'double', the length of 'pred_strs' must be even."
assert end <= pred_strs_length / 2, "The 'end' value must not exceed half of the 'pred_strs' length."
# Reset the newly start and end
start *= 2
end *= 2
pred_strs = eval('pred_strs[' + str(start) + ':' + str(end) +
']')
else:
pred_strs = eval('pred_strs' + test_range)
# If take SubjectNaivePartition, get all pred_strs
else:
pred_strs = pred_strs
if ('pred_role' in eval_cfg and 'meta_template' in judge_cfg
and not MODELS.get(judge_cfg['type']).is_api
and isinstance(pred_strs[0], str)):
# Create a prompt template for role config parsing
from opencompass.models.base import LMTemplateParser
parser = LMTemplateParser(judge_cfg['meta_template'])
role = parser.roles[eval_cfg['pred_role']]
pred_strs = [
extract_role_pred(pred, role.get('begin', None),
role.get('end', None)) for pred in pred_strs
]
# Postprocess predictions if necessary
ds_abbr = dataset_abbr_from_cfg(dataset_cfg)
model_postprocessors = judge_cfg.get('pred_postprocessor', {})
pred_postprocessor = None
for pattern in model_postprocessors.keys():
if fnmatch.fnmatch(ds_abbr, pattern):
pred_postprocessor = model_postprocessors[pattern]
break
if 'pred_postprocessor' in eval_cfg or pred_postprocessor:
kwargs = pred_postprocessor or eval_cfg['pred_postprocessor']
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
pred_strs = [proc(s, **kwargs) for s in pred_strs]
return {
'model_name': model_abbr_from_cfg(judge_cfg),
'model_preds': pred_strs
}
def _score(self,
model_cfg,
dataset_cfg,
eval_cfg,
output_column,
meta=False):
test_set = build_dataset_from_cfg(dataset_cfg).test
# Postprocess dataset if necessary
if 'dataset_postprocessor' in eval_cfg:
@ -208,27 +344,32 @@ class SubjectiveEvalTask(BaseTask):
test_set = test_set.map(postprocess)
# Get out_path
out_path = get_infer_output_path(model_cfg, dataset_cfg,
osp.join(self.work_dir, 'results'))
new_model_cfg = []
for m_cfg in model_cfg:
if len(m_cfg) > 1:
new_model_cfg.append(m_cfg)
if len(new_model_cfg) == 1:
new_model_cfg = new_model_cfg[0]
model_preds = self._load_model_pred(new_model_cfg, dataset_cfg,
eval_cfg, self.given_pred)
out_path = get_infer_output_path(
deal_with_judge_model_abbr(model_cfg, self.judge_cfg, self.meta),
dataset_cfg, osp.join(self.work_dir, 'results'))
if meta:
model_preds = self._load_model_pred(model_cfg, dataset_cfg,
eval_cfg, self.given_pred)
model_judges = self._load_model_judgements(model_cfg, dataset_cfg,
eval_cfg,
self.judge_models)
else:
model_preds = self._load_model_pred(model_cfg, dataset_cfg,
eval_cfg, self.given_pred)
model_judges = None
if not self.judge_cfg:
raise ValueError('missing "eval.runner.task.judge_cfg"')
raise ValueError('missing "eval.judge_cfg"')
eval_cfg['evaluator']['judge_cfg'] = self.judge_cfg
eval_cfg['evaluator']['dataset_cfg'] = dataset_cfg
eval_cfg['evaluator']['output_path'] = out_path
icl_evaluator = ICL_EVALUATORS.build(eval_cfg['evaluator'])
references = (test_set[output_column] if output_column else None)
if 'error' not in model_preds:
result = icl_evaluator.score(predictions=model_preds,
references=references)
judgements=model_judges,
references=references,
meta=meta,
infer_order=self.infer_order)
else:
result = model_preds
@ -259,17 +400,24 @@ class SubjectiveEvalTask(BaseTask):
output_paths = []
for model, datasets in zip(self.model_cfgs, self.dataset_cfgs):
for dataset in datasets:
if type(model) == ConfigDict:
if isinstance(model, ConfigDict):
model = (model, )
model += ({
'abbr':
'judged-by--' + model_abbr_from_cfg(self.judge_cfg)
}, )
if self.meta:
model += ({
'abbr':
'summarized-by--' + model_abbr_from_cfg(self.judge_cfg)
}, )
else:
model += ({
'abbr':
'judged-by--' + model_abbr_from_cfg(self.judge_cfg)
}, )
output_paths.append(
get_infer_output_path(
model, dataset,
osp.join(self.work_dir, self.output_subdir),
file_extension))
model = model[:-1]
return output_paths

View File

@ -46,3 +46,25 @@ def get_infer_output_path(model_cfg: ConfigDict,
model_abbr = model_abbr_from_cfg(model_cfg)
dataset_abbr = dataset_abbr_from_cfg(dataset_cfg)
return osp.join(root_path, model_abbr, f'{dataset_abbr}.{file_extension}')
def deal_with_judge_model_abbr(model_cfg, judge_model_cfg, meta=False):
if isinstance(model_cfg, ConfigDict):
model_cfg = (model_cfg, )
if meta:
for m_cfg in model_cfg:
if 'summarized-by--' in m_cfg['abbr']:
return model_cfg
model_cfg += ({
'abbr':
'summarized-by--' + model_abbr_from_cfg(judge_model_cfg)
}, )
else:
for m_cfg in model_cfg:
if 'judged-by--' in m_cfg['abbr']:
return model_cfg
model_cfg += ({
'abbr':
'judged-by--' + model_abbr_from_cfg(judge_model_cfg)
}, )
return model_cfg

9
run.py
View File

@ -341,7 +341,14 @@ def main():
if args.dry_run:
return
runner = RUNNERS.build(cfg.eval.runner)
runner(tasks)
# For meta-review-judge in subjective evaluation
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
tasks[0], list):
for task_part in tasks:
runner(task_part)
else:
runner(tasks)
# visualize
if args.mode in ['all', 'eval', 'viz']: