diff --git a/configs/datasets/subjective/wildbench/wildbench.md b/configs/datasets/subjective/wildbench/wildbench.md new file mode 100644 index 00000000..e4567ba1 --- /dev/null +++ b/configs/datasets/subjective/wildbench/wildbench.md @@ -0,0 +1,30 @@ +# Wildbench + +## Prepare the dataset + +We support the [wildbench dataset](https://github.com/allenai/WildBench), developed by Lin et al. Please refer to their repo for more detail. + +You have to download our preprocessed dataset. The format of dir should be like: + +``` +wildbench +---wildbench.jsonl +---gpt4 +------wildbench.json +---claude +------wildbench.json +---llama2-70b +------wildbench.json +``` + +The wildbench.jsonl is the preprocessed dataset, and the other three are the reference, used for score. + +Once you download the dataset, you have to modify the path defined in `configs/datasets/subjective/wildbench/wildbench_pair_judge.py` and `configs/datasets/subjective/wildbench/wildbench_single_judge.py` + +## Run + +We have provide the script for wildbench in `configs/eval_subjective_wildbench_pair.py` and `configs/eval_subjective_wildbench_single.py`. + +Please modify the path for `give_pred` (line 171) in `configs/eval_subjective_wildbench_pair.py` to your path. + +Note that if you test the wildbench with other models, please set the max_out_lens to 4096. diff --git a/configs/datasets/subjective/wildbench/wildbench_pair_judge.py b/configs/datasets/subjective/wildbench/wildbench_pair_judge.py new file mode 100644 index 00000000..e0a34c70 --- /dev/null +++ b/configs/datasets/subjective/wildbench/wildbench_pair_judge.py @@ -0,0 +1,46 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer +from opencompass.openicl.icl_evaluator import LMEvaluator +from opencompass.datasets import WildBenchDataset + + +subjective_reader_cfg = dict( + input_columns=['dialogue', 'prompt'], + output_column='judge', + ) + + +data_path ='./data/WildBench/wildbench.jsonl' + +subjective_datasets = [] +subjective_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template="""{dialogue}""" + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, infer_mode='last'), + ) + +subjective_eval_cfg = dict( + evaluator=dict( + type=LMEvaluator, + prompt_template=dict( + type=PromptTemplate, + template="""{prompt}""" + ), + ), + pred_role='BOT', +) + +subjective_datasets.append( + dict( + abbr='wildbench', + type=WildBenchDataset, + path=data_path, + mode='pair', + reader_cfg=subjective_reader_cfg, + infer_cfg=subjective_infer_cfg, + eval_cfg=subjective_eval_cfg + )) diff --git a/configs/datasets/subjective/wildbench/wildbench_single_judge.py b/configs/datasets/subjective/wildbench/wildbench_single_judge.py new file mode 100644 index 00000000..be11abcb --- /dev/null +++ b/configs/datasets/subjective/wildbench/wildbench_single_judge.py @@ -0,0 +1,47 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer +from opencompass.openicl.icl_evaluator import LMEvaluator +from opencompass.datasets import WildBenchDataset + + +subjective_reader_cfg = dict( + input_columns=['dialogue', 'prompt'], + output_column='judge', + ) + +data_path ='./data/WildBench/wildbench.jsonl' + +subjective_datasets = [] + +# the question is a list, how to process it +subjective_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template="""{dialogue}""" + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, infer_mode='last'), + ) + +subjective_eval_cfg = dict( + evaluator=dict( + type=LMEvaluator, + prompt_template=dict( + type=PromptTemplate, + template="""{prompt}""" + ), + ), + pred_role='BOT', +) + +subjective_datasets.append( + dict( + abbr='wildbench', + type=WildBenchDataset, + path=data_path, + mode='single', + reader_cfg=subjective_reader_cfg, + infer_cfg=subjective_infer_cfg, + eval_cfg=subjective_eval_cfg + )) diff --git a/configs/eval_subjective_wildbench_pair.py b/configs/eval_subjective_wildbench_pair.py new file mode 100644 index 00000000..652793cf --- /dev/null +++ b/configs/eval_subjective_wildbench_pair.py @@ -0,0 +1,180 @@ +from mmengine.config import read_base + +with read_base(): + # from .datasets.subjective.multiround.mtbench_single_judge_diff_temp import subjective_datasets + from .datasets.subjective.wildbench.wildbench_pair_judge import subjective_datasets + from .models.openai.gpt_4 import models as gpt4_models + from .models.hf_llama.hf_llama2_70b_chat import models as llama2_models + # from .models.gemma.hf_gemma_2b_it import models + # from .models.hf_llama.hf_llama3_70b_instruct import models as llama3_model + # # from .models.hf_internlm.hf_internlm2_chat_7b import models + # from .models.yi.hf_yi_1_5_34b_chat import models as yi_model + # from .models.qwen.hf_qwen1_5_72b_chat import models as qwen_model + +from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI +from opencompass.partitioners import NaivePartitioner, SizePartitioner +from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner +from opencompass.partitioners.sub_size import SubjectiveSizePartitioner +from opencompass.runners import LocalRunner +from opencompass.runners import SlurmSequentialRunner +from opencompass.tasks import OpenICLInferTask +from opencompass.tasks.subjective_eval import SubjectiveEvalTask +from opencompass.summarizers import WildBenchPairSummarizer +from opencompass.models.claude_api.claude_api import Claude +from opencompass.models import HuggingFacewithChatTemplate + + +models = sum([v for k, v in locals().items() if k.endswith('_model')], []) + +api_meta_template = dict( + round=[ + dict(role='SYSTEM', api_role='SYSTEM'), + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ] +) + +# _meta_template = dict( +# round=[ +# dict(role='HUMAN', begin='\n<|im_start|>user\n', end='<|im_end|>'), +# dict(role='BOT', begin='\n<|im_start|>assistant\n', end='<|im_end|>', generate=True), +# ], +# ) +# -------------Inference Stage ---------------------------------------- +# For subjective evaluation, we often set do sample for models + +models = [ + dict( + type=HuggingFacewithChatTemplate, + abbr='llama-3-8b-instruct-hf', + path='meta-llama/Meta-Llama-3-8B-Instruct', + max_out_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1), + stop_words=['<|end_of_text|>', '<|eot_id|>'], + ), + dict( + type=HuggingFacewithChatTemplate, + abbr='yi-1.5-6b-chat-hf', + path='01-ai/Yi-1.5-6B-Chat', + max_out_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1), + ), + dict( + type=HuggingFacewithChatTemplate, + abbr='qwen1.5-7b-chat-hf', + path='Qwen/Qwen1.5-7B-Chat', + max_out_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1), + ), + # dict( + # type=HuggingFacewithChatTemplate, + # abbr='llama-3-70b-instruct-hf', + # path='meta-llama/Meta-Llama-3-70B-Instruct', + # max_out_len=4096, + # batch_size=8, + # run_cfg=dict(num_gpus=4), + # stop_words=['<|end_of_text|>', '<|eot_id|>'], + # ), + # dict( + # type=HuggingFacewithChatTemplate, + # abbr='yi-1.5-34b-chat-hf', + # path='01-ai/Yi-1.5-34B-Chat', + # max_out_len=4096, + # batch_size=8, + # run_cfg=dict(num_gpus=2), + # ), + # dict( + # type=HuggingFacewithChatTemplate, + # abbr='qwen1.5-72b-chat-hf', + # path='Qwen/Qwen1.5-72B-Chat', + # max_out_len=4096, + # batch_size=8, + # run_cfg=dict(num_gpus=8), + # ) +] + +datasets = [*subjective_datasets] + +# -------------Evalation Stage ---------------------------------------- + +## ------------- JudgeLLM Configuration +judge_models = [dict( + abbr='GPT4-Turbo', + type=OpenAI, + path='gpt-4-0613', # To compare with the official leaderboard, please use gpt4-0613 + key='xxxx', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + meta_template=api_meta_template, + query_per_second=16, + max_out_len=2048, + max_seq_len=2048, + batch_size=8, + temperature=0, +)] + +gpt4 = dict( + abbr='gpt4-turbo', + type=OpenAI, + path='gpt-4-0409-preview', + key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + meta_template=api_meta_template, + query_per_second=1, + max_out_len=2048, + max_seq_len=4096, + batch_size=4, + retry=20, + temperature=1, +) # Re-inference gpt4's predictions or you can choose to use the pre-commited gpt4's predictions + +claude = dict(abbr='HaiKu', + type=Claude, + path='claude-2', + key='YOUR_CLAUDE_KEY', + query_per_second=1, + max_out_len=2048, max_seq_len=2048, batch_size=2, + ) +## single evaluation +# eval = dict( +# partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models, judge_models=judge_models), +# runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask)), +# ) +infer = dict( + partitioner=dict(type=SizePartitioner, max_task_size=1000, strategy='split'), + runner=dict( + type=SlurmSequentialRunner, + max_num_workers=64, + quotatype='reserved', + partition='llmeval', + task=dict(type=OpenICLInferTask)), +) + +eval = dict( + partitioner=dict( + type=SubjectiveNaivePartitioner, + mode='m2n', # m个模型 与 n个模型进行对战 + infer_order='random', + # 在m2n模式下,需要指定base_models和compare_models,将会对base_models和compare_models生成对应的两两pair(去重且不会与自身进行比较) + base_models = [*llama2_models, gpt4, claude], # 用于对比的基线模型 + compare_models = models, # 待评测模型 + judge_models=judge_models + ), + runner=dict( + type=LocalRunner, + # partition='llmeval', + # quotatype='auto', + max_num_workers=3, + task=dict( + type=SubjectiveEvalTask + )), + given_pred = [{'abbr':'gpt4-turbo', 'path':'./data/WildBench/gpt4'}, + {'abbr': 'llama-2-70b-chat-hf', 'path':'./data/WildBench/llama2-70b'}, + {'abbr': 'HaiKu', 'path':'./data/WildBench/claude'}, + {'abbr': 'llama-2-70b-chat-turbomind', 'path':'./data/WildBench/llama2-70b'}, + {'abbr': 'llama-2-70b-chat-vllm', 'path':'./data/WildBench/llama2-70b'}] +) + +summarizer = dict(type=WildBenchPairSummarizer) + +work_dir = 'outputs/wildbench/' diff --git a/configs/eval_subjective_wildbench_single.py b/configs/eval_subjective_wildbench_single.py new file mode 100644 index 00000000..5e053488 --- /dev/null +++ b/configs/eval_subjective_wildbench_single.py @@ -0,0 +1,135 @@ +from mmengine.config import read_base + +with read_base(): + # from .datasets.subjective.multiround.mtbench_single_judge_diff_temp import subjective_datasets + from .datasets.subjective.wildbench.wildbench_single_judge import subjective_datasets + # from .models.gemma.hf_gemma_2b_it import models as gemma_2b_models + # from .models.hf_llama.hf_llama3_70b_instruct import models as llama3_model + # # from .models.hf_internlm.hf_internlm2_chat_7b import models + # from .models.yi.hf_yi_1_5_34b_chat import models as yi_model + # from .models.qwen.hf_qwen1_5_72b_chat import models as qwen_model + +from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI +from opencompass.partitioners import NaivePartitioner, SizePartitioner +from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner +from opencompass.partitioners.sub_size import SubjectiveSizePartitioner +from opencompass.runners import LocalRunner +from opencompass.runners import SlurmSequentialRunner +from opencompass.tasks import OpenICLInferTask +from opencompass.tasks.subjective_eval import SubjectiveEvalTask +from opencompass.summarizers import WildBenchSingleSummarizer +from opencompass.models import HuggingFacewithChatTemplate + + +# models = sum([v for k, v in locals().items() if k.endswith("_model")], []) + +api_meta_template = dict( + round=[ + dict(role='SYSTEM', api_role='SYSTEM'), + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ] +) + +# _meta_template = dict( +# round=[ +# dict(role='HUMAN', begin='\n<|im_start|>user\n', end='<|im_end|>'), +# dict(role='BOT', begin='\n<|im_start|>assistant\n', end='<|im_end|>', generate=True), +# ], +# ) +# -------------Inference Stage ---------------------------------------- +# For subjective evaluation, we often set do sample for models +# set max_out_len to 4096. +models = [ + dict( + type=HuggingFacewithChatTemplate, + abbr='llama-3-8b-instruct-hf', + path='meta-llama/Meta-Llama-3-8B-Instruct', + max_out_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1), + stop_words=['<|end_of_text|>', '<|eot_id|>'], + ), + dict( + type=HuggingFacewithChatTemplate, + abbr='yi-1.5-6b-chat-hf', + path='01-ai/Yi-1.5-6B-Chat', + max_out_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1), + ), + dict( + type=HuggingFacewithChatTemplate, + abbr='qwen1.5-7b-chat-hf', + path='Qwen/Qwen1.5-7B-Chat', + max_out_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1), + ), + # dict( + # type=HuggingFacewithChatTemplate, + # abbr='llama-3-70b-instruct-hf', + # path='meta-llama/Meta-Llama-3-70B-Instruct', + # max_out_len=4096, + # batch_size=8, + # run_cfg=dict(num_gpus=4), + # stop_words=['<|end_of_text|>', '<|eot_id|>'], + # ), + # dict( + # type=HuggingFacewithChatTemplate, + # abbr='yi-1.5-34b-chat-hf', + # path='01-ai/Yi-1.5-34B-Chat', + # max_out_len=4096, + # batch_size=8, + # run_cfg=dict(num_gpus=2), + # ), + # dict( + # type=HuggingFacewithChatTemplate, + # abbr='qwen1.5-72b-chat-hf', + # path='Qwen/Qwen1.5-72B-Chat', + # max_out_len=4096, + # batch_size=8, + # run_cfg=dict(num_gpus=4), + # ) +] + +datasets = [*subjective_datasets] + +# -------------Evalation Stage ---------------------------------------- + +## ------------- JudgeLLM Configuration +judge_models = [dict( + abbr='GPT4-Turbo', + type=OpenAI, + path='gpt-4-0613', # To compare with the official leaderboard, please use gpt4-0613 + key='xxxx', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + meta_template=api_meta_template, + query_per_second=16, + max_out_len=2048, + max_seq_len=2048, + batch_size=8, + temperature=0, +)] + + +infer = dict( + partitioner=dict(type=SizePartitioner, max_task_size=1000, strategy='split'), + runner=dict( + type=SlurmSequentialRunner, + max_num_workers=64, + quotatype='reserved', + partition='llmeval', + task=dict(type=OpenICLInferTask)), +) + +## single evaluation +eval = dict( + partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models, judge_models=judge_models), + runner=dict(type=LocalRunner, + max_num_workers=2, + task=dict(type=SubjectiveEvalTask)), +) + +summarizer = dict(type=WildBenchSingleSummarizer) + +work_dir = 'outputs/wildbench/' diff --git a/opencompass/datasets/subjective/__init__.py b/opencompass/datasets/subjective/__init__.py index 5219700c..03acd545 100644 --- a/opencompass/datasets/subjective/__init__.py +++ b/opencompass/datasets/subjective/__init__.py @@ -10,3 +10,4 @@ from .mtbench import MTBenchDataset # noqa: F401, F403 from .mtbench101 import MTBench101Dataset # noqa: F401, F403 from .multiround import MultiroundDataset # noqa: F401, F403 from .subjective_cmp import SubjectiveCmpDataset # noqa: F401, F403 +from .wildbench import WildBenchDataset # noqa: F401, F403 diff --git a/opencompass/datasets/subjective/wildbench.py b/opencompass/datasets/subjective/wildbench.py new file mode 100644 index 00000000..8f0995f5 --- /dev/null +++ b/opencompass/datasets/subjective/wildbench.py @@ -0,0 +1,249 @@ +import json + +from datasets import Dataset, DatasetDict + +from opencompass.registry import LOAD_DATASET + +from ..base import BaseDataset + +score_prompt = """# Instruction + +You are an expert evaluator. Your task is to evaluate the quality of \ +the responses generated by AI models. +We will provide you with the user query and an AI-generated responses. +You should first read the user query and the conversation history \ +carefully for analyzing the task, and then evaluate the quality of \ +the responses based on and rules provided below. + +# Conversation between User and AI + +## History +<|begin_of_history|> + +{history} + +<|end_of_history|> + +## Current User Query +<|begin_of_query|> + +{user_query} + +<|end_of_query|> + +## AI Response +<|begin_of_response|> + +{prediction} + +<|end_of_response|> + + +# Evaluation + +## Checklist + +<|begin_of_checklist|> + +{checklist} + +<|end_of_checklist|> + +Please use this checklist to guide your evaluation, but do \ +not limit your assessment to the checklist. + +## Rules + +You should compare the above response based on your analysis\ + of the user queries and the conversation history. +You should first write down your analysis and the checklist \ +that you used for the evaluation, and then provide your \ +assessment according to the checklist. +The scores are in the range of 1~10, where 1 means the \ +response is very poor and 10 means the response is perfect. +Here are more detailed criteria for the scores: + +- Score 1~2: The response is very poor and does not make sense at all. +- Score 3~4: The response is poor and does help user solve the problem\ + in a meaningful way. +- Score 5~6: The response is fair but has some issues (e.g., factual \ +errors, hallucinations, missing key information). +- Score 7~8: The response is good enough but could be improved in some ways. +- Score 9~10: The response is perfect and provides helpful information that\ + can help user solve the problem. + +## Output Format +First, please output your analysis for the model response, and then summarize\ + your assessment to two aspects: "strengths" and "weaknesses"; Finally, please\ + write down your rating for the assessment. + +Please provide your evaluation results in the following json format by filling\ + in the placeholders in []: +``` +{ + "strengths": "[analysis for the strengths of the response]", + "weaknesses": "[analysis for the weaknesses of the response]", + "score": "[1~10]" +} +```""" + +pair_prompt = """# Instruction + +You are an expert evaluator. Your task is to evaluate the quality of the \ +responses generated by two AI models. +We will provide you with the user query and a pair of AI-generated \ +responses (Response A and Response B). +You should first read the user query and the conversation history \ +carefully for analyzing the task, and then evaluate the quality of the \ +responses based on and rules provided below. + +# Conversation between User and AI + +## History +<|begin_of_history|> + +{history} + +<|end_of_history|> + +## Current User Query +<|begin_of_query|> + +{user_query} + +<|end_of_query|> + +## Response A +<|begin_of_response_A|> + +{prediction} + +<|end_of_response_A|> + +## Response B +<|begin_of_response_B|> + +{prediction2} + +<|end_of_response_B|> + +# Evaluation + +## Checklist + +<|begin_of_checklist|> + +{checklist} + +<|end_of_checklist|> + +Please use this checklist to guide your evaluation, but do not limit your \ +assessment to the checklist. + +## Rules + +You should compare the above two responses based on your analysis of the \ +user queries and the conversation history. +You should first write down your analysis and the checklist that you used \ +for the evaluation, and then provide your assessment according to the \ +checklist. +There are five choices to give your final assessment: ["A++", "A+", \ +"A=B", "B+", "B++"], which correspond to the following meanings: + +- `A++`: Response A is much better than Response B. +- `A+`: Response A is only slightly better than Response B. +- `A=B`: Response A and B are of the same quality. Please use this \ +choice sparingly. +- `B+`: Response B is only slightly better than Response A. +- `B++`: Response B is much better than Response A. + + +## Output Format +First, please output your analysis for each model response, and \ +then summarize your assessment to three aspects: "reason A=B", \ +"reason A>B", and "reason B>A", and finally make your choice for \ +the final assessment. + +Please provide your evaluation results in the following json \ +format by filling in the placeholders in []: +``` +{ + "analysis of A": "[analysis of Response A]", + "analysis of B": "[analysis of Response B]", + "reason of A=B": "[where Response A and B perform equally well]", + "reason of A>B": "[where Response A is better than Response B]", + "reason of B>A": "[where Response B is better than Response A]", + "choice": "[A++ or A+ or A=B or B+ or B++]", +} +``` +""" + + +def parse_conversation(conversation): + # parse conversation into chat dialogue + role_dict = {'user': 'HUMAN', 'assistant': 'assistant'} + chat_round = [] + history = '' + if len(conversation) > 0: + for x in conversation[:-1]: + if x['role'] == 'user': + history += 'USER: ' + x['content'] + '\n\n' + elif x['role'] == 'assistant': + history += 'ASSISTANT: ' + x['content'] + '\n\n' + + chat_round.append({ + 'role': role_dict[x['role']], + 'content': x['content'] + }) + + last_query = conversation[-1]['content'] + chat_round.append({ + 'role': role_dict[conversation[-1]['role']], + 'content': conversation[-1]['content'] + }) + chat_round.append({'role': 'assistant', 'content': ''}) + + return chat_round, last_query, history + + +@LOAD_DATASET.register_module() +class WildBenchDataset(BaseDataset): + + def load(self, path: str, K=-1, mode='pair'): + dataset = DatasetDict() + raw_data = [] + with open(path, 'r', encoding='utf-8') as file: + for line in file: + item = json.loads(line) + chat_round, last_query, history = parse_conversation( + item['turn']) + + checklist_mardkdown = '' + for checklist_item in item['checklist']: + checklist_mardkdown += f'- {checklist_item}\n' + + if mode == 'single': + prompt = score_prompt + elif mode == 'pair': + prompt = pair_prompt + else: + assert NotImplementedError( + f'Mode {mode} not in single or pair.') + + prompt = prompt.replace('{history}', history) + prompt = prompt.replace('{user_query}', last_query) + prompt = prompt.replace('{checklist}', checklist_mardkdown) + + raw_data.append({ + 'dialogue': chat_round, + 'history': history, + 'prompt': prompt, + 'judge': { + 'other': None, + 'primary_tag': item['primary_tag'], + 'secondary_tag': item['secondary_tag'], + 'question_id': item['session_id'], + } + }) + dataset = Dataset.from_list(raw_data) + return dataset diff --git a/opencompass/summarizers/subjective/__init__.py b/opencompass/summarizers/subjective/__init__.py index 7457f14f..157d1713 100644 --- a/opencompass/summarizers/subjective/__init__.py +++ b/opencompass/summarizers/subjective/__init__.py @@ -13,3 +13,4 @@ from .information_retrival import IRSummarizer from .mtbench import MTBenchSummarizer from .mtbench101 import MTBench101Summarizer from .multiround import MultiroundSummarizer +from .wildbench import WildBenchPairSummarizer, WildBenchSingleSummarizer diff --git a/opencompass/summarizers/subjective/wildbench.py b/opencompass/summarizers/subjective/wildbench.py new file mode 100644 index 00000000..875b2c3f --- /dev/null +++ b/opencompass/summarizers/subjective/wildbench.py @@ -0,0 +1,295 @@ +# flake8: noqa +# yapf: disable +import csv +import os +import os.path as osp +import re +from collections import defaultdict +from datetime import datetime +from itertools import product + +import numpy as np +from mmengine import ConfigDict +from tabulate import tabulate + +from opencompass.partitioners.sub_naive import remove_duplicate_pairs +from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg + +from .compass_arena import (CompassArenaSummarizer, check_position_bias, + model_abbr_from_cfg_used_in_summarizer) +from .utils import get_judgeanswer_and_reference, get_outdir + +task_group_new = { + 'Information seeking': 'Information/Advice seeking', + 'Creative Writing': 'Creative Tasks', + 'Coding & Debugging': 'Coding & Debugging', + 'Reasoning': 'Planning & Reasoning', + 'Editing': 'Creative Tasks', + 'Math': 'Math & Data Analysis', + 'Planning': 'Planning & Reasoning', + 'Brainstorming': 'Creative Tasks', + 'Role playing': 'Creative Tasks', + 'Advice seeking': 'Information/Advice seeking', + 'Data Analysis': 'Math & Data Analysis', + 'Others': 'Creative Tasks'} + + +def post_process_wildbench_pair(judgement: str): + pattern = r'\"choice\": \"(.*?)\"' + matched_result = re.findall(pattern, judgement) + if matched_result: + return matched_result[0] + else: + return None + + +def post_process_wildbench_single(judgement: str): + pattern = r'\"score\": \"(.*?)\"' + matched_result = re.findall(pattern, judgement) + try: + score = float(matched_result[0]) + return {'score': score} + except (ValueError, IndexError) as e: + return None + + # if matched_result: + # score = float(matched_result[0]) + # else: + # return None + # return {'score': score} + + +def get_capability_results( + judged_answers, + references, + fout, + fout_flag, + model_abbr, +): + capability_ratings = defaultdict(float) + capability_counts = defaultdict(float) + + for ans, ref in zip(judged_answers, references): + # rescale + capability_ratings['total'] += ans + capability_counts['total'] += 1 + tags = [ref['primary_tag']] + ref['secondary_tag'] + for tag in tags: + capability_ratings[task_group_new[tag]] += ans + capability_counts[task_group_new[tag]] += 1 + + capability_avg_ratings = defaultdict(float) + + for capability, total_score in capability_ratings.items(): + s = (total_score / capability_counts[capability] - 5) * 2 * 10 + s = round(s, 2) + capability_avg_ratings[capability] = s + columns = list(capability_avg_ratings.keys()) + columns.insert(0, columns.pop(columns.index('total'))) + + with open(fout, 'a+', newline='') as csvfile: + writer = csv.writer(csvfile) + if fout_flag == 0: + writer.writerow(['model'] + columns) + writer.writerow([model_abbr] + [capability_avg_ratings[column] for column in columns]) + + +class WildBenchSingleSummarizer(CompassArenaSummarizer): + """Do the subjectivity analyze based on evaluation results. + + Args: + config (ConfigDict): The configuration object of the evaluation task. + It's expected to be filled out at runtime. + """ + + def __init__(self, config: ConfigDict) -> None: + self.judge_type = 'single' + self.tasks = [] + self.cfg = config + + self.eval_model_cfgs = self.cfg['eval']['partitioner']['models'] + self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0]) + self.judge_function = post_process_wildbench_single + + def summarize(self, time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): + """Summarize the subjectivity analysis based on evaluation results. + + Args: + time_str (str): Timestamp for file naming. + + Returns: + pd.DataFrame: The summary results. + """ + + # self.judge_type == 'single' + dataset_cfgs = self.cfg['datasets'] + output_dir, results_folder = get_outdir(self.cfg, time_str) + fout_flag = 0 + for eval_model_cfg in self.eval_model_cfgs: + eval_model_abbr = model_abbr_from_cfg(eval_model_cfg) + show_model_abbr = model_abbr_from_cfg_used_in_summarizer(eval_model_cfg) + subdir_path = os.path.join(results_folder, eval_model_abbr + '_judged-by--' + self.judge_abbr) + if os.path.isdir(subdir_path): + fout = osp.join(output_dir, 'judged-by--' + self.judge_abbr + '-capability.csv') + overall_judged_answers, overall_references = [], [] + for dataset in dataset_cfgs: + judged_answers, references = get_judgeanswer_and_reference(dataset, subdir_path, self.judge_function) + judged_answers = [item['score'] for item in judged_answers] + overall_judged_answers += judged_answers + overall_references += references + + get_capability_results(overall_judged_answers, overall_references, fout, fout_flag, show_model_abbr) + fout_flag += 1 + else: + print(subdir_path + ' is not exist! please check!') + + +class WildBenchPairSummarizer(CompassArenaSummarizer): + """Do the subjectivity analyze based on evaluation results. + + Args: + config (ConfigDict): The configuration object of the evaluation task. + It's expected to be filled out at runtime. + """ + + def __init__(self, config: ConfigDict, check_pos_bias=False) -> None: + self.tasks = [] + self.cfg = config + + self.base_models = self.cfg['eval']['partitioner']['base_models'] + self.compare_models = self.cfg['eval']['partitioner']['compare_models'] + self.judge_models = self.cfg.get('judge_models', None) + self.meta_judge_model = self.cfg.eval.partitioner.get('meta_judge_model', None) + self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0]) + self.judge_function = post_process_wildbench_pair + self.check_pos_bias = check_pos_bias + + def get_score(self, time_str): + output_dir, results_folder = get_outdir(self.cfg, time_str) + model_combinations = list(product(self.base_models, self.compare_models)) + unique_combinations = remove_duplicate_pairs([combo for combo in model_combinations if combo[0] != combo[1]]) + + if self.meta_judge_model is not None: + self.judge_models.append(self.meta_judge_model) + + scores = {} + for idx, judge_model_cfg in enumerate(self.judge_models): + judge_model = model_abbr_from_cfg(judge_model_cfg) + for dataset in self.cfg['datasets']: + dataset_abbr = dataset_abbr_from_cfg(dataset) + for model_pair in unique_combinations: + base_model = model_pair[0]['abbr'] + compare_model = model_pair[1]['abbr'] + if idx == len(self.judge_models): + subdir = base_model + '_' + compare_model + '_summarized-by--' + judge_model + else: + subdir = base_model + '_' + compare_model + '_judged-by--' + judge_model + subdir_path = os.path.join(results_folder, subdir) + if not os.path.isdir(subdir_path): + print(subdir_path + ' is not exist! please check!') + continue + judged_answers, references = get_judgeanswer_and_reference(dataset, subdir_path, self.judge_function) + if self.check_pos_bias: + bias_num = check_position_bias(judged_answers, references) + else: + bias_num = 0 + win_base_model = defaultdict(float) + win_compare_model = defaultdict(float) + categories = defaultdict(float) + # base_model = references[0]['answer1'] + # compare_model = references[0]['answer2'] + score_mapping = {'A++': 1, 'A+': 0.5, 'A=B': 0, 'B+': -0.5, 'B++': -1} + for prediction, reference in zip(judged_answers, references): + if prediction not in score_mapping: + continue + + categories[dataset_abbr] += 1 + flag = 1 if reference['answer1'] == base_model else -1 + score_1 = score_mapping[prediction]*flag + score_2 = -score_1 + + tags = [reference['primary_tag']] + reference['secondary_tag'] + for tag in tags: + win_base_model[task_group_new[tag]] += score_1 + win_compare_model[task_group_new[tag]] += score_2 + categories[task_group_new[tag]] += 1 + + win_compare_model[dataset_abbr] += score_2 + win_base_model[dataset_abbr] += score_1 + + for capability in categories: + win_base_model[capability] = win_base_model[capability] / categories[capability] * 100 + win_base_model[capability] = round(win_base_model[capability], 2) + win_compare_model[capability] = win_compare_model[capability] / categories[capability] * 100 + win_compare_model[capability] = round(win_compare_model[capability], 2) + + win_base_model['position_bias'] = bias_num + win_compare_model['position_bias'] = bias_num + + if judge_model not in scores: + scores[judge_model] = {} + if dataset_abbr not in scores[judge_model]: + scores[judge_model][dataset_abbr] = {} + scores[judge_model][dataset_abbr][base_model + '/' + compare_model] = win_compare_model + + return scores + + def summarize( + self, + time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S'), + ): + """Summarize the subjectivity analysis based on evaluation results. + + Args: + time_str (str): Timestamp for file naming. + + Returns: + pd.DataFrame: The summary results. + """ + scores = self.get_score(time_str) + output_dir, results_folder = get_outdir(self.cfg, time_str) + for idx, judge_model in enumerate(self.judge_models): + judge_abbr = model_abbr_from_cfg(judge_model) + for dataset in self.cfg['datasets']: + dataset_abbr = dataset_abbr_from_cfg(dataset) + summarizer_model_abbrs = [model_abbr_from_cfg_used_in_summarizer(i) for i in self.compare_models] + one_column = list(scores[judge_abbr][dataset_abbr].values())[0] + row_headers = [i for i in one_column.keys() if i not in [dataset_abbr, 'position_bias']] + row_headers = [dataset_abbr, 'position_bias'] + row_headers + + table = [] + for row_header in row_headers: + row = [row_header] + headers = [''] + for model_cfg in self.compare_models: + model_abbr = model_abbr_from_cfg(model_cfg) + avg = 0 + for base_model_cfg in self.base_models: + base_model_abbr = model_abbr_from_cfg(base_model_cfg) + base_compare = base_model_abbr + '/' + model_abbr + headers.append(base_compare) + s = scores[judge_abbr][dataset_abbr][base_compare].get(row_header, '') + if isinstance(s, float): + avg += s + s = f'{s:.2f}' + if isinstance(s, int): + s = str(s) + row.append(s) + avg = avg/len(self.base_models) + row.append(f'{avg:.2f}') + headers.append('Avg') + table.append(row) + + txt = tabulate(table, headers=headers) + print(txt) + + if idx == len(self.judge_models): + output_filename = osp.join(output_dir, 'summarized-by--' + judge_abbr + '-' + dataset_abbr + '-report.csv') + else: + output_filename = osp.join(output_dir, 'judged-by--' + judge_abbr + '-' + dataset_abbr + '-report.csv') + + with open(output_filename, 'w') as f: + f.write(','.join(headers) + '\n') + for line in table: + f.write(','.join(line) + '\n') + print(output_filename) diff --git a/opencompass/utils/prompt.py b/opencompass/utils/prompt.py index d65f6a03..cef6a31d 100644 --- a/opencompass/utils/prompt.py +++ b/opencompass/utils/prompt.py @@ -2,7 +2,6 @@ from __future__ import annotations import hashlib import json -import re from copy import deepcopy from typing import Dict, List, Union @@ -20,15 +19,20 @@ def safe_format(input_str: str, **kwargs) -> str: Returns: str: The formatted string. """ - segs = [input_str] + # import re + # segs = [input_str] + # for k, v in kwargs.items(): + # regex = re.compile(f'(?<={{{k}}})(?={{{k}}})|({{{k}}})') + # segs = [regex.split(seg) for seg in segs] + # segs = sum(segs, []) + # replace_dict = {f'{{{k}}}': str(v) for k, v in kwargs.items()} + # segs = [replace_dict.get(seg, seg) for seg in segs] + # output_str = ''.join(segs) + # return output_str + for k, v in kwargs.items(): - regex = re.compile(f'(?<={{{k}}})(?={{{k}}})|({{{k}}})') - segs = [regex.split(seg) for seg in segs] - segs = sum(segs, []) - replace_dict = {f'{{{k}}}': str(v) for k, v in kwargs.items()} - segs = [replace_dict.get(seg, seg) for seg in segs] - output_str = ''.join(segs) - return output_str + input_str = input_str.replace(f'{{{k}}}', str(v)) + return input_str def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: