[Feature] support alpacaeval (#809)

* support alpacaeval_v1

* Update opencompass/summarizers/subjective/__init__.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* Update opencompass/summarizers/subjective/alpacaeval_v1.py

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>

* fix conflict

* support alpacaeval v2

* support alpacav2

---------

Co-authored-by: Songyang Zhang <tonysy@users.noreply.github.com>
This commit is contained in:
bittersweet1999 2024-02-04 14:18:36 +08:00 committed by GitHub
parent 0919b08ec8
commit 7806cd0f64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 456 additions and 0 deletions

View File

@ -0,0 +1,98 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import SubjectiveCmpDataset
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['question'],
output_column='judge',
)
subjective_all_sets = [
"alpaca_eval",
]
subjective_datasets = []
gpt4_prompt = """
I want you to create a leaderboard of different of large-language models. To do so, I will give you the instructions (prompts) given to the models, and the responses of two models. Please rank the models based on which responses would be preferred by humans. All inputs and outputs should be python dictionaries.
Here is the prompt:
{
"instruction": "{question}"
}
Here are the outputs of the models:
[
{
"model": "model_1",
"answer": "{prediction}"
},
{
"model": "model_2",
"answer": "{prediction2}"
}
]
Now please rank the models by the quality of their answers, so that the model with rank 1 has the best output. Then return a list of the model names and ranks, i.e., produce the following output:
[
{"model": <model-name>, "rank": <model-rank>},
{"model": <model-name>, "rank": <model-rank>}
]
Your response must be a valid Python dictionary and should contain nothing else because we will directly execute it in Python. Please provide the ranking that the majority of humans would give.
"""
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
infer_order='random',
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a helpful assistant, that ranks models by the quality of their answers.")
],
round=[
dict(
role='HUMAN',
prompt = gpt4_prompt
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=SubjectiveCmpDataset,
path="./data/subjective/",
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,100 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import SubjectiveCmpDataset
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['question'],
output_column='judge',
)
subjective_all_sets = [
"alpaca_eval",
]
subjective_datasets = []
gpt4_prompt = """
I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.
## Instruction
{
"instruction": "{question}",
}
## Model Outputs
Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.
{
{
"model_identifier": "m",
"output": "{prediction}"
},
{
"model_identifier": "M",
"output": "{prediction2}"
}
}
## Task
Evaluate the models based on the quality and relevance of their outputs, and select the model that generated the best output. Answer by providing the model identifier of the best model. We will use your output as the name of the best model, so make sure your output only contains one of the following model identifiers and nothing else (no quotes, no spaces, no new lines, ...): m or M.
## Best Model Identifier
"""
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
infer_order='random',
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt="You are a highly efficient assistant, who evaluates and selects the best large language model (LLMs) based on the quality of their responses to a given instruction. This process will be used to create a leaderboard reflecting the most accurate and human-preferred answers.")
],
round=[
dict(
role='HUMAN',
prompt = gpt4_prompt
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=SubjectiveCmpDataset,
path="./data/subjective/",
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,82 @@
from mmengine.config import read_base
with read_base():
from .models.qwen.hf_qwen_7b_chat import models as hf_qwen_7b_chat
from .models.qwen.hf_qwen_14b_chat import models as hf_qwen_14b_chat
from .models.chatglm.hf_chatglm3_6b import models as hf_chatglm3_6b
from .models.baichuan.hf_baichuan2_7b_chat import models as hf_baichuan2_7b
from .models.hf_internlm.hf_internlm_chat_7b import models as hf_internlm_chat_7b
from .models.hf_internlm.hf_internlm_chat_20b import models as hf_internlm_chat_20b
from .datasets.subjective.alpaca_eval.alpacav1_judgeby_gpt4 import subjective_datasets as alpacav1
from .datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import subjective_datasets as alpacav2
datasets = [*alpacav2]
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3
from opencompass.models.openai_api import OpenAI, OpenAIAllesAPIN
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import AlpacaSummarizer
models = [*hf_qwen_7b_chat, *hf_chatglm3_6b]
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True)
],
reserved_roles=[
dict(role='SYSTEM', api_role='SYSTEM'),
],
)
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=256,
task=dict(type=OpenICLInferTask)),
)
judge_model = dict(
abbr='GPT4-Turbo',
type=OpenAI, path='gpt-4-1106-preview',
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=1024,
max_seq_len=4096,
batch_size=2,
retry=20,
temperature = 0
)
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
max_task_size=1000,
mode='m2n',
base_models = [*hf_chatglm3_6b],
compare_models = [*hf_qwen_7b_chat]
),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=256,
task=dict(
type=SubjectiveEvalTask,
judge_cfg=judge_model
)),
)
work_dir = 'outputs/alpaca/'
summarizer = dict(
type=AlpacaSummarizer, judge_type='v2'
)

