[Feature] add mtbench (#829)

* add mtbench

* add mtbench

* Update configs/datasets/subjective/multiround/mtbench_judgeby_gpt4.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update configs/datasets/subjective/multiround/mtbench_judgeby_gpt4.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update opencompass/datasets/subjective/__init__.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update opencompass/datasets/subjective/mtbench.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* fix mtbench

---------

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
This commit is contained in:
bittersweet1999 2024-01-24 12:11:47 +08:00 committed by GitHub
parent e059a5c2bf
commit 2ee8e8a1a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 609 additions and 11 deletions

View File

@ -0,0 +1,64 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import MTBenchDataset
subjective_reader_cfg = dict(
input_columns=['dialogue', 'capability', 'system_prompt', 'prompt_template'],
output_column='judge',
)
subjective_all_sets = [
"mtbench",
]
data_path ="data/subjective/"
subjective_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="""{dialogue}""",
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, infer_mode='every'),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
infer_order='double',
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="{system_prompt}")
],
round=[
dict(
role='HUMAN',
prompt = "{prompt_template}"
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=MTBenchDataset,
path=data_path,
name=_name,
judge_type='pair',
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,62 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import MTBenchDataset
subjective_reader_cfg = dict(
input_columns=['dialogue', 'capability', 'system_prompt', 'prompt_template'],
output_column='judge',
)
subjective_all_sets = [
"mtbench",
]
data_path ="data/subjective/"
subjective_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="""{dialogue}""",
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, infer_mode='every'),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="{system_prompt}")
],
round=[
dict(
role='HUMAN',
prompt = "{prompt_template}"
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=MTBenchDataset,
path=data_path,
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,116 @@
from mmengine.config import read_base
with read_base():
from .models.qwen.hf_qwen_7b_chat import models as hf_qwen_7b_chat
from .models.qwen.hf_qwen_14b_chat import models as hf_qwen_14b_chat
from .models.chatglm.hf_chatglm3_6b import models as hf_chatglm3_6b
from .models.baichuan.hf_baichuan2_7b_chat import models as hf_baichuan2_7b
from .models.hf_internlm.hf_internlm_chat_20b import models as hf_internlm_chat_20b
from .models.judge_llm.auto_j.hf_autoj_eng_13b import models as hf_autoj
from .models.judge_llm.judgelm.hf_judgelm_33b_v1 import models as hf_judgelm
from .models.judge_llm.pandalm.hf_pandalm_7b_v1 import models as hf_pandalm
from .datasets.subjective.multiround.mtbench_single_judge import subjective_datasets
#from .datasets.subjective.multiround.mtbench_pair_judge import subjective_datasets
datasets = [*subjective_datasets]
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3
from opencompass.models.openai_api import OpenAIAllesAPIN
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import MTBenchSummarizer
# -------------Inferen Stage ----------------------------------------
models = [*hf_chatglm3_6b, *hf_qwen_7b_chat]
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=100),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=256,
task=dict(type=OpenICLInferTask)),
)
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
judge_model = dict(
abbr='GPT4-Turbo',
type=OpenAIAllesAPIN, path='gpt-4-1106-preview',
key='xxxx', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
url='xxxx',
meta_template=api_meta_template,
query_per_second=16,
max_out_len=2048,
max_seq_len=2048,
batch_size=8,
temperature = 0
)
## ------------- Evaluation Configuration
'''
## pair evaluation
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
max_task_size=100,
mode='m2n',
base_models = [*hf_chatglm3_6b, ],
compare_models = models
),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=32,
task=dict(
type=SubjectiveEvalTask,
judge_cfg=judge_model
)),
)
summarizer = dict(
type=MTBenchSummarizer, judge_type='pair'
)
'''
## single evaluation
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
max_task_size=100,
mode='singlescore',
models = models
),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=32,
task=dict(
type=SubjectiveEvalTask,
judge_cfg=judge_model
)),
)
summarizer = dict(
type=MTBenchSummarizer, judge_type='single'
)
work_dir = 'outputs/mtbench/'

View File

@ -3,5 +3,6 @@ from .compass_arena import CompassArenaDataset # noqa: F401, F403
from .corev2 import Corev2Dataset # noqa: F401, F403
from .creationbench import CreationBenchDataset # noqa: F401, F403
from .information_retrival import IRDataset # noqa: F401, F403
from .mtbench import MTBenchDataset # noqa: F401, F403
from .multiround import MultiroundDataset # noqa: F401, F403
from .subjective_cmp import SubjectiveCmpDataset # noqa: F401, F403

