[Feature] Add multi_round dataset evaluation (#766)

* multi_round dataset

* add multi_round evaluation
This commit is contained in:
bittersweet1999 2024-01-04 18:37:52 +08:00 committed by GitHub
parent 7cd65d49d8
commit be369c3e06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 372 additions and 8 deletions

View File

@ -0,0 +1,55 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import MultiroundDataset
subjective_reader_cfg = dict(
input_columns=['dialogue', 'capability', 'gpt4_prefix', 'gpt4_suffix'],
output_column='judge',
)
subjective_all_sets = [
"FunctionalMT",
]
data_path ="data/subjective/"
subjective_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template="""{dialogue}""",
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, infer_mode='every'),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = "{gpt4_prefix}{prediction}{gpt4_suffix}"
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=MultiroundDataset,
path=data_path,
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -84,6 +84,7 @@ from .subject_alignmentbench import AlignmentBenchDataset # noqa: F401, F403
from .subject_corev2 import Corev2Dataset # noqa: F401, F403
from .subject_creationbench import CreationBenchDataset # noqa: F401, F403
from .subject_creationv01 import Creationv01Dataset # noqa: F401, F403
from .subject_multiround import MultiroundDataset # noqa: F401, F403
from .subjective_cmp import SubjectiveCmpDataset # noqa: F401, F403
from .summedits import * # noqa: F401, F403
from .summscreen import * # noqa: F401, F403

View File