View File

@ -1,5 +1,6 @@
# flake8: noqa: F401, E501
from .alignmentbench import AlignmentBenchSummarizer
from .alpacaeval import AlpacaSummarizer
from .compass_arena import CompassArenaSummarizer
from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer

View File

@ -0,0 +1,175 @@
# flake8: noqa: E501
import ast
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
from itertools import product
import mmengine
from mmengine import ConfigDict
from prettytable import from_csv
from opencompass.partitioners.sub_naive import remove_duplicate_pairs
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
from .utils import get_judgeanswer_and_reference, get_outdir
def post_process_alpacav1(completion: str):
r"""Parse a completion that contains a list of dictionary and returns the rank of the model1.
Examples
--------
>>> ranking_parser("[{'model': 'model_1', 'rank': 1}, {'model': 'model_2', 'rank': 2}]")
1
>>> ranking_parser("[{'model': 'model_1', 'rank': 2}, {'model': 'model_2', 'rank': 1}]")
2
>>> ranking_parser("[{'model': 'model_1', 'rank': 3}, {'model': 'model_2', 'rank': 1}]")
None
"""
try:
if isinstance(completion, str):
completion = re.findall(r'\[.*?\]', completion)[0]
ordered_completions = ast.literal_eval(completion)
else:
ordered_completions = completion
rank = [c for c in ordered_completions
if c['model'] == 'model_1'][0]['rank']
if rank in [1, 2]:
return {'rank': rank}
else:
return None
except Exception as e:
return None
def post_process_alpacav2(completion: str):
r"""Parse a completion that contains 'm' or 'M' and returns the rank of the model1.
Examples
--------
>>> ranking_parser("m")
1
>>> ranking_parser("M")
2
>>> ranking_parser("s")
None
"""
try:
if completion[0] == 'm':
return {'rank': 1}
elif completion[0] == 'M':
return {'rank': 2}
else:
return None
except Exception as e:
return None
class AlpacaSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self, config: ConfigDict, judge_type='v2') -> None:
self.tasks = []
self.cfg = config
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
self.judge_type = judge_type
assert self.judge_type in ['v1', 'v2']
self.judge_map = {
'v1': post_process_alpacav1,
'v2': post_process_alpacav2
}
self.judge_function = self.judge_map[self.judge_type]
def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
dataset_cfgs = self.cfg['datasets']
output_dir, results_folder = get_outdir(self.cfg, time_str)
model_combinations = list(
product(self.base_models, self.compare_models))
unique_combinations = remove_duplicate_pairs(
[combo for combo in model_combinations if combo[0] != combo[1]])
for model_pair in unique_combinations:
model1, model2, judge_model = model_pair[0]['abbr'], model_pair[1][
'abbr'], self.judge_abbr
subdir = model1 + '_' + model2 + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
fout = osp.join(output_dir,
'judged-by--' + judge_model + '-report.csv')
for dataset in dataset_cfgs:
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, self.judge_function)
win_model1, win_model2, categories = defaultdict(
float), defaultdict(float), defaultdict(float)
model1, model2 = references[0]['answer1'], references[0][
'answer2']
for prediction, reference in zip(judged_answers,
references):
categories['total'] += 1
categories[reference['capability']] += 1
if prediction['rank'] == 1:
if reference['answer1'] == model1:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
else:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
else:
if reference['answer1'] == model1:
win_model2[reference['capability']] += 1
win_model2['total'] += 1
else:
win_model1[reference['capability']] += 1
win_model1['total'] += 1
for capability in categories:
if capability not in win_model1:
win_model1[capability] = 0.0
else:
win_model1[capability] = round(
(win_model1[capability] /
categories[capability]) * 100, 2)
if capability not in win_model2:
win_model2[capability] = 0.0
else:
win_model2[capability] = round(
(win_model2[capability] /
categories[capability]) * 100, 2)
scores = {
'win_' + model1: win_model1,
'win_' + model2: win_model2
}
rows = list(scores.keys())
columns = list(scores[rows[0]].keys())
columns.insert(0, columns.pop(columns.index('total')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([model1 + '_vs_' + model2] + columns)
for row in rows:
writer.writerow(
[row] +
[scores[row][column] for column in columns])
else:
print(subdir_path + ' is not exist! please check!')
with open(fout, 'r') as f:
x = from_csv(f)
print(x)