[Feature] support arenahard evaluation (#1096)

* support arenahard

* support arenahard

* support arenahard
This commit is contained in:
bittersweet1999 2024-04-26 15:42:00 +08:00 committed by GitHub
parent 6ba1c4937d
commit e404b72c52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 572 additions and 4 deletions

View File

@ -70,6 +70,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a> ## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2024.04.26\]** We supported the evaluation of [ArenaHard](configs/eval_subjective_arena_hard.py) welcome to try!🔥🔥🔥.
- **\[2024.04.22\]** We supported the evaluation of [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py) 和 [LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py), welcome to try! 🔥🔥🔥 - **\[2024.04.22\]** We supported the evaluation of [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py) 和 [LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py), welcome to try! 🔥🔥🔥
- **\[2024.02.29\]** We supported the MT-Bench, AlpacalEval and AlignBench, more information can be found [here](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html) - **\[2024.02.29\]** We supported the MT-Bench, AlpacalEval and AlignBench, more information can be found [here](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html)
- **\[2024.01.30\]** We release OpenCompass 2.0. Click [CompassKit](https://github.com/open-compass), [CompassHub](https://hub.opencompass.org.cn/home), and [CompassRank](https://rank.opencompass.org.cn/home) for more information ! - **\[2024.01.30\]** We release OpenCompass 2.0. Click [CompassKit](https://github.com/open-compass), [CompassHub](https://hub.opencompass.org.cn/home), and [CompassRank](https://rank.opencompass.org.cn/home) for more information !

View File

@ -69,6 +69,7 @@
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a> ## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>
- **\[2024.04.26\]** 我们支持了 [ArenaHard评测](configs/eval_subjective_arena_hard.py) 欢迎试用!🔥🔥🔥.
- **\[2024.04.22\]** 我们支持了 [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py) 和 [LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py) 的评测,欢迎试用!🔥🔥🔥. - **\[2024.04.22\]** 我们支持了 [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py) 和 [LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py) 的评测,欢迎试用!🔥🔥🔥.
- **\[2024.02.29\]** 我们支持了MT-Bench、AlpacalEval和AlignBench更多信息可以在[这里](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html)找到。 - **\[2024.02.29\]** 我们支持了MT-Bench、AlpacalEval和AlignBench更多信息可以在[这里](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html)找到。
- **\[2024.01.30\]** 我们发布了OpenCompass 2.0。更多信息,请访问[CompassKit](https://github.com/open-compass)、[CompassHub](https://hub.opencompass.org.cn/home)和[CompassRank](https://rank.opencompass.org.cn/home)。 - **\[2024.01.30\]** 我们发布了OpenCompass 2.0。更多信息,请访问[CompassKit](https://github.com/open-compass)、[CompassHub](https://hub.opencompass.org.cn/home)和[CompassRank](https://rank.opencompass.org.cn/home)。

View File

@ -0,0 +1,40 @@
# ArenaHard
## Introduction
The following introduction comes from the official repo:
Arena-Hard is an evaluation tool for instruction-tuned LLMs. It contains 500 challenging user queries, which prompt GPT-4-Turbo as judge to compare the models' responses against a baseline model (default: GPT-4-0314).
## Official link
https://github.com/lm-sys/arena-hard
### Paper
https://lmsys.org/blog/2024-04-19-arena-hard/
## Examples
Input example I:
```
Use ABC notation to write a melody in the style of a folk tune.
```
Output example I (from GPT-4):
```
X:1\nT:Untitled Folk Tune\nM:4/4\nL:1/8\nK:G\n|:G2A2|B2A2|G2E2|D4|E2F2|G2F2|E2C2|B,4|\nA2B2|c2B2|A2F2|E4|D2E2|F2E2|D2B,2|C4:|
```
## Evaluation results
```
LLaMa3-8b-instruct: 20.6 (Official Results)
LLaMa3-8b-instruct: 21.9 (Opencompass Results)
```
## Reference
```
@misc{arenahard2024,
title = {From Live Data to High-Quality Benchmarks: The Arena-Hard Pipeline},
url = {https://lmsys.org/blog/2024-04-19-arena-hard/},
author = {Tianle Li*, Wei-Lin Chiang*, Evan Frick, Lisa Dunlap, Banghua Zhu, Joseph E. Gonzalez, Ion Stoica},
month = {April},
year = {2024}
}
```

View File

@ -0,0 +1,72 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import ArenaHardDataset
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['question'],
output_column='judge',
)
subjective_all_sets = [
"question",
]
subjective_datasets = []
system_prompt = "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate which assistant's answer is better.\n\nBegin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any answers.\n\nWhen evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information.\n\nThen consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive.\n\nThen consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt.\n\nAfter providing your explanation, you must output only one of the following choices as your final verdict with a label:\n\n1. Assistant A is significantly better: [[A>>B]]\n2. Assistant A is slightly better: [[A>B]]\n3. Tie, relatively the same: [[A=B]]\n4. Assistant B is slightly better: [[B>A]]\n5. Assistant B is significantly better: [[B>>A]]\n\nExample output: \"My final verdict is tie: [[A=B]]\"."
judge_prompt = "<|User Prompt|>\n{question}\n\n<|The Start of Assistant A's Answer|>\n{prediction}\n<|The End of Assistant A's Answer|>\n\n<|The Start of Assistant B's Answer|>\n{prediction2}\n<|The End of Assistant B's Answer|>"
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt=system_prompt)
],
round=[
dict(
role='HUMAN',
prompt = judge_prompt
),
]),
),
),
pred_role="BOT",
)
subjective_datasets.append(
dict(
abbr=f"{_name}",
type=ArenaHardDataset,
path="./data/subjective/arena_hard",
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))

View File

@ -0,0 +1,104 @@
from opencompass.models import HuggingFaceCausalLM
from copy import deepcopy
from opencompass.models import TurboMindModel
from mmengine.config import read_base
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import ArenaHardSummarizer
with read_base():
from .datasets.subjective.arena_hard.arena_hard_scoring import subjective_datasets
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
_meta_template = dict(
round=[
dict(role="HUMAN", begin="<|begin_of_text|>user<|end_header_id|>\n\n", end="<|eot_id|>"),
dict(role="BOT", begin="<|begin_of_text|>assistant<|end_header_id|>\n\n", end="<|eot_id|>", generate=True),
],
)
models = [
dict(
type=HuggingFaceCausalLM,
abbr="llama-3-8b-instruct-hf",
path="meta-llama/Meta-Llama-3-8B-Instruct",
model_kwargs=dict(device_map="auto"),
tokenizer_kwargs=dict(
padding_side="left",
truncation_side="left",
use_fast=False,
),
meta_template=_meta_template,
max_out_len=4096,
max_seq_len=2048,
batch_size=8,
run_cfg=dict(num_gpus=1, num_procs=1),
generation_kwargs={"eos_token_id": [128001, 128009]},
batch_padding=True,
)
]
datasets = [*subjective_datasets]
work_dir = 'outputs/arena_hard/'
# -------------Inferen Stage ----------------------------------------
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000000),
runner=dict(
type=LocalRunner,
max_num_workers=32,
task=dict(type=OpenICLInferTask)),
)
judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-1106-preview',
key='',
meta_template=api_meta_template,
query_per_second=1,
max_out_len=1024,
max_seq_len=4096,
batch_size=10,
retry=10,
temperature = 0,
)]
## ------------- Evaluation Configuration
gpt4_0314 = dict(
abbr='gpt4-0314',
type=OpenAI,
)
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
max_task_size=1000000,
mode='m2n',
infer_order='double',
base_models=[gpt4_0314],
compare_models=models,
judge_models=judge_models,
),
runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
given_pred = [{'abbr':'gpt4-0314', 'path':''}]
)
summarizer = dict(
type=ArenaHardSummarizer
)

View File

@ -15,10 +15,11 @@ We support the use of GPT-4 (or other JudgeLLM) for the subjective evaluation of
## Current Supported Subjective Evaluation Datasets ## Current Supported Subjective Evaluation Datasets
1. AlginBench (https://github.com/THUDM/AlignBench) 1. AlignBench (https://github.com/THUDM/AlignBench)
2. MTBench (https://github.com/lm-sys/FastChat) 2. MTBench (https://github.com/lm-sys/FastChat)
3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval) 3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval)
4. CompassArena (Internal dataset) 4. ArenaHard (https://github.com/lm-sys/arena-hard/tree/main)
5. CompassArena (Internal dataset)
## Subjective Evaluation with Custom Dataset ## Subjective Evaluation with Custom Dataset

View File

@ -15,10 +15,11 @@
## 目前已支持的主观评测数据集 ## 目前已支持的主观评测数据集
1. AlginBenchhttps://github.com/THUDM/AlignBench 1. AlignBenchhttps://github.com/THUDM/AlignBench
2. MTBench https://github.com/lm-sys/FastChat 2. MTBench https://github.com/lm-sys/FastChat
3. AlpacaEvalv2 https://github.com/tatsu-lab/alpaca_eval 3. AlpacaEvalv2 https://github.com/tatsu-lab/alpaca_eval
4. CompassArena内部数据集 4. ArenaHard (https://github.com/lm-sys/arena-hard/tree/main)
5. CompassArena内部数据集
## 自定义主观数据集评测 ## 自定义主观数据集评测

View File

@ -1,4 +1,5 @@
from .alignbench import AlignmentBenchDataset # noqa: F401, F403 from .alignbench import AlignmentBenchDataset # noqa: F401, F403
from .arena_hard import ArenaHardDataset # noqa: F401, F403
from .compass_arena import CompassArenaDataset # noqa: F401, F403 from .compass_arena import CompassArenaDataset # noqa: F401, F403
from .corev2 import Corev2Dataset # noqa: F401, F403 from .corev2 import Corev2Dataset # noqa: F401, F403
from .creationbench import CreationBenchDataset # noqa: F401, F403 from .creationbench import CreationBenchDataset # noqa: F401, F403

View File

@ -0,0 +1,35 @@
import json
import os.path as osp
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class ArenaHardDataset(BaseDataset):
def load(self, path: str, name: str):
filename = osp.join(path, f'{name}.jsonl')
dataset = DatasetDict()
raw_data = []
with open(filename, 'r', encoding='utf-8') as file:
for line in file:
problem = json.loads(line)
question_id = problem['question_id']
cluster = problem['cluster']
question = problem['turns'][0][
'content'] # only one turn in arena_hard
raw_data.append({
'question': question,
'capability': cluster,
'judge': {
'capability': cluster,
'question': question,
'question_id': question_id
}
})
dataset = Dataset.from_list(raw_data)
return dataset

View File

@ -2,6 +2,7 @@
from .alignmentbench import AlignmentBenchSummarizer from .alignmentbench import AlignmentBenchSummarizer
from .all_obj import AllObjSummarizer from .all_obj import AllObjSummarizer
from .alpacaeval import AlpacaSummarizer from .alpacaeval import AlpacaSummarizer
from .arenahard import ArenaHardSummarizer
from .compass_arena import CompassArenaSummarizer from .compass_arena import CompassArenaSummarizer
from .corev2 import Corev2Summarizer from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer from .creationbench import CreationBenchSummarizer

View File

@ -0,0 +1,311 @@
# flake8: noqa
# yapf: disable
import argparse
import datetime
import json
import math
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime
from glob import glob
from itertools import product
import mmengine
import numpy as np
#import plotly.express as px
import pandas as pd
import tiktoken
from mmengine import ConfigDict
from sklearn.linear_model import LogisticRegression
from tabulate import tabulate
from tqdm import tqdm
from opencompass.partitioners.sub_naive import remove_duplicate_pairs
from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg
from .utils import get_outdir
def compute_mle_elo(df, SCALE=400, BASE=10, INIT_RATING=1000):
models = pd.concat([df['model_a'], df['model_b']]).unique()
models = pd.Series(np.arange(len(models)), index=models)
# duplicate battles
df = pd.concat([df, df], ignore_index=True)
p = len(models.index)
n = df.shape[0]
X = np.zeros([n, p])
X[np.arange(n), models[df['model_a']]] = +math.log(BASE)
X[np.arange(n), models[df['model_b']]] = -math.log(BASE)
# one A win => two A win
Y = np.zeros(n)
Y[df['winner'] == 'model_a'] = 1.0
# one tie => one A win + one B win
# find tie + tie (both bad) index
tie_idx = (df['winner'] == 'tie') | (df['winner'] == 'tie (bothbad)')
tie_idx[len(tie_idx)//2:] = False
Y[tie_idx] = 1.0
lr = LogisticRegression(fit_intercept=False, penalty=None, tol=1e-8)
lr.fit(X,Y)
elo_scores = SCALE * lr.coef_[0] + INIT_RATING
# set anchor as gpt4-0314 = 1000
if 'gpt4-0314' in models.index:
elo_scores += 1000 - elo_scores[models['gpt4-0314']]
return pd.Series(elo_scores, index = models.index).sort_values(ascending=False)
def get_bootstrap_result(battles, func_compute_elo, num_round):
rows = []
for i in tqdm(range(num_round), desc='bootstrap'):
rows.append(func_compute_elo(battles.sample(frac=1.0, replace=True)))
df = pd.DataFrame(rows)
return df[df.median().sort_values(ascending=False).index]
def preety_print_two_ratings(ratings_1, ratings_2, column_names):
df = pd.DataFrame([
[n, ratings_1[n], ratings_2[n]] for n in ratings_1.keys()
], columns=['Model', column_names[0], column_names[1]]).sort_values(column_names[0], ascending=False).reset_index(drop=True)
df[column_names[0]] = (df[column_names[0]] + 0.5).astype(int)
df[column_names[1]] = (df[column_names[1]] + 0.5).astype(int)
df.index = df.index + 1
return df
def visualize_bootstrap_scores(df, title):
bars = pd.DataFrame(dict(
lower = df.quantile(.025),
rating = df.quantile(.5),
upper = df.quantile(.975))).reset_index(names='model').sort_values('rating', ascending=False)
bars['error_y'] = bars['upper'] - bars['rating']
bars['error_y_minus'] = bars['rating'] - bars['lower']
bars['rating_rounded'] = np.round(bars['rating'], 2)
fig = px.scatter(bars, x='model', y='rating', error_y='error_y',
error_y_minus='error_y_minus', text='rating_rounded',
title=title)
fig.update_layout(xaxis_title='Model', yaxis_title='Rating',
height=600)
return fig
def predict_win_rate(elo_ratings, SCALE=400, BASE=10, INIT_RATING=1000):
names = sorted(list(elo_ratings.keys()))
wins = defaultdict(lambda: defaultdict(lambda: 0))
for a in names:
for b in names:
ea = 1 / (1 + BASE ** ((elo_ratings[b] - elo_ratings[a]) / SCALE))
wins[a][b] = ea
wins[b][a] = 1 - ea
data = {
a: [wins[a][b] if a != b else np.NAN for b in names]
for a in names
}
df = pd.DataFrame(data, index=names)
df.index.name = 'model_a'
df.columns.name = 'model_b'
return df.T
def model_abbr_from_cfg_used_in_summarizer(model):
if model.get('summarizer_abbr', None):
return model['summarizer_abbr']
else:
return model_abbr_from_cfg(model)
def post_process_compass_arena(s):
if result := re.findall('\[\[([AB<>=]+)\]\]', s):
return result[0]
else:
return None
def get_win_rate_column(df, column, baseline='gpt4-0314'):
to_dict = df[['model', column]].set_index('model').to_dict()[column]
win_rate_table = predict_win_rate(to_dict)
return win_rate_table[baseline].fillna(0.5).apply(lambda x: round(x * 100, 2))
def get_battles_from_judgment(dataset, subdir_path, post_process, first_game_only=False, WEIGHT=3):
arena_hard_battles = pd.DataFrame()
print('Turning judgment results into battles...')
dataset_abbr = dataset_abbr_from_cfg(dataset)
filename = osp.join(subdir_path, dataset_abbr + '.json')
partial_filename = osp.join(subdir_path, dataset_abbr + '_0.json')
if osp.exists(osp.realpath(filename)):
result = mmengine.load(filename)
elif osp.exists(osp.realpath(partial_filename)):
filename = partial_filename
result = {}
i = 1
partial_dict_flag = 0
while osp.exists(osp.realpath(filename)):
res = mmengine.load(filename)
for k, v in res.items():
result[partial_dict_flag] = v
partial_dict_flag += 1
filename = osp.join(subdir_path,
dataset_abbr + '_' + str(i) + '.json')
i += 1
else:
result = {}
if len(result) == 0:
print('*' * 100)
print('There are no results for ' + filename + ' or ' +
partial_filename)
print('*' * 100)
assert len(result) > 0
judged_answers = []
references = []
for k, v in result.items():
output = {
'model_a': v['gold']['answer1'],
'model_b': v['gold']['answer2']}
processed_judge = post_process(v['prediction'])
if processed_judge is not None:
weight = 1
if processed_judge == 'A=B':
output['winner'] = 'tie'
elif processed_judge == 'A>B':
output['winner'] = 'model_a'
elif processed_judge == 'A>>B':
output['winner'] = 'model_a'
weight = WEIGHT
elif processed_judge == 'B>A':
output['winner'] = 'model_b'
elif processed_judge == 'B>>A':
output['winner'] = 'model_b'
weight = WEIGHT
else:
weight = 0
else:
weight = 0
if weight:
arena_hard_battles = pd.concat([arena_hard_battles, pd.DataFrame([output] * weight)])
arena_hard_battles.to_json(os.path.join(subdir_path,'arena_hard_battles.jsonl'), lines=True, orient='records')
return arena_hard_battles
class ArenaHardSummarizer:
"""Do the subjectivity analyze based on evaluation results.
Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""
def __init__(self,
config: ConfigDict,
judge_type='general',
check_pos_bias=True,
summary_type='single') -> None:
self.tasks = []
self.cfg = config
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner']['compare_models']
self.judge_models = self.cfg.get('judge_models', None)
self.meta_judge_model = self.cfg.eval.partitioner.get('meta_judge_model', None)
self.judge_type = judge_type
assert self.judge_type in ['general']
self.judge_map = {'general': post_process_compass_arena}
self.judge_function = self.judge_map[self.judge_type]
self.check_pos_bias = check_pos_bias
self.summary_type = summary_type
def get_score(self, time_str):
output_dir, results_folder = get_outdir(self.cfg, time_str)
model_combinations = list(product(self.base_models, self.compare_models))
unique_combinations = remove_duplicate_pairs([combo for combo in model_combinations if combo[0] != combo[1]])
if self.meta_judge_model is not None:
self.judge_models.append(self.meta_judge_model)
scores = {}
for idx, judge_model_cfg in enumerate(self.judge_models):
judge_model = model_abbr_from_cfg(judge_model_cfg)
for dataset in self.cfg['datasets']:
dataset_abbr = dataset_abbr_from_cfg(dataset)
for model_pair in unique_combinations:
model1 = model_pair[0]['abbr']
model2 = model_pair[1]['abbr']
if idx == len(self.judge_models):
subdir = model1 + '_' + model2 + '_summarized-by--' + judge_model
else:
subdir = model1 + '_' + model2 + '_judged-by--' + judge_model
subdir_path = os.path.join(results_folder, subdir)
if not os.path.isdir(subdir_path):
print(subdir_path + ' is not exist! please check!')
continue
battles = get_battles_from_judgment(dataset, subdir_path, self.judge_function)
bootstrap_online_elo = compute_mle_elo(battles)
np.random.seed(42)
bootstrap_elo_lu = get_bootstrap_result(battles, compute_mle_elo, 100)
bootstrap_elo_lu.to_json(os.path.join(subdir_path,'bootstrapping_results.jsonl'), lines=True, orient='records')
stats = pd.DataFrame()
stats['results'] = None
stats['results'] = stats['results'].astype('object')
for i, model in enumerate(bootstrap_online_elo.index):
assert model in bootstrap_elo_lu.columns
stats.at[i, 'model'] = model
stats.at[i, 'score'] = bootstrap_online_elo[model]
stats.at[i, 'lower'] = np.percentile(bootstrap_elo_lu[model], 2.5)
stats.at[i, 'upper'] = np.percentile(bootstrap_elo_lu[model], 97.5)
if model == 'gpt4-0314':
stats.at[i, 'avg_tokens'] = 423
else:
with open(os.path.join(output_dir.split('summary')[0], 'predictions', model, dataset_abbr+'.json'), 'r') as f:
model_preds = json.load(f)
pred_length = 0
for k, v in model_preds.items():
pred_length += len(tiktoken.encoding_for_model('gpt-3.5-turbo').encode(v['prediction']))
pred_length /= len(model_preds)
stats.at[i, 'avg_tokens'] = pred_length
stats.at[i, 'results'] = bootstrap_elo_lu[model].tolist()
stats.sort_values(by='model', inplace=True)
stats['score'] = get_win_rate_column(stats, 'score', 'gpt4-0314').tolist()
stats['lower'] = get_win_rate_column(stats, 'lower', 'gpt4-0314').tolist()
stats['upper'] = get_win_rate_column(stats, 'upper', 'gpt4-0314').tolist()
decimal = 1
stats.sort_values(by='score', ascending=False, inplace=True)
for _, row in stats.iterrows():
interval = str((round(row['lower'] - row['score'], decimal), round(row['upper'] - row['score'], decimal)))
print(f"{row['model'] : <30} | score: {round(row['score'], decimal) : ^5} | 95% CI: {interval : ^12} | average #tokens: {int(row['avg_tokens'])}")
stats.to_json(os.path.join(output_dir,'arena_hard_leaderboard.json'), orient='records', indent=4)
def summarize(
self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S'),
):
"""Summarize the subjectivity analysis based on evaluation results.
Args:
time_str (str): Timestamp for file naming.
Returns:
pd.DataFrame: The summary results.
"""
self.get_score(time_str)