[Sync] format (#1214)

This commit is contained in:
Fengzhe Zhou 2024-05-30 00:21:58 +08:00 committed by GitHub
parent d59189b87f
commit a77b8a5cec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 561 additions and 9 deletions

View File

@ -0,0 +1,58 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CompassBenchDataset
subjective_reader_cfg = dict(
input_columns=['question', 'judge_prompt'],
output_column='judge',
)
data_path ='data/subjective/compassbench'
subjective_datasets = []
versions = ['CompassbenchV1']
for version_abbr in versions:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt='{question}'
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=2048),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = '{judge_prompt}'
),
]),
),
),
pred_role='BOT',
)
subjective_datasets.append(
dict(
abbr=version_abbr,
type=CompassBenchDataset,
path=data_path,
name=version_abbr,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,137 @@
from os import getenv as gv
from opencompass.models import HuggingFaceCausalLM
from mmengine.config import read_base
with read_base():
from .datasets.subjective.compassbench.compassbench_compare import subjective_datasets
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import CompassBenchSummarizer
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
],
reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
)
# -------------Inference Stage ----------------------------------------
from opencompass.models import HuggingFacewithChatTemplate
models = [
dict(
type=HuggingFacewithChatTemplate,
abbr='internlm2-chat-7b-hf',
path='internlm/internlm2-chat-7b',
max_out_len=1024,
batch_size=8,
run_cfg=dict(num_gpus=1),
stop_words=['</s>', '<|im_end|>'],
generation_kwargs=dict(
do_sample=True,
),
)
]
datasets = [*subjective_datasets]
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='reserved',
max_num_workers=256,
task=dict(type=OpenICLInferTask),
),
)
gpt4 = dict(
abbr='gpt4-turbo',
type=OpenAI,
path='gpt-4-1106-preview',
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048,
max_seq_len=4096,
batch_size=4,
retry=20,
temperature=1,
) # Re-inference gpt4's predictions or you can choose to use the pre-commited gpt4's predictions
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-1106-preview',
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=1024,
max_seq_len=4096,
batch_size=2,
retry=20,
temperature=0,
)]
judge_models = [
dict(
type=HuggingFacewithChatTemplate,
abbr='internlm102b',
path='/mnt/petrelfs/caomaosong/backup_hwfile/100bjudge_6w_epoch1/hf',
max_out_len=1024,
batch_size=8,
run_cfg=dict(num_gpus=4),
stop_words=['</s>', '<|im_end|>'],
),
dict(
type=HuggingFacewithChatTemplate,
abbr='internlm102b2',
path='/mnt/petrelfs/caomaosong/backup_hwfile/100bjudge_6w_epoch1/hf',
max_out_len=1024,
batch_size=8,
run_cfg=dict(num_gpus=4),
stop_words=['</s>', '<|im_end|>'],
),
dict(
type=HuggingFacewithChatTemplate,
abbr='internlm102b3',
path='/mnt/petrelfs/caomaosong/backup_hwfile/100bjudge_6w_epoch1/hf',
max_out_len=1024,
batch_size=8,
run_cfg=dict(num_gpus=4),
stop_words=['</s>', '<|im_end|>'],
)
]
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
strategy='split',
max_task_size=10000000,
mode='m2n',
infer_order='double',
base_models=[gpt4],
compare_models=models,
judge_models=judge_models,
),
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask)),
#given_pred = [{'abbr':'gpt4-turbo', 'path':''}]
)
work_dir = 'outputs/compassbench/'
summarizer = dict(type=CompassBenchSummarizer, summary_type='half_add')

View File

