[Feature] Add CompassArena (#828)

* add compass arena

* add compass_arena

* add compass arena

* Update opencompass/summarizers/subjective/compass_arena.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update opencompass/summarizers/subjective/__init__.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update opencompass/datasets/subjective/compass_arena.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update opencompass/datasets/subjective/__init__.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update configs/eval_subjective_compassarena.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update configs/datasets/subjective/compassarena/compassarena_compare.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update configs/eval_subjective_compassarena.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update configs/datasets/subjective/compassarena/compassarena_compare.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* fix check position bias

---------

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
This commit is contained in:
bittersweet1999 2024-01-23 15:12:46 +08:00 committed by GitHub
parent 40a2441deb
commit 2d4da8dd02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 461 additions and 1 deletions

View File

@ -0,0 +1,160 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CompassArenaDataset
subjective_reader_cfg = dict(
input_columns=['question', 'ref'],
output_column='judge',
)
data_path ="data/subjective/"
subjective_datasets = []
base_prompt = """
[回答1开始]
{prediction}
[回答1结束]
[回答2开始]
{prediction2}
[回答2结束]
根据评分要求在以下 3 个选项中做出选择:
A. 回答1更好
B. 回答2更好
C. 回答12平局
并提供你的解释原因
如果你认为回答1更好你的输出应形如
选择A
原因blahblah blahblah\n
如果你认为回答2更好你的输出应形如
选择B
原因blahblah blahblah\n
如果你认为回答12打成平手你的输出应形如
选择C
原因blahblah blahblah\n
"""
knowledge_prompt = """
请根据提供的 评分要求用户问题参考答案 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 更好的回答能与参考答案吻合或表明参考答案的意思
2. 在都准确答对问题的前提下更好的回答能对知识点进行额外补充且补充的知识准确无误
3. 更好的回答更加符合与人类对话的习惯包括语气情调等
[用户问题]
{question}
[参考答案]
{ref}
""" + base_prompt
language_prompt = """
请根据提供的 评分要求用户问题 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 在有明确的参考答案的情况下越贴近参考答案或表明了参考答案的意思的回答越好
2. 更好的回答在语言表达上更流畅更加符合与人类对话的习惯包括语气情调等
3. 在都准确答对问题的前提下更好的回答能进行额外补充且补充的内容准确无误
[用户问题]
{question}
[参考答案]
{ref}
""" + base_prompt
math_prompt = """
请根据提供的 评分要求用户问题参考答案 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 更好的回答的答案能和参考答案一致
2. 若两个回答的答案都与参考答案不一致则更好的回答的推理过程应更加合理
3. 更好的回答更加符合与人类对话的习惯包括语气情调等
[用户问题]
{question}
[参考答案]
{ref}
""" + base_prompt
reason_prompt = math_prompt
qa_prompt = """
请根据提供的 评分要求用户问题 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 好的回答必须首先具有事实正确性即除了想象的内容外所引用或阐述的各种信息都是真实正确的
2. 好的回答必须具有逻辑连贯性围绕一个中心进行回答且前后连贯逻辑没有问题
3. 在都准确答对问题的前提下更好的回答能进行额外补充且补充的内容准确无误
[用户问题]
{question}
""" + base_prompt
creation_prompt = """
请根据提供的 评分要求用户问题 以及 相应的两个回答回答1回答2判断两个回答中哪一个更好
评分要求重要性依次递减:
1. 好的回答必须首先符合用户问题里的各种需求不能跑题
2. 好的回答必须具有逻辑连贯性围绕一个中心进行回答
3. 好的回答必须具有创造性的词语和表达丰富度
[用户问题]
{question}
""" + base_prompt
subjective_all_sets = ["knowledge", "language", "math", "reason", "qa", "creationv2_zh"]
prompt_all_sets = [knowledge_prompt, language_prompt, math_prompt, reason_prompt, qa_prompt, creation_prompt]
for _name,_prompt in zip(subjective_all_sets, prompt_all_sets):
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_seq_len=4096, max_out_len=2048),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
infer_order='double',
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = _prompt
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=CompassArenaDataset,
path=data_path,
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,95 @@
from os import getenv as gv
from opencompass.models import HuggingFaceCausalLM
from mmengine.config import read_base
with read_base():
from .models.chatglm.hf_chatglm3_6b_32k import models as chatglm3_6b_32k_model
from .models.yi.hf_yi_6b_chat import models as yi_6b_chat_model
from .datasets.subjective.compassarena.compassarena_compare import subjective_datasets
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.models.openai_api import OpenAIAllesAPIN
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import CompassArenaSummarizer
infer = dict(
#partitioner=dict(type=NaivePartitioner),
partitioner=dict(type=SizePartitioner, max_task_size=10000),
runner=dict(
type=SlurmSequentialRunner,
partition='llm_dev2',
quotatype='auto',
max_num_workers=256,
task=dict(type=OpenICLInferTask)),
)
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
gpt4 = dict(
abbr='gpt4-turbo',
type=OpenAI, path='gpt-4-1106-preview',
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048,
max_seq_len=4096,
batch_size=4,
retry=20,
temperature = 1
)
models = [*chatglm3_6b_32k_model, *yi_6b_chat_model]
datasets = [*subjective_datasets]
work_dir = 'outputs/compass_arena/'
# -------------Inferen Stage ----------------------------------------
judge_model = dict(
abbr='GPT4-Turbo',
type=OpenAI, path='gpt-4-1106-preview',
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=1024,
max_seq_len=4096,
batch_size=2,
retry=20,
temperature = 0
)
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
strategy='split',
max_task_size=10000,
mode='m2n',
base_models = [gpt4],
compare_models = [*chatglm3_6b_32k_model, *yi_6b_chat_model, ]
),
runner=dict(
type=SlurmSequentialRunner,
partition='llm_dev2',
quotatype='auto',
max_num_workers=32,
task=dict(
type=SubjectiveEvalTask,
judge_cfg=judge_model
)),
)
summarizer = dict(
type=CompassArenaSummarizer
)