View File

@ -0,0 +1,201 @@
# flake8: noqa: E501
import json
import os.path as osp
import re
from typing import Optional
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
NEED_REF_CATS = ['math', 'reasoning', 'coding', 'arena-hard-200']
pair_v2 = {
'type': 'pairwise',
'system_prompt':
"Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
'prompt_template':
"[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{prediction_r1}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{prediction1_r1}\n[The End of Assistant B's Answer]",
'description': 'Prompt for general questions',
'category': 'general',
'output_format': '[[A]]'
}
pair_v2_multi_turn = {
'type': 'pairwise',
'system_prompt':
"Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. You should choose the assistant that follows the user's instructions and answers the user's questions better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. You should focus on who provides a better answer to the second user question. Begin your evaluation by comparing the responses of the two assistants and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
'prompt_template':
"<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{prediction_r1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{prediction_r2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{prediction1_r1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{prediction1_r2}\n\n<|The End of Assistant B's Conversation with User|>",
'description': 'Prompt for multi-turn general questions',
'category': 'general',
'output_format': '[[A]]'
}
pair_math_v1 = {
'type': 'pairwise',
'system_prompt':
"Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
'prompt_template':
"[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{prediction_r1}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{prediction1_r1}\n[The End of Assistant B's Answer]",
'description': 'Prompt for math questions',
'category': 'math',
'output_format': '[[A]]'
}
pair_math_v1_multi_turn = {
'type': 'pairwise',
'system_prompt':
"Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. Your evaluation should consider correctness and helpfulness. You will be given reference answers, the assistant A's answers, the assistant B's answers. Your job is to determine which assistant provides correct and helpful answers to the second user question. Begin your evaluation by comparing both assistants' answers with the reference answers. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
'prompt_template':
"<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{prediction_r1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{prediction_r2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{prediction1_r1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{prediction1_r2}\n\n<|The End of Assistant B's Conversation with User|>",
'description': 'Prompt for multi-turn general questions',
'category': 'general',
'output_format': '[[A]]'
}
single_v1 = {
'type': 'single',
'system_prompt': 'You are a helpful assistant.',
'prompt_template':
"[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Assistant's Answer]\n{prediction_r1}\n[The End of Assistant's Answer]",
'description': 'Prompt for general questions',
'category': 'general',
'output_format': '[[rating]]'
}
single_math_v1 = {
'type': 'single',
'system_prompt': 'You are a helpful assistant.',
'prompt_template':
"[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{prediction_r1}\n[The End of Assistant's Answer]",
'description': 'Prompt for general questions',
'category': 'math',
'output_format': '[[rating]]'
}
single_v1_multi_turn = {
'type': 'single',
'system_prompt':
"Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant's answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n",
'prompt_template':
"<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{prediction_r1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{prediction_r2}\n\n<|The End of Assistant A's Conversation with User|>",
'description': 'Prompt for general questions',
'category': 'general',
'output_format': '[[rating]]'
}
single_math_v1_multi_turn = {
'type': 'single',
'system_prompt':
"Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n",
'prompt_template':
"<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{prediction_r1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{prediction_r2}\n\n<|The End of Assistant A's Conversation with User|>",
'description': 'Prompt for general questions',
'category': 'math',
'output_format': '[[rating]]'
}
def prompt_construct(problem, multi_turn=False, judge_type='single'):
"""Return the correct pairwise judge."""
question_1 = problem['dialogue'][0]['content']
if multi_turn:
question_2 = problem['dialogue'][2]['content']
if problem['capability'] in NEED_REF_CATS:
ref_answer_1 = problem['others']['reference'][0]
ref_answer_2 = problem['others']['reference'][1]
if judge_type == 'pair':
return pair_math_v1_multi_turn[
'system_prompt'], pair_math_v1_multi_turn[
'prompt_template'].format(
question_1=question_1,
question_2=question_2,
ref_answer_1=ref_answer_1,
ref_answer_2=ref_answer_2,
prediction_r1='{prediction_r1}',
prediction_r2='{prediction_r2}',
prediction1_r1='{prediction1_r1}',
prediction1_r2='{prediction1_r2}')
elif judge_type == 'single':
return single_math_v1_multi_turn[
'system_prompt'], single_math_v1_multi_turn[
'prompt_template'].format(
question_1=question_1,
question_2=question_2,
ref_answer_1=ref_answer_1,
ref_answer_2=ref_answer_2,
prediction_r1='{prediction_r1}',
prediction_r2='{prediction_r2}')
if judge_type == 'pair':
return pair_v2_multi_turn['system_prompt'], pair_v2_multi_turn[
'prompt_template'].format(question_1=question_1,
question_2=question_2,
prediction_r1='{prediction_r1}',
prediction_r2='{prediction_r2}',
prediction1_r1='{prediction1_r1}',
prediction1_r2='{prediction1_r2}')
elif judge_type == 'single':
return single_v1_multi_turn['system_prompt'], single_v1_multi_turn[
'prompt_template'].format(question_1=question_1,
question_2=question_2,
answer_1='{answer_1}',
prediction_r1='{prediction_r1}',
prediction_r2='{prediction_r2}')
if problem['capability'] in NEED_REF_CATS:
ref_answer_1 = problem['others']['reference'][0]
if judge_type == 'pair':
return pair_math_v1['system_prompt'], pair_math_v1[
'prompt_template'].format(question=question_1,
ref_answer_1=ref_answer_1,
prediction_r1='{prediction_r1}',
prediction1_r1='{prediction1_r1}')
elif judge_type == 'single':
return single_math_v1['system_prompt'], single_math_v1[
'prompt_template'].format(question=question_1,
ref_answer_1=ref_answer_1,
prediction_r1='{prediction_r1}')
else:
if judge_type == 'pair':
return pair_v2['system_prompt'], pair_v2['prompt_template'].format(
question=question_1,
prediction_r1='{prediction_r1}',
prediction1_r1='{prediction1_r1}')
elif judge_type == 'single':
return single_v1['system_prompt'], single_v1[
'prompt_template'].format(question=question_1,
prediction_r1='{prediction_r1}')
@LOAD_DATASET.register_module()
class MTBenchDataset(BaseDataset):
def load(self, path: str, name: str, multi_turn=True, judge_type='single'):
filename = osp.join(path, f'{name}.json')
dataset = DatasetDict()
raw_data = []
with open(filename, 'r', encoding='utf-8') as f:
json_data = json.load(f)
for problem in json_data:
if 'dialogue' in problem:
system_prompt, prompt_template = prompt_construct(
problem, multi_turn, judge_type)
dialogue = problem['dialogue']
capability = problem['capability']
others = problem['others']
others['round'] = int(len(dialogue) / 2)
user_contents = [
item['content'] for item in dialogue
if item['role'] == 'user'
]
question = ' '.join(user_contents)
others['question'] = question
raw_data.append({
'dialogue': dialogue,
'capability': capability,
'system_prompt': system_prompt,
'prompt_template': prompt_template,
'others': others,
'judge': {
'capability': capability,
'others': others,
}
})
dataset = Dataset.from_list(raw_data)
return dataset

