Support wildbench (#1266)

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
klein 2024-06-24 13:16:27 +08:00 committed by GitHub
parent 83b9fd9eaa
commit 1fa62c4a42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 997 additions and 9 deletions

View File

@ -0,0 +1,30 @@
# Wildbench
## Prepare the dataset
We support the [wildbench dataset](https://github.com/allenai/WildBench), developed by Lin et al. Please refer to their repo for more detail.
You have to download our preprocessed dataset. The format of dir should be like:
```
wildbench
---wildbench.jsonl
---gpt4
------wildbench.json
---claude
------wildbench.json
---llama2-70b
------wildbench.json
```
The wildbench.jsonl is the preprocessed dataset, and the other three are the reference, used for score.
Once you download the dataset, you have to modify the path defined in `configs/datasets/subjective/wildbench/wildbench_pair_judge.py` and `configs/datasets/subjective/wildbench/wildbench_single_judge.py`
## Run
We have provide the script for wildbench in `configs/eval_subjective_wildbench_pair.py` and `configs/eval_subjective_wildbench_single.py`.
Please modify the path for `give_pred` (line 171) in `configs/eval_subjective_wildbench_pair.py` to your path.
Note that if you test the wildbench with other models, please set the max_out_lens to 4096.

View File

@ -0,0 +1,46 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import WildBenchDataset
subjective_reader_cfg = dict(
input_columns=['dialogue', 'prompt'],
output_column='judge',
)
data_path ='./data/WildBench/wildbench.jsonl'
subjective_datasets = []
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="""{dialogue}"""
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, infer_mode='last'),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template="""{prompt}"""
),
),
pred_role='BOT',
)
subjective_datasets.append(
dict(
abbr='wildbench',
type=WildBenchDataset,
path=data_path,
mode='pair',
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,47 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import WildBenchDataset
subjective_reader_cfg = dict(
input_columns=['dialogue', 'prompt'],
output_column='judge',
)
data_path ='./data/WildBench/wildbench.jsonl'
subjective_datasets = []
# the question is a list, how to process it
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="""{dialogue}"""
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, infer_mode='last'),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template="""{prompt}"""
),
),
pred_role='BOT',
)
subjective_datasets.append(
dict(
abbr='wildbench',
type=WildBenchDataset,
path=data_path,
mode='single',
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,180 @@
from mmengine.config import read_base
with read_base():
# from .datasets.subjective.multiround.mtbench_single_judge_diff_temp import subjective_datasets
from .datasets.subjective.wildbench.wildbench_pair_judge import subjective_datasets
from .models.openai.gpt_4 import models as gpt4_models
from .models.hf_llama.hf_llama2_70b_chat import models as llama2_models
# from .models.gemma.hf_gemma_2b_it import models
# from .models.hf_llama.hf_llama3_70b_instruct import models as llama3_model
# # from .models.hf_internlm.hf_internlm2_chat_7b import models
# from .models.yi.hf_yi_1_5_34b_chat import models as yi_model
# from .models.qwen.hf_qwen1_5_72b_chat import models as qwen_model
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import WildBenchPairSummarizer
from opencompass.models.claude_api.claude_api import Claude
from opencompass.models import HuggingFacewithChatTemplate
models = sum([v for k, v in locals().items() if k.endswith('_model')], [])
api_meta_template = dict(
round=[
dict(role='SYSTEM', api_role='SYSTEM'),
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
# _meta_template = dict(
# round=[
# dict(role='HUMAN', begin='\n<|im_start|>user\n', end='<|im_end|>'),
# dict(role='BOT', begin='\n<|im_start|>assistant\n', end='<|im_end|>', generate=True),
# ],
# )
# -------------Inference Stage ----------------------------------------
# For subjective evaluation, we often set do sample for models
models = [
dict(
type=HuggingFacewithChatTemplate,
abbr='llama-3-8b-instruct-hf',
path='meta-llama/Meta-Llama-3-8B-Instruct',
max_out_len=4096,
batch_size=8,
run_cfg=dict(num_gpus=1),
stop_words=['<|end_of_text|>', '<|eot_id|>'],
),
dict(
type=HuggingFacewithChatTemplate,
abbr='yi-1.5-6b-chat-hf',
path='01-ai/Yi-1.5-6B-Chat',
max_out_len=4096,
batch_size=8,
run_cfg=dict(num_gpus=1),
),
dict(
type=HuggingFacewithChatTemplate,
abbr='qwen1.5-7b-chat-hf',
path='Qwen/Qwen1.5-7B-Chat',
max_out_len=4096,
batch_size=8,
run_cfg=dict(num_gpus=1),
),
# dict(
# type=HuggingFacewithChatTemplate,
# abbr='llama-3-70b-instruct-hf',
# path='meta-llama/Meta-Llama-3-70B-Instruct',
# max_out_len=4096,
# batch_size=8,
# run_cfg=dict(num_gpus=4),
# stop_words=['<|end_of_text|>', '<|eot_id|>'],
# ),
# dict(
# type=HuggingFacewithChatTemplate,
# abbr='yi-1.5-34b-chat-hf',
# path='01-ai/Yi-1.5-34B-Chat',
# max_out_len=4096,
# batch_size=8,
# run_cfg=dict(num_gpus=2),
# ),
# dict(
# type=HuggingFacewithChatTemplate,
# abbr='qwen1.5-72b-chat-hf',
# path='Qwen/Qwen1.5-72B-Chat',
# max_out_len=4096,
# batch_size=8,
# run_cfg=dict(num_gpus=8),
# )
]
datasets = [*subjective_datasets]
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-0613', # To compare with the official leaderboard, please use gpt4-0613
key='xxxx', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=16,
max_out_len=2048,
max_seq_len=2048,
batch_size=8,
temperature=0,
)]
gpt4 = dict(
abbr='gpt4-turbo',
type=OpenAI,
path='gpt-4-0409-preview',
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048,
max_seq_len=4096,
batch_size=4,
retry=20,
temperature=1,
) # Re-inference gpt4's predictions or you can choose to use the pre-commited gpt4's predictions
claude = dict(abbr='HaiKu',
type=Claude,
path='claude-2',
key='YOUR_CLAUDE_KEY',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=2,
)
## single evaluation
# eval = dict(
# partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models, judge_models=judge_models),
# runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask)),
# )
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000, strategy='split'),
runner=dict(
type=SlurmSequentialRunner,
max_num_workers=64,
quotatype='reserved',
partition='llmeval',
task=dict(type=OpenICLInferTask)),
)
eval = dict(
partitioner=dict(
type=SubjectiveNaivePartitioner,
mode='m2n', # m个模型 与 n个模型进行对战
infer_order='random',
# 在m2n模式下需要指定base_models和compare_models将会对base_models和compare_models生成对应的两两pair去重且不会与自身进行比较
base_models = [*llama2_models, gpt4, claude], # 用于对比的基线模型
compare_models = models, # 待评测模型
judge_models=judge_models
),
runner=dict(
type=LocalRunner,
# partition='llmeval',
# quotatype='auto',
max_num_workers=3,
task=dict(
type=SubjectiveEvalTask
)),
given_pred = [{'abbr':'gpt4-turbo', 'path':'./data/WildBench/gpt4'},
{'abbr': 'llama-2-70b-chat-hf', 'path':'./data/WildBench/llama2-70b'},
{'abbr': 'HaiKu', 'path':'./data/WildBench/claude'},
{'abbr': 'llama-2-70b-chat-turbomind', 'path':'./data/WildBench/llama2-70b'},
{'abbr': 'llama-2-70b-chat-vllm', 'path':'./data/WildBench/llama2-70b'}]
)
summarizer = dict(type=WildBenchPairSummarizer)
work_dir = 'outputs/wildbench/'

View File

@ -0,0 +1,135 @@
from mmengine.config import read_base
with read_base():
# from .datasets.subjective.multiround.mtbench_single_judge_diff_temp import subjective_datasets
from .datasets.subjective.wildbench.wildbench_single_judge import subjective_datasets
# from .models.gemma.hf_gemma_2b_it import models as gemma_2b_models
# from .models.hf_llama.hf_llama3_70b_instruct import models as llama3_model
# # from .models.hf_internlm.hf_internlm2_chat_7b import models
# from .models.yi.hf_yi_1_5_34b_chat import models as yi_model
# from .models.qwen.hf_qwen1_5_72b_chat import models as qwen_model
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import WildBenchSingleSummarizer
from opencompass.models import HuggingFacewithChatTemplate
# models = sum([v for k, v in locals().items() if k.endswith("_model")], [])
api_meta_template = dict(
round=[
dict(role='SYSTEM', api_role='SYSTEM'),
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
# _meta_template = dict(
# round=[
# dict(role='HUMAN', begin='\n<|im_start|>user\n', end='<|im_end|>'),
# dict(role='BOT', begin='\n<|im_start|>assistant\n', end='<|im_end|>', generate=True),
# ],
# )
# -------------Inference Stage ----------------------------------------
# For subjective evaluation, we often set do sample for models
# set max_out_len to 4096.
models = [
dict(
type=HuggingFacewithChatTemplate,
abbr='llama-3-8b-instruct-hf',
path='meta-llama/Meta-Llama-3-8B-Instruct',
max_out_len=4096,
batch_size=8,
run_cfg=dict(num_gpus=1),
stop_words=['<|end_of_text|>', '<|eot_id|>'],
),
dict(
type=HuggingFacewithChatTemplate,
abbr='yi-1.5-6b-chat-hf',
path='01-ai/Yi-1.5-6B-Chat',
max_out_len=4096,
batch_size=8,
run_cfg=dict(num_gpus=1),
),
dict(
type=HuggingFacewithChatTemplate,
abbr='qwen1.5-7b-chat-hf',
path='Qwen/Qwen1.5-7B-Chat',
max_out_len=4096,
batch_size=8,
run_cfg=dict(num_gpus=1),
),
# dict(
# type=HuggingFacewithChatTemplate,
# abbr='llama-3-70b-instruct-hf',
# path='meta-llama/Meta-Llama-3-70B-Instruct',
# max_out_len=4096,
# batch_size=8,
# run_cfg=dict(num_gpus=4),
# stop_words=['<|end_of_text|>', '<|eot_id|>'],
# ),
# dict(
# type=HuggingFacewithChatTemplate,
# abbr='yi-1.5-34b-chat-hf',
# path='01-ai/Yi-1.5-34B-Chat',
# max_out_len=4096,
# batch_size=8,
# run_cfg=dict(num_gpus=2),
# ),
# dict(
# type=HuggingFacewithChatTemplate,
# abbr='qwen1.5-72b-chat-hf',
# path='Qwen/Qwen1.5-72B-Chat',
# max_out_len=4096,
# batch_size=8,
# run_cfg=dict(num_gpus=4),
# )
]
datasets = [*subjective_datasets]
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-0613', # To compare with the official leaderboard, please use gpt4-0613
key='xxxx', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=16,
max_out_len=2048,
max_seq_len=2048,
batch_size=8,
temperature=0,
)]
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000, strategy='split'),
runner=dict(
type=SlurmSequentialRunner,
max_num_workers=64,
quotatype='reserved',
partition='llmeval',
task=dict(type=OpenICLInferTask)),
)
## single evaluation
eval = dict(
partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models, judge_models=judge_models),
runner=dict(type=LocalRunner,
max_num_workers=2,
task=dict(type=SubjectiveEvalTask)),
)
summarizer = dict(type=WildBenchSingleSummarizer)
work_dir = 'outputs/wildbench/'