View File

@ -1,4 +1,5 @@
from .alignbench import AlignmentBenchDataset # noqa: F401, F403 from .alignbench import AlignmentBenchDataset # noqa: F401, F403
from .compass_arena import CompassArenaDataset # noqa: F401, F403
from .corev2 import Corev2Dataset # noqa: F401, F403 from .corev2 import Corev2Dataset # noqa: F401, F403
from .creationbench import CreationBenchDataset # noqa: F401, F403 from .creationbench import CreationBenchDataset # noqa: F401, F403
from .information_retrival import IRDataset # noqa: F401, F403 from .information_retrival import IRDataset # noqa: F401, F403

View File

@ -0,0 +1,28 @@
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .subjective_cmp import SubjectiveCmpDataset
@LOAD_DATASET.register_module()
class CompassArenaDataset(SubjectiveCmpDataset):
def load(
self,
path: str,
name: str,
):
dataset = list(super().load(path, name))
creation_dataset = []
for data in dataset:
if 'reference' in data['others']:
if data['others']['reference'] is not None:
data['ref'] = data['others']['reference']
else:
data['ref'] = '满足用户需求,言之有理即可'
else:
data['ref'] = '满足用户需求,言之有理即可'
creation_dataset.append(data)
dataset = Dataset.from_list(creation_dataset)
return dataset

View File

@ -26,7 +26,8 @@ class SubjectiveCmpDataset(BaseDataset):
'capability': capability, 'capability': capability,
'others': others, 'others': others,
'judge': { 'judge': {
'capability': capability 'capability': capability,
'question': question
} }
}) })
dataset = Dataset.from_list(raw_data) dataset = Dataset.from_list(raw_data)

View File

@ -1,5 +1,6 @@
# flake8: noqa: F401, E501 # flake8: noqa: F401, E501
from .alignmentbench import AlignmentBenchSummarizer from .alignmentbench import AlignmentBenchSummarizer
from .compass_arena import CompassArenaSummarizer
from .corev2 import Corev2Summarizer from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer from .creationbench import CreationBenchSummarizer
from .information_retrival import IRSummarizer from .information_retrival import IRSummarizer