@ -20,16 +20,16 @@ prompts = [
] ]
charm_reaso_summary_groups = [] charm_reason_summary_groups = []
for prompt in prompts: for prompt in prompts:
for region in regions: for region in regions:
subsets = ['charm-reason-' + region + '_' + task + '_' + prompt for task in charm_tasks] subsets = ['charm-reason-' + region + '_' + task + '_' + prompt for task in charm_tasks]
charm_reaso_summary_groups.append({'name': 'charm-reason-' + region + '_' + prompt, 'subsets': subsets}) charm_reason_summary_groups.append({'name': 'charm-reason-' + region + '_' + prompt, 'subsets': subsets})
for prompt in prompts: for prompt in prompts:
subsets = ['charm-reason-' + region + '_' + prompt for region in regions] subsets = ['charm-reason-' + region + '_' + prompt for region in regions]
charm_reaso_summary_groups.append({'name': 'charm-reason-' + prompt, 'subsets': subsets}) charm_reason_summary_groups.append({'name': 'charm-reason-' + prompt, 'subsets': subsets})
charm_reaso_summary_groups.append( charm_reason_summary_groups.append(
{'name': 'charm-reason-CoT', 'subsets': ['charm-reason-ZH-CoT', 'charm-reason-EN-CoT']} {'name': 'charm-reason-CoT', 'subsets': ['charm-reason-ZH-CoT', 'charm-reason-EN-CoT']}
) )

View File

@ -1,6 +1,7 @@
from .alignbench import AlignmentBenchDataset # noqa: F401, F403 from .alignbench import AlignmentBenchDataset # noqa: F401, F403
from .arena_hard import ArenaHardDataset # noqa: F401, F403 from .arena_hard import ArenaHardDataset # noqa: F401, F403
from .compass_arena import CompassArenaDataset # noqa: F401, F403 from .compass_arena import CompassArenaDataset # noqa: F401, F403
from .compassbench import CompassBenchDataset # noqa: F401, F403
from .corev2 import Corev2Dataset # noqa: F401, F403 from .corev2 import Corev2Dataset # noqa: F401, F403
from .creationbench import CreationBenchDataset # noqa: F401, F403 from .creationbench import CreationBenchDataset # noqa: F401, F403
from .information_retrival import IRDataset # noqa: F401, F403 from .information_retrival import IRDataset # noqa: F401, F403

View File

@ -0,0 +1,101 @@
# flake8: noqa
import json
import os.path as osp
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
base_prompt_zh = """请根据 用户问题 以及 相应的两个回答,判断哪一个回答更好。
[用户问题]
{question}
[回答1开始]
{prediction}
[回答1结束]
[回答2开始]
{prediction2}
[回答2结束]
根据评分要求请先对两个回答进行评价最后在以下 3 个选项中做出选择:
A. 回答1更好
B. 回答2更好
C. 回答12平局
如果你认为回答1更好你的输出应形如
评价1回答1 xxx
评价2回答2 xxx
选择[[A]]
如果你认为回答2更好你的输出应形如
评价1回答1 xxx
评价2回答2 xxx
选择[[B]]
如果你认为回答12打成平手你的输出应形如
评价1回答1 xxx
评价2回答2 xxx
选择[[C]]
"""
base_prompt_en = """Please evaluate the two responses based on the user's question and then choose from the following three options:
A. Response 1 is better
B. Response 2 is better
C. Both responses are equal
[user's question]
{question}
[Response 1 Start]
{prediction}
[Response 1 End]
[Response 2 Start]
{prediction2}
[Response 2 End]
If you believe that Response 1 is better, your output should be formatted as follows:
Evaluation 1: Response 1 xxx
Evaluation 2: Response 2 xxx
Choice: [[A]]
If you believe that Response 2 is better, your output should be formatted as follows:
Evaluation 1: Response 1 xxx
Evaluation 2: Response 2 xxx
Choice: [[B]]
If you believe that both responses are equally good, your output should be formatted as follows:
Evaluation 1: Response 1 xxx
Evaluation 2: Response 2 xxx
Choice: [[C]]
"""
@LOAD_DATASET.register_module()
class CompassBenchDataset(BaseDataset):
def load(self, path: str, name: str):
filename = osp.join(path, f'{name}.json')
raw_data = []
with open(filename, 'r', encoding='utf-8') as f:
json_data = json.load(f)
for problem in json_data:
question = problem['question']
lan = problem['language']
others = problem['others']
judge_prompt = base_prompt_zh if lan == 'zh' else base_prompt_en
raw_data.append({
'question': question,
'judge_prompt': judge_prompt,
'judge': {
'lan': lan,
'level': others['level'],
'category': problem['category'],
'question': question
}
})
dataset = Dataset.from_list(raw_data)
return dataset