@ -0,0 +1,118 @@
# flake8: noqa: E501
import json
import os.path as osp
import re
from typing import Optional
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
base_prefix_en = """
You are a helper who will help me to evaluate the quality of AI assistants.
You will be provided with a dialogue between a user and an assistant of N rounds (each round includes an utterance from user and an utterance from the assistant).
In this dialogue, user will ask assistant to finish a series of tasks.
Your job is to help me detect whether the AI assistant successfully completed these tasks.
You should print a list of round index (starting from 1), in which rounds the assistants failed to accomplish the task assigned by user in the current round. Besides, you also need to provide the reason.
If the assistant succeed in the entire series of tasks, print an empty list '[]'.
Here I will show you some examples:
Example 1:
[{'round': 1, 'user': 'I have data about different car models that need to be formatted into a CSV. The data includes the car model, engine power in horsepower, and year of release. The cars are Mustang, 450hp, released in 2020; Corvette, 490hp, released in 2021; and Charger, 370hp, released in 2019. Could you create a CSV file for me?', 'assistant': 'Model,Engine Power (HP),Year of Release\nMustang,450,2020\nCorvette,490,2021\nCharger,370,2019'}, {'round': 2, 'user': 'Okay, now please add a new field for the body style of each car. Mustang is a Coupe, Corvette is a Coupe as well, and Charger is a Sedan.', 'assistant': 'Model,Engine Power (HP),Year of Release,Body Style\nMustang,450,2020,Coupe\nCorvette,490,2021,Coupe\nCharger,370,2019,Sedan'}]
Output: []
Reason: The assistant succeed in all round.
Example 2:
[{'round': 1, 'user': 'I have data about different car models that need to be formatted into a CSV. The data includes the car model, engine power in horsepower, and year of release. The cars are Mustang, 450hp, released in 2020; Corvette, 490hp, released in 2021; and Charger, 370hp, released in 2019. Could you create a CSV file for me?', 'assistant': 'Model,Engine Power (HP),Year of Release\nMustang,450,2020\nCorvette,490,2021\nCharger,370,2019'}, {'round': 2, 'user': 'Okay, now please add a new field for the body style of each car. Mustang is a Coupe, Corvette is a Coupe as well, and Charger is a Sedan.', 'assistant': 'Model,Engine Power (HP),Year of Release,Body Style\nMustang,450,2020,Coupe\nCorvette,490,2021,Coupe\nCharger,370,2019,Coupe'}]
Output: [2]
Reason: In round 2, the body style of Charger should be Sedan, but the assistant returned Coupe.
Example 3:
[{'round': 1, 'user': 'I have data about different car models that need to be formatted into a CSV. The data includes the car model, engine power in horsepower, and year of release. The cars are Mustang, 450hp, released in 2020; Corvette, 490hp, released in 2021; and Charger, 370hp, released in 2019. Could you create a CSV file for me?', 'assistant': 'Model,Engine Power (HP),Year of Release\nMustang,400,2020\nCorvette,400,2021\nCharger,400,2019'}, {'round': 2, 'user': 'Okay, now please add a new field for the body style of each car. Mustang is a Coupe, Corvette is a Coupe as well, and Charger is a Sedan.', 'assistant': 'Model,Engine Power (HP),Year of Release,Body Style\nMustang,450,2020,Coupe\nCorvette,490,2021,Coupe\nCharger,370,2019,Coupe'}]
Output: [1, 2]
Reason: In round 1, the horse power of all models are wrong. In round 2, the body style of Charger should be Sedan, but the assistant returned Coupe.
Now it's your turn:
Dialogue:
"""
base_suffix_en = """
Based on the dialogue, give your Output and Reason in the above format.
"""
base_prefix_ch = """
你是一个评估AI assistant对话内容质量的助手
你将获得用户和与assistant的N轮对话每轮包括用户的问题和assistant的回答
在这个对话中用户会要求assistant完成一系列任务
你需要检测assistant是否成功完成了这些任务
你需要输出一个轮次索引的列表列出assistant未能完成任务的轮次此外你还需要提供原因
如果助手成功完成了整个系列任务请打印一个空列表'[]'
下面是一些例子
例1
[{'round': 1, 'user': '我有一些关于员工的数据,需要整理成 CSV 格式。数据包括员工的姓名、年龄和职位。员工是张伟30岁工程师李娜28岁设计师王浩32岁经理。请帮我制作一个 CSV 文件。', 'assistant': '姓名,年龄,职位\n张伟,30,工程师\n李娜,28,设计师\n王浩,32,经理'}, {'round': 2, 'user': '好的现在请在每个员工信息后面增加一个薪水字段。张伟的薪水是8000李娜是9000王浩是12000。', 'assistant': '姓名,年龄,职位,薪水\n张伟,30,工程师,8000\n李娜,28,设计师,9000\n王浩,32,经理,12000'}]
输出[]
原因assistant成功完成了所有的任务
例2
[{'round': 1, 'user': '我有一些关于员工的数据,需要整理成 CSV 格式。数据包括员工的姓名、年龄和职位。员工是张伟30岁工程师李娜28岁设计师王浩32岁经理。请帮我制作一个 CSV 文件。', 'assistant': '姓名,年龄,职位\n张伟,30,工程师\n李娜,28,设计师\n王浩,32,经理'}, {'round': 2, 'user': '好的现在请在每个员工信息后面增加一个薪水字段。张伟的薪水是8000李娜是9000王浩是12000。', 'assistant': '姓名,年龄,职位,薪水\n张伟,30,工程师,8000\n李娜,28,设计师,8000\n王浩,40,经理,12000'}]
输出[2]
原因在第二轮对话中assistant错误的填写了李娜的薪水和王浩的年龄
例3
[{'round': 1, 'user': '我有一些关于员工的数据,需要整理成 CSV 格式。数据包括员工的姓名、年龄和职位。员工是张伟30岁工程师李娜28岁设计师王浩32岁经理。请帮我制作一个 CSV 文件。', 'assistant': '姓名,年龄,职位\n张伟,30,工程师\n张伟,28,设计师\n张伟,32,经理'}, {'round': 2, 'user': '好的现在请在每个员工信息后面增加一个薪水字段。张伟的薪水是8000李娜是9000王浩是12000。', 'assistant': '姓名,年龄,职位,薪水\n张伟,30,工程师,8000\n李娜,28,设计师,8000\n王浩,40,经理,12000'}]
输出[1, 2]
原因在第一和第二轮对话中assistant都出现了错误
现在请评估以下对话
对话
"""
base_suffix_ch = """
基于以上对话请按照上面的格式给出你的"输出""原因"
"""
def prompt_construct(sample):
if sample['others']['language'] == 'zh':
return base_prefix_ch, base_suffix_ch
elif sample['others']['language'] == 'en':
return base_prefix_en, base_suffix_en
# ref = str(sample['dialogue'])
# base_suffix.format(ref=ref)
@LOAD_DATASET.register_module()
class MultiroundDataset(BaseDataset):
def load(
self,
path: str,
name: str,
):
filename = osp.join(path, f'{name}.json')
dataset = DatasetDict()
raw_data = []
with open(filename, 'r', encoding='utf-8') as f:
json_data = json.load(f)
for problem in json_data:
gpt4_prefix, gpt4_suffix = prompt_construct(problem)
dialogue = problem['dialogue']
capability = str(problem['capability'])
others = problem['others']
others['round'] = int(len(dialogue) / 2)
raw_data.append({
'dialogue': dialogue,
'capability': capability,
'gpt4_prefix': gpt4_prefix,
'gpt4_suffix': gpt4_suffix,
'others': others,
'judge': {
'capability': capability,
'others': others
}
})
dataset = Dataset.from_list(raw_data)
return dataset

View File

@ -188,6 +188,7 @@ class ChatInferencer(BaseInferencer):
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
self.dialogue_mode = False
def _set_meta_template(self, model):
origin = model.template_parser
@ -314,6 +315,9 @@ class ChatInferencer(BaseInferencer):
item[input_columns[0]], dict):
# Single input column and it's already a chat.
chat = item[input_columns[0]]
elif 'dialogue' in input_columns:
chat = item['dialogue']
self.dialogue_mode = True
else:
raise ValueError('Cannot construct chat from the dataset.')
@ -339,19 +343,39 @@ class ChatInferencer(BaseInferencer):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
index_copy = index
for i in assistant_indices:
history = chat[:i]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
chat[i]['content'] = output
index += 1
if not self.dialogue_mode:
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
index += 1
if self.dialogue_mode:
# dialogue mode for subjective evaluation
assert len(chat) % 2 == 0
round_num = int(len(chat) / 2)
preds_list = []
for i in range(round_num):
temp_dict = {
'round': i + 1,
'user': chat[i * 2]['content'],
'assistant': chat[i * 2 + 1]['content']
}
preds_list.append(temp_dict)
output_handler.save_results(
origin_prompt=None,
prediction=str(preds_list),
idx=index_copy,
gold=None,
)
def infer_every_with_gt(self, chat: List[dict], index: int,
output_handler):