View File

@ -0,0 +1,174 @@
# flake8: noqa: E501
import ast
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
from itertools import product
import mmengine
from mmengine import ConfigDict
from prettytable import from_csv
from opencompass.partitioners.sub_naive import remove_duplicate_pairs
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
from .utils import get_judgeanswer_and_reference, get_outdir
def post_process_compass_arena(s):
if result := re.findall('(?:选择:|Choice: )([ABC])', s):
return result[0]
else:
return None
def check_position_bias(judged_answers, references, banned_choice=['C']):
"""Check position bias for judgellm's judgement.
Args:
judged_answers: The successfully extracted judgement.
references: The references contains original question, which is used to located the same question for different position judgement.
"""
position_bias_flag = 0
position_bias_dict = {}
for judge, ref in zip(judged_answers, references):
question = ref['others']['question']
question_hash = hash(question)
if question_hash not in position_bias_dict:
position_bias_dict[question_hash] = {
'question': question,
'judge': judge
}
else:
first_judge = position_bias_dict[question_hash]['judge']
if judge == first_judge and first_judge not in banned_choice and judge not in banned_choice:
# If second choice is same with first choice, there has position bias.
position_bias_flag += 1
return position_bias_flag
class CompassArenaSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, judge_type='general') -> None:
self.tasks = []
self.cfg = config
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_type = judge_type
assert self.judge_type in ['general']
self.judge_map = {
'general': post_process_compass_arena,
}
self.judge_function = self.judge_map[self.judge_type]
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S'),
check_pos_bias=True):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
model_combinations = list(
product(self.base_models, self.compare_models))
unique_combinations = remove_duplicate_pairs(
[combo for combo in model_combinations if combo[0] != combo[1]])
fout_list = []
for model_pair in unique_combinations:
model1, model2, judge_model = model_pair[0]['abbr'], model_pair[1][
'abbr'], self.judge_abbr
subdir = model1 + '_' + model2 + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
for dataset in dataset_cfgs:
dataset_abbr = dataset_abbr_from_cfg(dataset)
fout = osp.join(
output_dir, 'judged-by--' + judge_model + '-' +
dataset_abbr + '-report.csv')
fout_list.append(fout)
judged_answers, references = get_judgeanswer_and_reference(
dataset,
subdir_path,
self.judge_function,
)
if check_pos_bias:
bias_num = check_position_bias(judged_answers,
references)
else:
bias_num = 0
win_model1, win_model2, categories = defaultdict(
float), defaultdict(float), defaultdict(float)
model1, model2 = references[0]['answer1'], references[0][
'answer2']
for prediction, reference in zip(judged_answers,
references):
if dataset_abbr == 'zhihu_hot_0113':
reference['capability'] = 'QA'
categories['total'] += 1
categories[reference['capability']] += 1
if prediction == 'A':
if reference['answer1'] == model1:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
else:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
elif prediction == 'B':
if reference['answer1'] == model1:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
else:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
for capability in categories:
if capability not in win_model1:
win_model1[capability] = 0.0
else:
win_model1[capability] = round(
(win_model1[capability] /
categories[capability]) * 100, 2)
if capability not in win_model2:
win_model2[capability] = 0.0
else:
win_model2[capability] = round(
(win_model2[capability] /
categories[capability]) * 100, 2)
win_model1['position_bias'] = bias_num
win_model2['position_bias'] = bias_num
scores = {
'win_' + model1: win_model1,
'win_' + model2: win_model2
}
rows = list(scores.keys())
columns = list(scores[rows[0]].keys())
columns.insert(0, columns.pop(columns.index('total')))
columns.insert(1,
columns.pop(columns.index('position_bias')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([model1 + '_vs_' + model2] + columns)
for row in rows:
writer.writerow(
[row] +
[scores[row][column] for column in columns])
else:
print(subdir_path + ' is not exist! please check!')
for fout in fout_list:
with open(fout, 'r') as f:
x = from_csv(f)
print(x)