View File

@ -4,6 +4,7 @@ from .all_obj import AllObjSummarizer
from .alpacaeval import AlpacaSummarizer from .alpacaeval import AlpacaSummarizer
from .arenahard import ArenaHardSummarizer from .arenahard import ArenaHardSummarizer
from .compass_arena import CompassArenaSummarizer from .compass_arena import CompassArenaSummarizer
from .compassbench import CompassBenchSummarizer
from .corev2 import Corev2Summarizer from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer from .creationbench import CreationBenchSummarizer
from .flames import FlamesSummarizer from .flames import FlamesSummarizer

View File

@ -0,0 +1,241 @@
# flake8: noqa
# yapf: disable
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
from itertools import product
import mmengine
from mmengine import ConfigDict
from tabulate import tabulate
from opencompass.partitioners.sub_naive import remove_duplicate_pairs
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
from .utils import get_judgeanswer_and_reference, get_outdir
def model_abbr_from_cfg_used_in_summarizer(model):
if model.get('summarizer_abbr', None):
return model['summarizer_abbr']
else:
return model_abbr_from_cfg(model)
def post_process_compass_arena(s):
if result := re.findall(r'(?:选择:|Choice: )\[\[([ABC])\]\]', s):
return result[0]
else:
return None
def check_position_bias(judged_answers, references, banned_choice=['C']):
"""Check position bias for judgellm's judgement.
Args:
judged_answers: The successfully extracted judgement.
references: The references contains original question, which is used to located the same question for different position judgement.
"""
position_bias_flag = 0
position_bias_dict = {}
for judge, ref in zip(judged_answers, references):
question = ref['question']
question_hash = hash(question)
if question_hash not in position_bias_dict:
position_bias_dict[question_hash] = {
'question': question,
'judge': judge
}
else:
first_judge = position_bias_dict[question_hash]['judge']
if judge == first_judge and first_judge not in banned_choice and judge not in banned_choice:
# If second choice is same with first choice, there has position bias.
position_bias_flag += 1
return position_bias_flag
class CompassBenchSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self,
config: ConfigDict,
judge_type='general',
check_pos_bias=True,
summary_type='single') -> None:
self.tasks = []
self.cfg = config
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
self.judge_models = self.cfg.get('judge_models', None)
self.meta_judge_model = self.cfg.eval.partitioner.get('meta_judge_model', None)
self.judge_type = judge_type
assert self.judge_type in ['general']
self.judge_map = {'general': post_process_compass_arena}
self.judge_function = self.judge_map[self.judge_type]
self.check_pos_bias = check_pos_bias
self.summary_type = summary_type
def get_score(self, time_str):
output_dir, results_folder = get_outdir(self.cfg, time_str)
model_combinations = list(product(self.base_models, self.compare_models))
unique_combinations = remove_duplicate_pairs([combo for combo in model_combinations if combo[0] != combo[1]])
if self.meta_judge_model is not None:
self.judge_models.append(self.meta_judge_model)
scores = {}
for idx, judge_model_cfg in enumerate(self.judge_models):
judge_model = model_abbr_from_cfg(judge_model_cfg)
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
for model_pair in unique_combinations:
model1 = model_pair[0]['abbr']
model2 = model_pair[1]['abbr']
if idx == len(self.judge_models):
subdir = model1 + '_' + model2 + '_summarized-by--' + judge_model
else:
subdir = model1 + '_' + model2 + '_judged-by--' + judge_model
subdir_path = os.path.join(results_folder, subdir)
if not os.path.isdir(subdir_path):
print(subdir_path + ' is not exist! please check!')
continue
judged_answers, references = get_judgeanswer_and_reference(dataset, subdir_path, self.judge_function)
if self.check_pos_bias:
bias_num = check_position_bias(judged_answers, references)
else:
bias_num = 0
win_model1 = defaultdict(float)
win_model2 = defaultdict(float)
categories = defaultdict(float)
difficulties = defaultdict(float)
model1 = references[0]['answer1']
model2 = references[0]['answer2']
for prediction, reference in zip(judged_answers, references):
categories[dataset_abbr] += 1
categories[reference['category']] += 1
difficulties[reference['level']] += 1
if prediction == 'A':
if reference['answer1'] == model1:
score_1, score_2 = 1, 0
else:
score_1, score_2 = 0, 1
elif prediction == 'B':
if reference['answer1'] == model1:
score_1, score_2 = 0, 1
else:
score_1, score_2 = 1, 0
elif prediction == 'C':
if self.summary_type == 'half_add':
score_1, score_2 = 0.5, 0.5
else:
score_1, score_2 = 0, 0
win_model1[reference['category']] += score_1
win_model1[dataset_abbr] += score_1
win_model2[reference['category']] += score_2
win_model2[dataset_abbr] += score_2
for category in categories:
win_model1[category] = win_model1[category] / categories[category] * 100
win_model1[category] = round(win_model1[category], 2)
win_model2[category] = win_model2[category] / categories[category] * 100
win_model2[category] = round(win_model2[category], 2)
win_model1['position_bias'] = bias_num
win_model2['position_bias'] = bias_num
if judge_model not in scores:
scores[judge_model] = {}
if dataset_abbr not in scores[judge_model]:
scores[judge_model][dataset_abbr] = {}
scores[judge_model][dataset_abbr][model2] = win_model2
return scores
def summarize(
self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S'),
):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
scores = self.get_score(time_str)
# scores['win_' + model1] = win_model1
output_dir, results_folder = get_outdir(self.cfg, time_str)
for idx, judge_model in enumerate(self.judge_models):
judge_abbr = model_abbr_from_cfg(judge_model)
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
summarizer_model_abbrs = [model_abbr_from_cfg_used_in_summarizer(i) for i in self.compare_models]
one_column = list(scores[judge_abbr][dataset_abbr].values())[0]
row_headers = [i for i in one_column.keys() if i not in [dataset_abbr, 'position_bias']]
row_headers = [dataset_abbr, 'position_bias'] + row_headers
headers = [''] + summarizer_model_abbrs
table = []
for row_header in row_headers:
row = [row_header]
for model_cfg in self.compare_models:
model_abbr = model_abbr_from_cfg(model_cfg)
s = scores[judge_abbr][dataset_abbr][model_abbr].get(row_header, '')
if isinstance(s, float):
s = f'{s:.2f}'
if isinstance(s, int):
s = str(s)
row.append(s)
table.append(row)
txt = tabulate(table, headers=headers)
print(txt)
if idx == len(self.judge_models):
output_filename = osp.join(output_dir, 'summarized-by--' + judge_abbr + '-' + dataset_abbr + '-report.csv')
else:
output_filename = osp.join(output_dir, 'judged-by--' + judge_abbr + '-' + dataset_abbr + '-report.csv')
with open(output_filename, 'w') as f:
f.write(','.join(headers) + '\n')
for line in table:
f.write(','.join(line) + '\n')
print(output_filename)
table = []
summarizer_model_abbrs = [model_abbr_from_cfg_used_in_summarizer(i) for i in self.compare_models]
headers = [''] + summarizer_model_abbrs
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
row = [dataset_abbr]
for model_cfg in self.compare_models:
model_abbr = model_abbr_from_cfg(model_cfg)
s = scores[judge_abbr][dataset_abbr][model_abbr].get(dataset_abbr, '')
if isinstance(s, float):
s = f'{s:.2f}'
if isinstance(s, int):
s = str(s)
row.append(s)
table.append(row)
txt = tabulate(table, headers=headers)
print(txt)
if idx == len(self.judge_models):
output_filename = osp.join(output_dir, 'summarized-by--' + judge_abbr + '-overall-report.csv')
else:
output_filename = osp.join(output_dir, 'judged-by--' + judge_abbr + '-overall-report.csv')
with open(output_filename, 'w') as f:
f.write(','.join(headers) + '\n')
for line in table:
f.write(','.join(line) + '\n')
print(output_filename)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
import re
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Union from typing import Dict, List, Union
@ -19,9 +20,15 @@ def safe_format(input_str: str, **kwargs) -> str:
Returns: Returns:
str: The formatted string. str: The formatted string.
""" """
segs = [input_str]
for k, v in kwargs.items(): for k, v in kwargs.items():
input_str = input_str.replace(f'{{{k}}}', str(v)) regex = re.compile(f'(?<={{{k}}})(?={{{k}}})|({{{k}}})')
return input_str segs = [regex.split(seg) for seg in segs]
segs = sum(segs, [])
replace_dict = {f'{{{k}}}': str(v) for k, v in kwargs.items()}
segs = [replace_dict.get(seg, seg) for seg in segs]
output_str = ''.join(segs)
return output_str
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:

View File

@ -86,8 +86,14 @@ def get_config_from_arg(args) -> Config:
config['models'] = change_accelerator(config['models'], args.accelerator) config['models'] = change_accelerator(config['models'], args.accelerator)
if config.get('eval', {}).get('partitioner', {}).get('models') is not None: if config.get('eval', {}).get('partitioner', {}).get('models') is not None:
config['eval']['partitioner']['models'] = change_accelerator(config['eval']['partitioner']['models'], args.accelerator) config['eval']['partitioner']['models'] = change_accelerator(config['eval']['partitioner']['models'], args.accelerator)
if config.get('eval', {}).get('partitioner', {}).get('base_models') is not None:
config['eval']['partitioner']['base_models'] = change_accelerator(config['eval']['partitioner']['base_models'], args.accelerator)
if config.get('eval', {}).get('partitioner', {}).get('compare_models') is not None:
config['eval']['partitioner']['compare_models'] = change_accelerator(config['eval']['partitioner']['compare_models'], args.accelerator)
if config.get('eval', {}).get('partitioner', {}).get('judge_models') is not None: if config.get('eval', {}).get('partitioner', {}).get('judge_models') is not None:
config['eval']['partitioner']['judge_models'] = change_accelerator(config['eval']['partitioner']['judge_models'], args.accelerator) config['eval']['partitioner']['judge_models'] = change_accelerator(config['eval']['partitioner']['judge_models'], args.accelerator)
if config.get('judge_models', {}) is not None:
config['judge_models'] = change_accelerator(config['judge_models'], args.accelerator)
return config return config
# parse dataset args # parse dataset args
@ -211,7 +217,7 @@ def change_accelerator(models, accelerator):
mod = TurboMindModel mod = TurboMindModel
acc_model = dict( acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}', type=f'{mod.__module__}.{mod.__name__}',
abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy', abbr=model['abbr'].replace('hf', 'turbomind') if '-hf' in model['abbr'] else model['abbr'] + '-turbomind',
path=model['path'], path=model['path'],
engine_config=dict(session_len=model['max_seq_len'], engine_config=dict(session_len=model['max_seq_len'],
max_batch_size=model['batch_size'], max_batch_size=model['batch_size'],
@ -254,7 +260,7 @@ def change_accelerator(models, accelerator):
mod = VLLMwithChatTemplate mod = VLLMwithChatTemplate
acc_model = dict( acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}', type=f'{mod.__module__}.{mod.__name__}',
abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-vllm', abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
path=model['path'], path=model['path'],
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']), model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']),
max_out_len=model['max_out_len'], max_out_len=model['max_out_len'],
@ -266,7 +272,7 @@ def change_accelerator(models, accelerator):
mod = TurboMindModelwithChatTemplate mod = TurboMindModelwithChatTemplate
acc_model = dict( acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}', type=f'{mod.__module__}.{mod.__name__}',
abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-turbomind', abbr=model['abbr'].replace('hf', 'turbomind') if '-hf' in model['abbr'] else model['abbr'] + '-turbomind',
path=model['path'], path=model['path'],
engine_config=dict(max_batch_size=model.get('batch_size', 16), tp=model['run_cfg']['num_gpus']), engine_config=dict(max_batch_size=model.get('batch_size', 16), tp=model['run_cfg']['num_gpus']),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9), gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9),