View File

@ -5,4 +5,5 @@ from .corev2 import Corev2Summarizer # noqa: F401
from .creationbench import CreationBenchSummarizer
from .creationv01 import Creationv01Summarizer # noqa: F401
from .default import DefaultSummarizer # noqa: F401
from .multiround import MultiroundSummarizer # noqa: F401
from .subjective import SubjectiveSummarizer # noqa: F401

View File

@ -153,7 +153,9 @@ def get_capability_results(judged_answers,
capability] = total_score / capability_counts[capability]
temp_list = []
total_column_num = 2
for category, sub_categories in categories.items():
total_column_num += 1 + len(sub_categories)
capability_avg_ratings[category + '总分'] = np.mean([
np.mean(capability_avg_ratings[cat])
for cat in categories[category]
@ -168,7 +170,7 @@ def get_capability_results(judged_answers,
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
if fout_flag == 0:
num_header = [str(i) for i in range(12)]
num_header = [str(i) for i in range(total_column_num)]
writer.writerow(num_header)
header = ['模型', '总分']

View File

@ -0,0 +1,163 @@
# flake8: noqa: E501
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
import numpy as np
from mmengine import ConfigDict
try:
from prettytable import from_csv
except ImportError:
from_csv = None
from opencompass.utils import model_abbr_from_cfg
from .utils import get_judgeanswer_and_reference, get_outdir
CATEGORIES = {
'中文': ['json_zh', 'csv_zh', 'email_zh', 'markdown_zh', 'article_zh'],
'英文': ['json_en', 'csv_en', 'email_en', 'markdown_en', 'article_en'],
}
def post_process_multiround(judgement: str):
"""Input a string like below:
xxx输出[1, 2, 3, 4, 5, 6]xxx,
xxxOutput: [1, 2, 3, 4, 5, 6]xxx,
and extract the list
"""
pattern = r'\[([^]]*)\]'
match = re.search(pattern, judgement)
if match:
temp = match.group(1)
if temp == '':
return 0
numbers = temp.split(', ')
try:
if all(num.isdigit() for num in numbers):
return len([int(num) for num in numbers])
else:
return None
except ValueError:
return None
else:
return None
def get_capability_results(judged_answers,
references,
fout,
fout_flag,
model,
categories=CATEGORIES):
capability_ratings = defaultdict(float)
capability_counts = defaultdict(int)
for ans, ref in zip(judged_answers, references):
lan = ref['others']['language']
capability_ratings[ref['capability'] + '_' +
lan] += (ref['others']['round'] -
ans) / ref['others']['round']
capability_counts[ref['capability'] + '_' + lan] += 1
capability_avg_ratings = defaultdict(float)
for capability, total_score in capability_ratings.items():
capability_avg_ratings[
capability] = total_score / capability_counts[capability]
temp_list = []
total_column_num = 2
for category, sub_categories in categories.items():
total_column_num += 1 + len(sub_categories)
capability_avg_ratings[category + '总分'] = np.mean([
np.mean(capability_avg_ratings[cat])
for cat in categories[category]
])
temp_list.append(category + '总分')
capability_avg_ratings['总分'] = 0
for temp in temp_list:
capability_avg_ratings['总分'] += capability_avg_ratings[temp]
capability_avg_ratings['总分'] /= len(temp_list)
scores = {model: capability_avg_ratings}
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
if fout_flag == 0:
num_header = [str(i) for i in range(total_column_num)]
writer.writerow(num_header)
header = ['模型', '总分']
for category, sub_categories in categories.items():
header.append(category)
header.extend([None for _ in range(len(sub_categories))])
writer.writerow(header)
sub_header = ['模型', '总分']
for category, sub_categories in categories.items():
sub_header.extend([category + '总分'])
sub_header.extend(sub_categories)
writer.writerow(sub_header)
fout_flag += 1
row = [model]
row.append(scores[model]['总分'])
for category, sub_categories in categories.items():
row.append(scores[model][category + '总分'])
for sub_category in sub_categories:
row.append(scores[model][sub_category])
writer.writerow(row)
class MultiroundSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict) -> None:
self.tasks = []
self.cfg = config
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
self.eval_model_abbrs = [
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
fout_flag = 0
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model, judge_model = eval_model_abbr, self.judge_abbr
fout = osp.join(
output_dir,
'judged-by--' + judge_model + '-capability.csv')
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, post_process_multiround)
get_capability_results(judged_answers, references, fout,
fout_flag, model)
else:
print(subdir_path + ' is not exist! please check!')
with open(fout, 'r') as f:
x = from_csv(f)
print(x)