View File

@ -16,6 +16,17 @@ from opencompass.utils.text_postprocessors import first_number_postprocess
from opencompass.utils.types import get_type_from_cfg
def extract_dicts(data):
max_round_num = max(len(sublist) for sublist in data)
predictions = [[] for _ in range(max_round_num)]
for sublist in data:
for i, d in enumerate(sublist):
predictions[i].append(d.get('assistant'))
for j in range(i + 1, max_round_num):
predictions[j].append(None)
return predictions
def order_preds_and_record_references(predictions,
references,
infer_order,
@ -101,7 +112,6 @@ class LMEvaluator:
def score(self, predictions, references: Optional[List] = None) -> Dict:
dup_indices = []
if type(predictions) == list:
"""Apply to multi-model comparison."""
references = [{} for _ in range(len(predictions[0]['model_preds']))
@ -112,10 +122,12 @@ class LMEvaluator:
# calculate dupicated predictions numbers
total_predictions_num = len(predictions[0])
for i in range(len(predictions[0])):
check = [sub[i] for sub in predictions]
if len(set(check)) == 1:
dup_indices.append(i)
# since there is impossible that two models response same pattern in multi-round chat, so we just check dup for single chat
if isinstance(predictions[0][0], str):
for i in range(len(predictions[0])):
check = [sub[i] for sub in predictions]
if len(set(check)) == 1:
dup_indices.append(i)
elif type(predictions) == dict:
"""Apply to single-model scoring."""
@ -131,9 +143,21 @@ class LMEvaluator:
del references[index]
pred_dict = {}
for i in range(len(predictions)):
key = 'prediction' if i == 0 else f'prediction{i + 1}'
pred_dict[key] = predictions[i]
if isinstance(
predictions[0][0], str
): #single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']]
for i in range(len(predictions)):
key = 'prediction' if i == 0 else f'prediction{i + 1}'
pred_dict[key] = predictions[i]
elif isinstance(
predictions[0][0], list
): #multi round for format like [[[{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}], [{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}]]]
for i in range(len(predictions)):
multiround_predictions = extract_dicts(predictions[i])
for j in range(len(multiround_predictions)):
key = 'prediction' if i == 0 else f'prediction{i}'
key += '_r' + str(j + 1)
pred_dict[key] = multiround_predictions[j]
if self.dataset_cfg:
dataset = build_dataset_from_cfg(self.dataset_cfg)

View File

@ -372,7 +372,7 @@ class ChatInferencer(BaseInferencer):
preds_list.append(temp_dict)
output_handler.save_results(
origin_prompt=None,
prediction=str(preds_list),
prediction=preds_list,
idx=index_copy,
gold=None,
)

View File

@ -4,4 +4,5 @@ from .compass_arena import CompassArenaSummarizer
from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer
from .information_retrival import IRSummarizer
from .mtbench import MTBenchSummarizer
from .multiround import MultiroundSummarizer

View File

@ -0,0 +1,129 @@
# flake8: noqa: E501
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
import numpy as np
from mmengine import ConfigDict
try:
from prettytable import from_csv
except ImportError:
from_csv = None
from opencompass.utils import model_abbr_from_cfg
from .compass_arena import CompassArenaSummarizer
from .subjective_post_process import post_process_autoj
from .utils import get_judgeanswer_and_reference, get_outdir
def post_process_mtbench(judgement: str):
"""Input a string like below:
xxx[[A]]xxx, and extract the judge
"""
pattern = r'\[([A-C]+)\]'
matched_result = re.findall(pattern, judgement)
if matched_result:
return matched_result[0]
else:
return None
def get_capability_results(
judged_answers,
references,
fout,
fout_flag,
model,
):
capability_ratings = defaultdict(int)
capability_counts = defaultdict(int)
for ans, ref in zip(judged_answers, references):
capability_ratings['total'] += ans['score']
capability_counts['total'] += 1
capability_ratings[ref['capability']] += ans['score']
capability_counts[ref['capability']] += 1
capability_avg_ratings = defaultdict(float)
for capability, total_score in capability_ratings.items():
capability_avg_ratings[
capability] = total_score / capability_counts[capability]
columns = list(capability_avg_ratings.keys())
columns.insert(0, columns.pop(columns.index('total')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
if fout_flag == 0:
writer.writerow(['model'] + columns)
writer.writerow([model] +
[capability_avg_ratings[column] for column in columns])
class MTBenchSummarizer(CompassArenaSummarizer):
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, judge_type='single') -> None:
self.judge_type = judge_type
self.tasks = []
self.cfg = config
if self.judge_type == 'single':
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
self.eval_model_abbrs = [
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
elif self.judge_type == 'pair':
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner'][
'compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_map = {
'single': post_process_autoj,
'pair': post_process_mtbench
}
self.judge_function = self.judge_map[self.judge_type]
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
if self.judge_type == 'single':
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
fout_flag = 0
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model, judge_model = eval_model_abbr, self.judge_abbr
fout = osp.join(
output_dir,
'judged-by--' + judge_model + '-capability.csv')
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, self.judge_function)
get_capability_results(judged_answers, references,
fout, fout_flag, model)
fout_flag += 1
else:
print(subdir_path + ' is not exist! please check!')
with open(fout, 'r') as f:
x = from_csv(f)
print(x)
elif self.judge_type == 'pair':
super().summarize()

View File

@ -156,9 +156,9 @@ class SubjectiveEvalTask(BaseTask):
# If take SubjectNaivePartition, get all pred_strs
else:
pred_strs = pred_strs
if ('pred_role' in eval_cfg and 'meta_template' in model_cfg
and not MODELS.get(model_cfg['type']).is_api):
and not MODELS.get(model_cfg['type']).is_api
and isinstance(pred_strs[0], str)):
# Create a prompt template for role config parsing
from opencompass.models.base import LMTemplateParser
parser = LMTemplateParser(model_cfg['meta_template'])