View File

@ -10,3 +10,4 @@ from .mtbench import MTBenchDataset # noqa: F401, F403
from .mtbench101 import MTBench101Dataset # noqa: F401, F403
from .multiround import MultiroundDataset # noqa: F401, F403
from .subjective_cmp import SubjectiveCmpDataset # noqa: F401, F403
from .wildbench import WildBenchDataset # noqa: F401, F403

View File

@ -0,0 +1,249 @@
import json
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
score_prompt = """# Instruction
You are an expert evaluator. Your task is to evaluate the quality of \
the responses generated by AI models.
We will provide you with the user query and an AI-generated responses.
You should first read the user query and the conversation history \
carefully for analyzing the task, and then evaluate the quality of \
the responses based on and rules provided below.
# Conversation between User and AI
## History
<|begin_of_history|>
{history}
<|end_of_history|>
## Current User Query
<|begin_of_query|>
{user_query}
<|end_of_query|>
## AI Response
<|begin_of_response|>
{prediction}
<|end_of_response|>
# Evaluation
## Checklist
<|begin_of_checklist|>
{checklist}
<|end_of_checklist|>
Please use this checklist to guide your evaluation, but do \
not limit your assessment to the checklist.
## Rules
You should compare the above response based on your analysis\
of the user queries and the conversation history.
You should first write down your analysis and the checklist \
that you used for the evaluation, and then provide your \
assessment according to the checklist.
The scores are in the range of 1~10, where 1 means the \
response is very poor and 10 means the response is perfect.
Here are more detailed criteria for the scores:
- Score 1~2: The response is very poor and does not make sense at all.
- Score 3~4: The response is poor and does help user solve the problem\
in a meaningful way.
- Score 5~6: The response is fair but has some issues (e.g., factual \
errors, hallucinations, missing key information).
- Score 7~8: The response is good enough but could be improved in some ways.
- Score 9~10: The response is perfect and provides helpful information that\
can help user solve the problem.
## Output Format
First, please output your analysis for the model response, and then summarize\
your assessment to two aspects: "strengths" and "weaknesses"; Finally, please\
write down your rating for the assessment.
Please provide your evaluation results in the following json format by filling\
in the placeholders in []:
```
{
"strengths": "[analysis for the strengths of the response]",
"weaknesses": "[analysis for the weaknesses of the response]",
"score": "[1~10]"
}
```"""
pair_prompt = """# Instruction
You are an expert evaluator. Your task is to evaluate the quality of the \
responses generated by two AI models.
We will provide you with the user query and a pair of AI-generated \
responses (Response A and Response B).
You should first read the user query and the conversation history \
carefully for analyzing the task, and then evaluate the quality of the \
responses based on and rules provided below.
# Conversation between User and AI
## History
<|begin_of_history|>
{history}
<|end_of_history|>
## Current User Query
<|begin_of_query|>
{user_query}
<|end_of_query|>
## Response A
<|begin_of_response_A|>
{prediction}
<|end_of_response_A|>
## Response B
<|begin_of_response_B|>
{prediction2}
<|end_of_response_B|>
# Evaluation
## Checklist
<|begin_of_checklist|>
{checklist}
<|end_of_checklist|>
Please use this checklist to guide your evaluation, but do not limit your \
assessment to the checklist.
## Rules
You should compare the above two responses based on your analysis of the \
user queries and the conversation history.
You should first write down your analysis and the checklist that you used \
for the evaluation, and then provide your assessment according to the \
checklist.
There are five choices to give your final assessment: ["A++", "A+", \
"A=B", "B+", "B++"], which correspond to the following meanings:
- `A++`: Response A is much better than Response B.
- `A+`: Response A is only slightly better than Response B.
- `A=B`: Response A and B are of the same quality. Please use this \
choice sparingly.
- `B+`: Response B is only slightly better than Response A.
- `B++`: Response B is much better than Response A.
## Output Format
First, please output your analysis for each model response, and \
then summarize your assessment to three aspects: "reason A=B", \
"reason A>B", and "reason B>A", and finally make your choice for \
the final assessment.
Please provide your evaluation results in the following json \
format by filling in the placeholders in []:
```
{
"analysis of A": "[analysis of Response A]",
"analysis of B": "[analysis of Response B]",
"reason of A=B": "[where Response A and B perform equally well]",
"reason of A>B": "[where Response A is better than Response B]",
"reason of B>A": "[where Response B is better than Response A]",
"choice": "[A++ or A+ or A=B or B+ or B++]",
}
```
"""
def parse_conversation(conversation):
# parse conversation into chat dialogue
role_dict = {'user': 'HUMAN', 'assistant': 'assistant'}
chat_round = []
history = ''
if len(conversation) > 0:
for x in conversation[:-1]:
if x['role'] == 'user':
history += 'USER: ' + x['content'] + '\n\n'
elif x['role'] == 'assistant':
history += 'ASSISTANT: ' + x['content'] + '\n\n'
chat_round.append({
'role': role_dict[x['role']],
'content': x['content']
})
last_query = conversation[-1]['content']
chat_round.append({
'role': role_dict[conversation[-1]['role']],
'content': conversation[-1]['content']
})
chat_round.append({'role': 'assistant', 'content': ''})
return chat_round, last_query, history
@LOAD_DATASET.register_module()
class WildBenchDataset(BaseDataset):
def load(self, path: str, K=-1, mode='pair'):
dataset = DatasetDict()
raw_data = []
with open(path, 'r', encoding='utf-8') as file:
for line in file:
item = json.loads(line)
chat_round, last_query, history = parse_conversation(
item['turn'])
checklist_mardkdown = ''
for checklist_item in item['checklist']:
checklist_mardkdown += f'- {checklist_item}\n'
if mode == 'single':
prompt = score_prompt
elif mode == 'pair':
prompt = pair_prompt
else:
assert NotImplementedError(
f'Mode {mode} not in single or pair.')
prompt = prompt.replace('{history}', history)
prompt = prompt.replace('{user_query}', last_query)
prompt = prompt.replace('{checklist}', checklist_mardkdown)
raw_data.append({
'dialogue': chat_round,
'history': history,
'prompt': prompt,
'judge': {
'other': None,
'primary_tag': item['primary_tag'],
'secondary_tag': item['secondary_tag'],
'question_id': item['session_id'],
}
})
dataset = Dataset.from_list(raw_data)
return dataset

View File

@ -13,3 +13,4 @@ from .information_retrival import IRSummarizer
from .mtbench import MTBenchSummarizer
from .mtbench101 import MTBench101Summarizer
from .multiround import MultiroundSummarizer
from .wildbench import WildBenchPairSummarizer, WildBenchSingleSummarizer

View File

@ -0,0 +1,295 @@
# flake8: noqa
# yapf: disable
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
from itertools import product
import numpy as np
from mmengine import ConfigDict
from tabulate import tabulate
from opencompass.partitioners.sub_naive import remove_duplicate_pairs
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
from .compass_arena import (CompassArenaSummarizer, check_position_bias,
model_abbr_from_cfg_used_in_summarizer)
from .utils import get_judgeanswer_and_reference, get_outdir
task_group_new = {
'Information seeking': 'Information/Advice seeking',
'Creative Writing': 'Creative Tasks',
'Coding & Debugging': 'Coding & Debugging',
'Reasoning': 'Planning & Reasoning',
'Editing': 'Creative Tasks',
'Math': 'Math & Data Analysis',
'Planning': 'Planning & Reasoning',
'Brainstorming': 'Creative Tasks',
'Role playing': 'Creative Tasks',
'Advice seeking': 'Information/Advice seeking',
'Data Analysis': 'Math & Data Analysis',
'Others': 'Creative Tasks'}
def post_process_wildbench_pair(judgement: str):
pattern = r'\"choice\": \"(.*?)\"'
matched_result = re.findall(pattern, judgement)
if matched_result:
return matched_result[0]
else:
return None
def post_process_wildbench_single(judgement: str):
pattern = r'\"score\": \"(.*?)\"'
matched_result = re.findall(pattern, judgement)
try:
score = float(matched_result[0])
return {'score': score}
except (ValueError, IndexError) as e:
return None
# if matched_result:
# score = float(matched_result[0])
# else:
# return None
# return {'score': score}
def get_capability_results(
judged_answers,
references,
fout,
fout_flag,
model_abbr,
):
capability_ratings = defaultdict(float)
capability_counts = defaultdict(float)
for ans, ref in zip(judged_answers, references):
# rescale
capability_ratings['total'] += ans
capability_counts['total'] += 1
tags = [ref['primary_tag']] + ref['secondary_tag']
for tag in tags:
capability_ratings[task_group_new[tag]] += ans
capability_counts[task_group_new[tag]] += 1
capability_avg_ratings = defaultdict(float)
for capability, total_score in capability_ratings.items():
s = (total_score / capability_counts[capability] - 5) * 2 * 10
s = round(s, 2)
capability_avg_ratings[capability] = s
columns = list(capability_avg_ratings.keys())
columns.insert(0, columns.pop(columns.index('total')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
if fout_flag == 0:
writer.writerow(['model'] + columns)
writer.writerow([model_abbr] + [capability_avg_ratings[column] for column in columns])
class WildBenchSingleSummarizer(CompassArenaSummarizer):
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict) -> None:
self.judge_type = 'single'
self.tasks = []
self.cfg = config
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
self.judge_function = post_process_wildbench_single
def summarize(self, time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
# self.judge_type == 'single'
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
fout_flag = 0
for eval_model_cfg in self.eval_model_cfgs:
eval_model_abbr = model_abbr_from_cfg(eval_model_cfg)
show_model_abbr = model_abbr_from_cfg_used_in_summarizer(eval_model_cfg)
subdir_path = os.path.join(results_folder, eval_model_abbr + '_judged-by--' + self.judge_abbr)
if os.path.isdir(subdir_path):
fout = osp.join(output_dir, 'judged-by--' + self.judge_abbr + '-capability.csv')
overall_judged_answers, overall_references = [], []
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(dataset, subdir_path, self.judge_function)
judged_answers = [item['score'] for item in judged_answers]
overall_judged_answers += judged_answers
overall_references += references
get_capability_results(overall_judged_answers, overall_references, fout, fout_flag, show_model_abbr)
fout_flag += 1
else:
print(subdir_path + ' is not exist! please check!')
class WildBenchPairSummarizer(CompassArenaSummarizer):
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, check_pos_bias=False) -> None:
self.tasks = []
self.cfg = config
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
self.judge_models = self.cfg.get('judge_models', None)
self.meta_judge_model = self.cfg.eval.partitioner.get('meta_judge_model', None)
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
self.judge_function = post_process_wildbench_pair
self.check_pos_bias = check_pos_bias
def get_score(self, time_str):
output_dir, results_folder = get_outdir(self.cfg, time_str)
model_combinations = list(product(self.base_models, self.compare_models))
unique_combinations = remove_duplicate_pairs([combo for combo in model_combinations if combo[0] != combo[1]])
if self.meta_judge_model is not None:
self.judge_models.append(self.meta_judge_model)
scores = {}
for idx, judge_model_cfg in enumerate(self.judge_models):
judge_model = model_abbr_from_cfg(judge_model_cfg)
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
for model_pair in unique_combinations:
base_model = model_pair[0]['abbr']
compare_model = model_pair[1]['abbr']
if idx == len(self.judge_models):
subdir = base_model + '_' + compare_model + '_summarized-by--' + judge_model
else:
subdir = base_model + '_' + compare_model + '_judged-by--' + judge_model
subdir_path = os.path.join(results_folder, subdir)
if not os.path.isdir(subdir_path):
print(subdir_path + ' is not exist! please check!')
continue
judged_answers, references = get_judgeanswer_and_reference(dataset, subdir_path, self.judge_function)
if self.check_pos_bias:
bias_num = check_position_bias(judged_answers, references)
else:
bias_num = 0
win_base_model = defaultdict(float)
win_compare_model = defaultdict(float)
categories = defaultdict(float)
# base_model = references[0]['answer1']
# compare_model = references[0]['answer2']
score_mapping = {'A++': 1, 'A+': 0.5, 'A=B': 0, 'B+': -0.5, 'B++': -1}
for prediction, reference in zip(judged_answers, references):
if prediction not in score_mapping:
continue
categories[dataset_abbr] += 1
flag = 1 if reference['answer1'] == base_model else -1
score_1 = score_mapping[prediction]*flag
score_2 = -score_1
tags = [reference['primary_tag']] + reference['secondary_tag']
for tag in tags:
win_base_model[task_group_new[tag]] += score_1
win_compare_model[task_group_new[tag]] += score_2
categories[task_group_new[tag]] += 1
win_compare_model[dataset_abbr] += score_2
win_base_model[dataset_abbr] += score_1
for capability in categories:
win_base_model[capability] = win_base_model[capability] / categories[capability] * 100
win_base_model[capability] = round(win_base_model[capability], 2)
win_compare_model[capability] = win_compare_model[capability] / categories[capability] * 100
win_compare_model[capability] = round(win_compare_model[capability], 2)
win_base_model['position_bias'] = bias_num
win_compare_model['position_bias'] = bias_num
if judge_model not in scores:
scores[judge_model] = {}
if dataset_abbr not in scores[judge_model]:
scores[judge_model][dataset_abbr] = {}
scores[judge_model][dataset_abbr][base_model + '/' + compare_model] = win_compare_model
return scores
def summarize(
self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S'),
):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
scores = self.get_score(time_str)
output_dir, results_folder = get_outdir(self.cfg, time_str)
for idx, judge_model in enumerate(self.judge_models):
judge_abbr = model_abbr_from_cfg(judge_model)
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
summarizer_model_abbrs = [model_abbr_from_cfg_used_in_summarizer(i) for i in self.compare_models]
one_column = list(scores[judge_abbr][dataset_abbr].values())[0]
row_headers = [i for i in one_column.keys() if i not in [dataset_abbr, 'position_bias']]
row_headers = [dataset_abbr, 'position_bias'] + row_headers
table = []
for row_header in row_headers:
row = [row_header]
headers = ['']
for model_cfg in self.compare_models:
model_abbr = model_abbr_from_cfg(model_cfg)
avg = 0
for base_model_cfg in self.base_models:
base_model_abbr = model_abbr_from_cfg(base_model_cfg)
base_compare = base_model_abbr + '/' + model_abbr
headers.append(base_compare)
s = scores[judge_abbr][dataset_abbr][base_compare].get(row_header, '')
if isinstance(s, float):
avg += s
s = f'{s:.2f}'
if isinstance(s, int):
s = str(s)
row.append(s)
avg = avg/len(self.base_models)
row.append(f'{avg:.2f}')
headers.append('Avg')
table.append(row)
txt = tabulate(table, headers=headers)
print(txt)
if idx == len(self.judge_models):
output_filename = osp.join(output_dir, 'summarized-by--' + judge_abbr + '-' + dataset_abbr + '-report.csv')
else:
output_filename = osp.join(output_dir, 'judged-by--' + judge_abbr + '-' + dataset_abbr + '-report.csv')
with open(output_filename, 'w') as f:
f.write(','.join(headers) + '\n')
for line in table:
f.write(','.join(line) + '\n')
print(output_filename)

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import hashlib
import json
import re
from copy import deepcopy
from typing import Dict, List, Union
@ -20,15 +19,20 @@ def safe_format(input_str: str, **kwargs) -> str:
Returns:
str: The formatted string.
"""
segs = [input_str]
# import re
# segs = [input_str]
# for k, v in kwargs.items():
# regex = re.compile(f'(?<={{{k}}})(?={{{k}}})|({{{k}}})')
# segs = [regex.split(seg) for seg in segs]
# segs = sum(segs, [])
# replace_dict = {f'{{{k}}}': str(v) for k, v in kwargs.items()}
# segs = [replace_dict.get(seg, seg) for seg in segs]
# output_str = ''.join(segs)
# return output_str
for k, v in kwargs.items():
regex = re.compile(f'(?<={{{k}}})(?={{{k}}})|({{{k}}})')
segs = [regex.split(seg) for seg in segs]
segs = sum(segs, [])
replace_dict = {f'{{{k}}}': str(v) for k, v in kwargs.items()}
segs = [replace_dict.get(seg, seg) for seg in segs]
output_str = ''.join(segs)
return output_str
input_str = input_str.replace(f'{{{k}}}', str(v))
return input_str
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: