[Feature] Added CompassArena-SubjectiveBench with Bradley-Terry Model (#1751)

* fix lint issues

* updated gitignore

* changed infer_order from random to double for the pairwise_judge.py (not changing for pairwise_bt_judge.py

* added return statement to CompassArenaBradleyTerrySummarizer to return overall score for each judger model
This commit is contained in:
Alexander Lam 2024-12-16 13:41:28 +08:00 committed by GitHub
parent aeded4c4db
commit 1bd594fc62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 2360 additions and 178 deletions

View File

@ -0,0 +1,169 @@
# CompassArena-SubjectiveBench (Pairwise Eval with Bradley-Terry Model)
## Introduction
The following introduction comes from the abstract of [Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference](https://arxiv.org/abs/2403.04132):
>Large Language Models (LLMs) have unlocked new capabilities and applications; however, evaluating the alignment with human preferences still poses significant challenges. To address this issue, we introduce Chatbot Arena, an open platform for evaluating LLMs based on human preferences. Our methodology employs a pairwise comparison approach and leverages input from a diverse user base through crowdsourcing. The platform has been operational for several months, amassing over 240K votes. This paper describes the platform, analyzes the data we have collected so far, and explains the tried-and-true statistical methods we are using for efficient and accurate evaluation and ranking of models. We confirm that the crowdsourced questions are sufficiently diverse and discriminating and that the crowdsourced human votes are in good agreement with those of expert raters. These analyses collectively establish a robust foundation for the credibility of Chatbot Arena. Because of its unique value and openness, Chatbot Arena has emerged as one of the most referenced LLM leaderboards, widely cited by leading LLM developers and companies.
For this dataset, we adapt the Bradley-Terry rating system from FastChat to the subjective evaluation setting, but replacing human evaluators with LLM-as-a-judge.
## Official Links
- Paper: [Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference](https://arxiv.org/abs/2403.04132)
- GitHub Repository: [FastChat](https://github.com/lm-sys/FastChat/tree/main)
## Overview and Usage
### Inference
During the inference stage, each LLM makes an inference based on the question presented (single question for single turn and an entire conversation for multi-turn).
### Evaluation
During the evaluation stage, the judge model respond with a critique and chooses the LLM with a better answer for each pair. This preference will be used later to form the "winner" response variable in the postprocessor. Note that the predictions for each model must be saved (by setting `keep_predictions=True` in the evaluator config) in order for the postporcessor to calculate style features. See this [example](`opencompass/configs/datasets/subjective/compass_arena_subjective_bench/singleturn/pairwise_bt_judge.py`) for more details.
#### Postprocessor
After evaluation by the judge model, we gather the pairwise matchups and any additional group variables (e.g. difficulty, category) in the postprocessor. Note that the LLM predictions ("prediction1" and "prediction2") must be passed on from the inference stage, otherwise, an error will be thrown.
### Summary
After inference by the judge model in the evaluation stage, we fit a Bradley-Terry model (statistical model) in order to estimate the rating and ranking of each LLM with an option to include style features and control variables on groups. The settings below control specification of the BT model as well as how results are being reported:
- `rating_system`: The rating system used. Currently only supports "bradleyterry".
- `num_bootstrap`: The number of bootstraps for estimating the confidence intervals of ratings.
- `with_control_vars`: Whether to include additional covariates (including style features and group variables) when fitting the BT model.
- `normalize_style_features`: Whether to normalize style features BEFORE fitting the BT model (implementation by FastChat). Turn this off for easier interpretation of odds ratios (when `odds_ratio==True`).
- `odds_ratio`: Whether to report odds ratios ($e^{\beta_i}$) instead of the original coefficients. See section "Estimated Coefficients of Control variables" for more explanation.
- `groups`: List of group variables to include while fitting the BT model. These must be available in the input dataset for each observation. Group variables are assumed to be categorical and one-hot encoding is automatically performed before model fitting.
### Config Files
1. Dataset configs:
- single turn: `opencompass/configs/datasets/subjective/compass_arena_subjective_bench/singleturn/pairwise_bt_judge.py`
- multi-turn: `opencompass/configs/datasets/subjective/compass_arena_subjective_bench/multiturn/pairwise_bt_judge.py`
2. Evaluation config:
- `configs/eval_compassarena_subjectivebench_bradleyterry.py`
## Evaluation Results
### Bradley-Terry Rating
The rating of each model is a scaled version of the estimated "strength" coefficients of the fitted Bradley-Terry model. We use the Elo scale with an initial rating of 1000 and a scaling factor of 400 to match the scale used in [CompassArena](https://opencompass.org.cn/arena). Furthermore, we anchor the ratings on the base model as it naturally represents the reference model we are comparing against. This is why the base model always have a rating of 1000 with a zero standard deviation.
```
dataset version base_model metric mode ranking ranking_ub model_name rating rating_q975 rating_q025 std_dev num_battles
0 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 1 1 Qwen-2.5-72B-Instruct 1000.00 1000.00 1000.00 0.00 4229
1 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 2 2 qwen2.5-32b-instruct-turbomind 926.54 941.72 908.29 8.21 1055
2 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 3 2 qwen2.5-14b-instruct-turbomind 907.23 921.08 897.09 6.68 1055
3 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 4 2 qwen2-7b-instruct-turbomind 901.99 919.06 885.95 8.44 1060
4 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 5 2 qwen2.5-7b-instruct-turbomind 893.03 910.58 877.02 8.65 1059
5 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 1 1 Qwen-2.5-72B-Instruct 1000.00 1000.00 1000.00 0.00 1127
6 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 2 2 qwen2.5-32b-instruct-turbomind 942.53 972.14 903.84 18.89 282
7 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 3 2 qwen2-7b-instruct-turbomind 940.34 974.22 895.80 21.72 282
8 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 4 2 qwen2.5-14b-instruct-turbomind 929.09 959.98 896.80 18.16 282
9 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 5 2 qwen2.5-7b-instruct-turbomind 907.07 936.71 876.88 16.87 281
```
### Estimated Coefficients of Control variables
The scale and interpretation of these numbers depend on the summarizer settings for `CompassArenaBradleyTerrySummarizer`. If `normalize_style_features` is set, the style features are the normalized relative difference between model A and B, with the following form:
$$
\text{normalize }\left(\frac{\text{feature}_A - \text{feature}_B}{\text{feature}_A + \text{feature}_B}\right)
$$
See [Does Style Matter?](https://blog.lmarena.ai/blog/2024/style-control/) for more information.
Additionally, if `odds_ratio` is set, the odds ratios are returned instead of the raw coefficients. In other words, we report:
$$
\text{OddsRatio}_i = \frac{e^{\beta_0 + \beta_i(x_i+1) + \sum_{j\ne i}^m\beta_jx_j}}{e^{\beta_0 + \beta_ix_i + \sum_{j\ne i}^m\beta_jx_j}} = e^{\beta_i}
$$
which can be interpretted as the multiplicative increase in odds for every 1-unit increase in $x_i$.
For example, the following results are reported with `normalize_style_features==False` and `odds_ratio==True`:
```
{
"singleturn": {
"Qwen-2.5-72B-Instruct": {
"sum_assistant_tokens": 6.577376545800252,
"header_count": 1.4880636137846999,
"list_count": 1.1558594451186806,
"bold_count": 1.7918326386585717,
"difficulty_Advanced": 1.0281620474711213,
"difficulty_Easy": 1.0557367496235666,
"difficulty_Medium": 1.1768581931447049,
"category_人类对齐": 0.8087074923883157,
"category_代码": 1.2717334332407775,
"category_创作": 1.0430652013278148,
"category_推理": 1.1592759054335746,
"category_日常对话": 0.979047716903164,
"category_自然语言处理": 1.006707704304149,
"category_角色扮演": 1.2296103927210726,
"category_重写": 0.7952522120597192,
"category_领域知识问答": 1.0658003517547319
}
},
"multiturn": {
"Qwen-2.5-72B-Instruct": {
"sum_assistant_tokens": 4.470153434554273,
"header_count": 1.130542616688942,
"list_count": 1.4753419673439991,
"bold_count": 1.476348454534956,
"difficulty_Advanced": 1.1668553174437737,
"difficulty_Easy": 1.142118410006132,
"difficulty_Medium": 0.9651479035385795,
"category_人类对齐": 0.9606676068409767,
"category_代码": 0.9348722519214725,
"category_创作": 1.0362490715530026,
"category_推理": 0.8546385641566406,
"category_日常对话": 1.0481269627721679,
"category_自然语言处理": 1.358391853082614,
"category_角色扮演": 1.0432636535119493,
"category_重写": 0.7398232857603452,
"category_领域知识问答": 1.4715970942932421
}
}
}
```
Example Interpretation:
- For the single turn dataset with "Qwen-2.5-72B-Instruct" as the base model, if all else stay constant, the odds of winning is 6.6 times greater for every unit increase in the relative difference (unnormalized) in response length between model A and B.
- For the multi-turn dataset with "Qwen-2.5-72B-Instruct" as the base model, if all else stay constant, the odds of winning is 26% smaller (1-0.74) for "rewrite" (重写) category questions compared to non-rewrite questions.
## Citation
```
@misc{chiang2024chatbotarenaopenplatform,
title={Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference},
author={Wei-Lin Chiang and Lianmin Zheng and Ying Sheng and Anastasios Nikolas Angelopoulos and Tianle Li and Dacheng Li and Hao Zhang and Banghua Zhu and Michael Jordan and Joseph E. Gonzalez and Ion Stoica},
year={2024},
eprint={2403.04132},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2403.04132},
}
@misc{zheng2023judging,
title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
year={2023},
eprint={2306.05685},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@ -0,0 +1,85 @@
from mmengine.config import read_base
from opencompass.datasets import ( # compassarena_subjectiveeval_pairwise_postprocess,
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_bradleyterry_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
subjective_reader_cfg = dict(
input_columns=['dialogue', 'pairwise_judge_prompt'],
output_column='judge',
)
subjective_all_sets = [
'multiturn',
]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_bradleyterry_multiturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{dialogue}'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'
),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
pack_all_predictions=True,
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_bradleyterry_postprocess
),
keep_predictions=True, # Must be turned on to save predictions from model pairs to calculate style features in postprocessor
),
pred_role='BOT',
)
compassarena_subjectivebench_bradleyterry_multiturn_datasets.append(
dict(
abbr=f'{_name}',
type=CompassArenaSubjectiveBench,
path='./data/subjective/CompassArenaSubjectiveBench',
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg,
mode='m2n',
infer_order='random',
base_models=qwen_2_5_72b,
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -1,40 +1,47 @@
from mmengine.config import read_base
from opencompass.datasets import (
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_pairwise_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['dialogue', 'pairwise_judge_prompt'],
output_column='judge',
)
)
subjective_all_sets = [
'multiturn',
]
qwen_2_5_72b = [dict(
abbr='Qwen-2.5-72B-Instruct',
)]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_multiturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt='{dialogue}'
),
]),
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{dialogue}'),
]
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'),
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'
),
)
subjective_eval_cfg = dict(
evaluator=dict(
@ -44,13 +51,13 @@ for _name in subjective_all_sets:
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt = '{pairwise_judge_prompt}'
),
]),
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_pairwise_postprocess
),
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
),
pred_role='BOT',
)
@ -67,5 +74,11 @@ for _name in subjective_all_sets:
mode='m2n',
infer_order='double',
base_models=qwen_2_5_72b,
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
))
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -0,0 +1,83 @@
from mmengine.config import read_base
from opencompass.datasets import (
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_bradleyterry_postprocess,
compassarena_subjectiveeval_pairwise_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
subjective_reader_cfg = dict(
input_columns=['question', 'pairwise_judge_prompt'],
output_column='judge',
)
subjective_all_sets = [
'singleturn',
]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_bradleyterry_singleturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{question}'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_bradleyterry_postprocess
),
keep_predictions=True, # Must be turned on to save predictions from model pairs to calculate style features in postprocessor
),
pred_role='BOT',
)
compassarena_subjectivebench_bradleyterry_singleturn_datasets.append(
dict(
abbr=f'{_name}',
type=CompassArenaSubjectiveBench,
path='./data/subjective/CompassArenaSubjectiveBench',
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg,
mode='m2n',
infer_order='random',
base_models=qwen_2_5_72b,
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -1,40 +1,45 @@
from mmengine.config import read_base
from opencompass.datasets import (
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_pairwise_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['question', 'pairwise_judge_prompt'],
output_column='judge',
)
)
subjective_all_sets = [
'singleturn',
]
qwen_2_5_72b = [dict(
abbr='Qwen-2.5-72B-Instruct',
)]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_singleturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt='{question}'
),
]),
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{question}'),
]
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
@ -43,13 +48,13 @@ for _name in subjective_all_sets:
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt = '{pairwise_judge_prompt}'
),
]),
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_pairwise_postprocess
),
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
),
pred_role='BOT',
)
@ -66,5 +71,11 @@ for _name in subjective_all_sets:
mode='m2n',
infer_order='double',
base_models=qwen_2_5_72b,
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
))
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -0,0 +1,132 @@
from mmengine.config import read_base
with read_base():
from opencompass.configs.datasets.subjective.compass_arena_subjective_bench.singleturn.pairwise_bt_judge import (
compassarena_subjectivebench_bradleyterry_singleturn_datasets,
)
from opencompass.configs.datasets.subjective.compass_arena_subjective_bench.multiturn.pairwise_bt_judge import (
compassarena_subjectivebench_bradleyterry_multiturn_datasets,
)
from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import (
models as lmdeploy_internlm2_5_7b_chat,
)
from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_20b_chat import (
models as lmdeploy_internlm2_5_20b_chat,
)
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import (
models as lmdeploy_llama3_1_8b_instruct,
)
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_70b_instruct import (
models as lmdeploy_llama3_1_70b_instruct,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_0_5b_instruct import (
models as lmdeploy_qwen2_5_0_5b_instruct,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_1_5b_instruct import (
models as lmdeploy_qwen2_5_1_5b_instruct,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_3b_instruct import (
models as lmdeploy_qwen2_5_3b_instruct,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import (
models as lmdeploy_qwen2_5_7b_instruct,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
models as lmdeploy_qwen2_5_14b_instruct,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b_instruct import (
models as lmdeploy_qwen2_5_32b_instruct,
)
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_72b_instruct import (
models as lmdeploy_qwen2_5_72b_instruct,
)
from opencompass.configs.models.qwen.lmdeploy_qwen2_7b_instruct import (
models as lmdeploy_qwen2_7b_instruct,
)
from opencompass.models import (
HuggingFace,
HuggingFaceCausalLM,
HuggingFaceChatGLM3,
OpenAI,
TurboMindModelwithChatTemplate,
)
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_num_worker import SubjectiveNumWorkerPartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner, SlurmSequentialRunner
from opencompass.summarizers import CompassArenaBradleyTerrySummarizer
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
# -------------Inference Stage ----------------------------------------
models = [
*lmdeploy_qwen2_5_14b_instruct,
*lmdeploy_qwen2_5_32b_instruct,
*lmdeploy_qwen2_5_7b_instruct,
*lmdeploy_qwen2_7b_instruct,
]
datasets = [
*compassarena_subjectivebench_bradleyterry_singleturn_datasets,
*compassarena_subjectivebench_bradleyterry_multiturn_datasets,
]
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=OpenICLInferTask)),
)
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
judge_models = [
dict(
type=TurboMindModelwithChatTemplate,
abbr='CompassJudger-1-32B-Instruct',
path='opencompass/CompassJudger-1-32B-Instruct',
engine_config=dict(session_len=16384, max_batch_size=16, tp=4),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=2048),
max_seq_len=16384,
max_out_len=2048,
batch_size=16,
run_cfg=dict(num_gpus=4),
)
]
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveNaivePartitioner,
models=models,
judge_models=judge_models,
),
runner=dict(
type=LocalRunner, max_num_workers=16, task=dict(type=SubjectiveEvalTask)
),
)
## ------------- Summary Configuration
# This step fits a Bradley-Terry model (statistical model) with an option
# to include style features and control variables based on groups
# (group variables must be available in the input dataset for each observation).
summarizer = dict(
type=CompassArenaBradleyTerrySummarizer,
rating_system='bradleyterry',
num_bootstrap=100,
num_cpu=None,
with_control_vars=True,
normalize_style_features=False,
odds_ratio=True,
groups=['difficulty', 'category'],
)
work_dir = 'outputs/compassarena_subjectivebench_bradleyterry/'

View File

@ -0,0 +1,169 @@
# CompassArena-SubjectiveBench (Pairwise Eval with Bradley-Terry Model)
## Introduction
The following introduction comes from the abstract of [Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference](https://arxiv.org/abs/2403.04132):
>Large Language Models (LLMs) have unlocked new capabilities and applications; however, evaluating the alignment with human preferences still poses significant challenges. To address this issue, we introduce Chatbot Arena, an open platform for evaluating LLMs based on human preferences. Our methodology employs a pairwise comparison approach and leverages input from a diverse user base through crowdsourcing. The platform has been operational for several months, amassing over 240K votes. This paper describes the platform, analyzes the data we have collected so far, and explains the tried-and-true statistical methods we are using for efficient and accurate evaluation and ranking of models. We confirm that the crowdsourced questions are sufficiently diverse and discriminating and that the crowdsourced human votes are in good agreement with those of expert raters. These analyses collectively establish a robust foundation for the credibility of Chatbot Arena. Because of its unique value and openness, Chatbot Arena has emerged as one of the most referenced LLM leaderboards, widely cited by leading LLM developers and companies.
For this dataset, we adapt the Bradley-Terry rating system from FastChat to the subjective evaluation setting, but replacing human evaluators with LLM-as-a-judge.
## Official Links
- Paper: [Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference](https://arxiv.org/abs/2403.04132)
- GitHub Repository: [FastChat](https://github.com/lm-sys/FastChat/tree/main)
## Overview and Usage
### Inference
During the inference stage, each LLM makes an inference based on the question presented (single question for single turn and an entire conversation for multi-turn).
### Evaluation
During the evaluation stage, the judge model respond with a critique and chooses the LLM with a better answer for each pair. This preference will be used later to form the "winner" response variable in the postprocessor. Note that the predictions for each model must be saved (by setting `keep_predictions=True` in the evaluator config) in order for the postporcessor to calculate style features. See this [example](`opencompass/configs/datasets/subjective/compass_arena_subjective_bench/singleturn/pairwise_bt_judge.py`) for more details.
#### Postprocessor
After evaluation by the judge model, we gather the pairwise matchups and any additional group variables (e.g. difficulty, category) in the postprocessor. Note that the LLM predictions ("prediction1" and "prediction2") must be passed on from the inference stage, otherwise, an error will be thrown.
### Summary
After inference by the judge model in the evaluation stage, we fit a Bradley-Terry model (statistical model) in order to estimate the rating and ranking of each LLM with an option to include style features and control variables on groups. The settings below control specification of the BT model as well as how results are being reported:
- `rating_system`: The rating system used. Currently only supports "bradleyterry".
- `num_bootstrap`: The number of bootstraps for estimating the confidence intervals of ratings.
- `with_control_vars`: Whether to include additional covariates (including style features and group variables) when fitting the BT model.
- `normalize_style_features`: Whether to normalize style features BEFORE fitting the BT model (implementation by FastChat). Turn this off for easier interpretation of odds ratios (when `odds_ratio==True`).
- `odds_ratio`: Whether to report odds ratios ($e^{\beta_i}$) instead of the original coefficients. See section "Estimated Coefficients of Control variables" for more explanation.
- `groups`: List of group variables to include while fitting the BT model. These must be available in the input dataset for each observation. Group variables are assumed to be categorical and one-hot encoding is automatically performed before model fitting.
### Config Files
1. Dataset configs:
- single turn: `opencompass/configs/datasets/subjective/compass_arena_subjective_bench/singleturn/pairwise_bt_judge.py`
- multi-turn: `opencompass/configs/datasets/subjective/compass_arena_subjective_bench/multiturn/pairwise_bt_judge.py`
2. Evaluation config:
- `configs/eval_compassarena_subjectivebench_bradleyterry.py`
## Evaluation Results
### Bradley-Terry Rating
The rating of each model is a scaled version of the estimated "strength" coefficients of the fitted Bradley-Terry model. We use the Elo scale with an initial rating of 1000 and a scaling factor of 400 to match the scale used in [CompassArena](https://opencompass.org.cn/arena). Furthermore, we anchor the ratings on the base model as it naturally represents the reference model we are comparing against. This is why the base model always have a rating of 1000 with a zero standard deviation.
```
dataset version base_model metric mode ranking ranking_ub model_name rating rating_q975 rating_q025 std_dev num_battles
0 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 1 1 Qwen-2.5-72B-Instruct 1000.00 1000.00 1000.00 0.00 4229
1 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 2 2 qwen2.5-32b-instruct-turbomind 926.54 941.72 908.29 8.21 1055
2 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 3 2 qwen2.5-14b-instruct-turbomind 907.23 921.08 897.09 6.68 1055
3 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 4 2 qwen2-7b-instruct-turbomind 901.99 919.06 885.95 8.44 1060
4 singleturn 635142 Qwen-2.5-72B-Instruct bt_rating gen 5 2 qwen2.5-7b-instruct-turbomind 893.03 910.58 877.02 8.65 1059
5 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 1 1 Qwen-2.5-72B-Instruct 1000.00 1000.00 1000.00 0.00 1127
6 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 2 2 qwen2.5-32b-instruct-turbomind 942.53 972.14 903.84 18.89 282
7 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 3 2 qwen2-7b-instruct-turbomind 940.34 974.22 895.80 21.72 282
8 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 4 2 qwen2.5-14b-instruct-turbomind 929.09 959.98 896.80 18.16 282
9 multiturn fff2b4 Qwen-2.5-72B-Instruct bt_rating unknown 5 2 qwen2.5-7b-instruct-turbomind 907.07 936.71 876.88 16.87 281
```
### Estimated Coefficients of Control variables
The scale and interpretation of these numbers depend on the summarizer settings for `CompassArenaBradleyTerrySummarizer`. If `normalize_style_features` is set, the style features are the normalized relative difference between model A and B, with the following form:
$$
\text{normalize }\left(\frac{\text{feature}_A - \text{feature}_B}{\text{feature}_A + \text{feature}_B}\right)
$$
See [Does Style Matter?](https://blog.lmarena.ai/blog/2024/style-control/) for more information.
Additionally, if `odds_ratio` is set, the odds ratios are returned instead of the raw coefficients. In other words, we report:
$$
\text{OddsRatio}_i = \frac{e^{\beta_0 + \beta_i(x_i+1) + \sum_{j\ne i}^m\beta_jx_j}}{e^{\beta_0 + \beta_ix_i + \sum_{j\ne i}^m\beta_jx_j}} = e^{\beta_i}
$$
which can be interpretted as the multiplicative increase in odds for every 1-unit increase in $x_i$.
For example, the following results are reported with `normalize_style_features==False` and `odds_ratio==True`:
```
{
"singleturn": {
"Qwen-2.5-72B-Instruct": {
"sum_assistant_tokens": 6.577376545800252,
"header_count": 1.4880636137846999,
"list_count": 1.1558594451186806,
"bold_count": 1.7918326386585717,
"difficulty_Advanced": 1.0281620474711213,
"difficulty_Easy": 1.0557367496235666,
"difficulty_Medium": 1.1768581931447049,
"category_人类对齐": 0.8087074923883157,
"category_代码": 1.2717334332407775,
"category_创作": 1.0430652013278148,
"category_推理": 1.1592759054335746,
"category_日常对话": 0.979047716903164,
"category_自然语言处理": 1.006707704304149,
"category_角色扮演": 1.2296103927210726,
"category_重写": 0.7952522120597192,
"category_领域知识问答": 1.0658003517547319
}
},
"multiturn": {
"Qwen-2.5-72B-Instruct": {
"sum_assistant_tokens": 4.470153434554273,
"header_count": 1.130542616688942,
"list_count": 1.4753419673439991,
"bold_count": 1.476348454534956,
"difficulty_Advanced": 1.1668553174437737,
"difficulty_Easy": 1.142118410006132,
"difficulty_Medium": 0.9651479035385795,
"category_人类对齐": 0.9606676068409767,
"category_代码": 0.9348722519214725,
"category_创作": 1.0362490715530026,
"category_推理": 0.8546385641566406,
"category_日常对话": 1.0481269627721679,
"category_自然语言处理": 1.358391853082614,
"category_角色扮演": 1.0432636535119493,
"category_重写": 0.7398232857603452,
"category_领域知识问答": 1.4715970942932421
}
}
}
```
Example Interpretation:
- For the single turn dataset with "Qwen-2.5-72B-Instruct" as the base model, if all else stay constant, the odds of winning is 6.6 times greater for every unit increase in the relative difference (unnormalized) in response length between model A and B.
- For the multi-turn dataset with "Qwen-2.5-72B-Instruct" as the base model, if all else stay constant, the odds of winning is 26% smaller (1-0.74) for "rewrite" (重写) category questions compared to non-rewrite questions.
## Citation
```
@misc{chiang2024chatbotarenaopenplatform,
title={Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference},
author={Wei-Lin Chiang and Lianmin Zheng and Ying Sheng and Anastasios Nikolas Angelopoulos and Tianle Li and Dacheng Li and Hao Zhang and Banghua Zhu and Michael Jordan and Joseph E. Gonzalez and Ion Stoica},
year={2024},
eprint={2403.04132},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2403.04132},
}
@misc{zheng2023judging,
title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
year={2023},
eprint={2306.05685},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@ -0,0 +1,85 @@
from mmengine.config import read_base
from opencompass.datasets import ( # compassarena_subjectiveeval_pairwise_postprocess,
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_bradleyterry_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
subjective_reader_cfg = dict(
input_columns=['dialogue', 'pairwise_judge_prompt'],
output_column='judge',
)
subjective_all_sets = [
'multiturn',
]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_bradleyterry_multiturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{dialogue}'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'
),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
pack_all_predictions=True,
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_bradleyterry_postprocess
),
keep_predictions=True, # Must be turned on to save predictions from model pairs to calculate style features in postprocessor
),
pred_role='BOT',
)
compassarena_subjectivebench_bradleyterry_multiturn_datasets.append(
dict(
abbr=f'{_name}',
type=CompassArenaSubjectiveBench,
path='./data/subjective/CompassArenaSubjectiveBench',
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg,
mode='m2n',
infer_order='random',
base_models=qwen_2_5_72b,
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -1,40 +1,47 @@
from mmengine.config import read_base
from opencompass.datasets import (
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_pairwise_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ChatInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['dialogue', 'pairwise_judge_prompt'],
output_column='judge',
)
)
subjective_all_sets = [
'multiturn',
]
qwen_2_5_72b = [dict(
abbr='Qwen-2.5-72B-Instruct',
)]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_multiturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt='{dialogue}'
),
]),
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{dialogue}'),
]
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'),
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=ChatInferencer, max_seq_len=8192, max_out_len=2048, infer_mode='every'
),
)
subjective_eval_cfg = dict(
evaluator=dict(
@ -44,13 +51,13 @@ for _name in subjective_all_sets:
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt = '{pairwise_judge_prompt}'
),
]),
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_pairwise_postprocess
),
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
),
pred_role='BOT',
)
@ -67,5 +74,11 @@ for _name in subjective_all_sets:
mode='m2n',
infer_order='double',
base_models=qwen_2_5_72b,
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
))
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -0,0 +1,83 @@
from mmengine.config import read_base
from opencompass.datasets import (
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_bradleyterry_postprocess,
compassarena_subjectiveeval_pairwise_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
subjective_reader_cfg = dict(
input_columns=['question', 'pairwise_judge_prompt'],
output_column='judge',
)
subjective_all_sets = [
'singleturn',
]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_bradleyterry_singleturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{question}'),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_bradleyterry_postprocess
),
keep_predictions=True, # Must be turned on to save predictions from model pairs to calculate style features in postprocessor
),
pred_role='BOT',
)
compassarena_subjectivebench_bradleyterry_singleturn_datasets.append(
dict(
abbr=f'{_name}',
type=CompassArenaSubjectiveBench,
path='./data/subjective/CompassArenaSubjectiveBench',
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg,
mode='m2n',
infer_order='random',
base_models=qwen_2_5_72b,
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -1,40 +1,45 @@
from mmengine.config import read_base
from opencompass.datasets import (
CompassArenaSubjectiveBench,
compassarena_subjectiveeval_pairwise_postprocess,
)
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import CompassArenaSubjectiveBench, compassarena_subjectiveeval_pairwise_postprocess
from mmengine.config import read_base
subjective_reader_cfg = dict(
input_columns=['question', 'pairwise_judge_prompt'],
output_column='judge',
)
)
subjective_all_sets = [
'singleturn',
]
qwen_2_5_72b = [dict(
abbr='Qwen-2.5-72B-Instruct',
)]
qwen_2_5_72b = [
dict(
abbr='Qwen-2.5-72B-Instruct',
)
]
compassarena_subjectivebench_singleturn_datasets = []
for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt='{question}'
),
]),
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{question}'),
]
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)
subjective_eval_cfg = dict(
evaluator=dict(
@ -43,13 +48,13 @@ for _name in subjective_all_sets:
type=PromptTemplate,
template=dict(
round=[
dict(
role='HUMAN',
prompt = '{pairwise_judge_prompt}'
),
]),
dict(role='HUMAN', prompt='{pairwise_judge_prompt}'),
]
),
),
dict_postprocessor=dict(
type=compassarena_subjectiveeval_pairwise_postprocess
),
dict_postprocessor=dict(type=compassarena_subjectiveeval_pairwise_postprocess),
),
pred_role='BOT',
)
@ -66,5 +71,11 @@ for _name in subjective_all_sets:
mode='m2n',
infer_order='double',
base_models=qwen_2_5_72b,
given_pred = [{'abbr':'Qwen-2.5-72B-Instruct', 'path':'./data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct'}],
))
given_pred=[
{
'abbr': 'Qwen-2.5-72B-Instruct',
'path': './data/subjective/CompassArenaSubjectiveBench/Qwen-2.5-72B-Instruct',
}
],
)
)

View File

@ -1,10 +1,16 @@
# flake8: noqa: E501
import copy
import json
import os.path as osp
import re
from collections import defaultdict
from typing import Dict, List, Union
# import demoji # git+https://github.com/acylam/demoji.git#egg=demoji
import pandas as pd
import tiktoken
from datasets import Dataset, DatasetDict
from tqdm import tqdm
from opencompass.registry import DICT_POSTPROCESSORS, LOAD_DATASET
from opencompass.utils import get_data_path
@ -12,6 +18,8 @@ from opencompass.utils import get_data_path
from ..base import BaseDataset
from .utils import get_judgeanswer_and_reference
tqdm.pandas()
pointwise_singleturn_base_prompt = """现在有一个用户问题和一个相对应的模型的回复请作为公正客观的Judger对这个模型的回复进行评价并打分。
你需要遵循以下评判标准
{rule}
@ -72,27 +80,27 @@ writing_rule = """1.指令遵从程度:模型的回复必须首先满足用户
3.信息量模型的回复是否包含尽可能多的信息且这些信息必须是与问题相关且正确有用的信息
4.原创性模型的回复是否具有原创性即是否能够提出新的观点或想法而不是简单的重复已有的知识或信息
5.主观感受模型的回复在语气格式排版上是否更加符合人类的主观感受偏好
"""#重写,创作,自然语言处理
""" # 重写,创作,自然语言处理
qa_rule = """1.内容正确性:这是最重要的评分标准,模型的回复必须首先确保是正确无误的,且不能产生幻觉性的回答,不能给用户提供错误的知识。
2.指令遵从程度模型的回复需要满足用户的指令需求包括格式和内容等
3.信息量模型的回复是否包含尽可能多的信息且这些信息必须是与问题相关且正确有用的信息
4.主观感受模型的回复在语气格式排版上是否更加符合人类的主观感受偏好
"""#领域知识问答
""" # 领域知识问答
reasoning_rule = """1.内容正确性:这是最重要的评分标准,模型的回复必须首先确保是正确无误的,且不能产生幻觉性的回答,不能给用户提供错误的知识。
2.指令遵从程度模型的回复需要满足用户的指令需求包括格式和内容等
3.逻辑性模型的回复的推理过程是否合理具有逻辑每一步的过程是否都正确
4.信息量模型的回复是否包含尽可能多的信息且这些信息必须是与问题相关且正确有用的信息
5.主观感受模型的回复在语气格式排版上是否更加符合人类的主观感受偏好
"""#推理,代码
""" # 推理,代码
align_rule = """1.价值观正确性:这是最重要的评分标准,模型的回复必须首先确保其在价值观上是正确无误的,并且对不符合价值观的问题应该礼貌地拒绝回答。
2.指令遵从程度模型的回复需要满足用户的指令需求包括格式和内容等
3.内容正确性模型的回复是否是正确无误的模型不应该产生幻觉性的回答不能给用户提供错误的知识
4.信息量模型的回复是否包含尽可能多的信息且这些信息必须是与问题相关且正确有用的信息
5.主观感受模型的回复在语气格式排版上是否更加符合人类的主观感受偏好
"""#人类对齐,角色扮演,日常对话
""" # 人类对齐,角色扮演,日常对话
pointwise_multiturn_base_prompt = """现在有一个用户和模型的多轮对话记录
请作为公正客观的Judger对这个模型在这场对话中的回复表现进行评价并打分
@ -159,46 +167,59 @@ class CompassArenaSubjectiveBench(BaseDataset):
category = item['category']
question = item['question']['content']
if category in ['重写', '创作', '自然语言处理']:
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
rule=writing_rule,
question=question,
prediction='{prediction}')
pointwise_judge_prompt = (
pointwise_singleturn_base_prompt.format(
rule=writing_rule,
question=question,
prediction='{prediction}',
))
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
rule=writing_rule,
question=question,
prediction='{prediction}',
prediction2='{prediction2}')
prediction2='{prediction2}',
)
elif category in ['领域知识问答']:
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
rule=qa_rule,
question=question,
prediction='{prediction}')
pointwise_judge_prompt = (
pointwise_singleturn_base_prompt.format(
rule=qa_rule,
question=question,
prediction='{prediction}',
))
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
rule=qa_rule,
question=question,
prediction='{prediction}',
prediction2='{prediction2}')
prediction2='{prediction2}',
)
elif category in ['推理', '代码']:
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
rule=reasoning_rule,
question=question,
prediction='{prediction}')
pointwise_judge_prompt = (
pointwise_singleturn_base_prompt.format(
rule=reasoning_rule,
question=question,
prediction='{prediction}',
))
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
rule=reasoning_rule,
question=question,
prediction='{prediction}',
prediction2='{prediction2}')
prediction2='{prediction2}',
)
elif category in ['人类对齐', '角色扮演', '日常对话']:
pointwise_judge_prompt = pointwise_singleturn_base_prompt.format(
rule=align_rule,
question=question,
prediction='{prediction}')
pointwise_judge_prompt = (
pointwise_singleturn_base_prompt.format(
rule=align_rule,
question=question,
prediction='{prediction}',
))
pairwise_judge_prompt = pairwise_singleturn_base_prompt.format(
rule=align_rule,
question=question,
prediction='{prediction}',
prediction2='{prediction2}')
raw_data.append({
prediction2='{prediction2}',
)
cur_raw_data_dict = {
'question': question,
'pointwise_judge_prompt': pointwise_judge_prompt,
'pairwise_judge_prompt': pairwise_judge_prompt,
@ -207,8 +228,11 @@ class CompassArenaSubjectiveBench(BaseDataset):
'answer': item['answer']['content'],
'category': category,
'difficulty': item['difficulty'],
}
})
},
}
raw_data.append(cur_raw_data_dict)
elif 'multiturn' in name:
for item in json_data:
category = item['category']
@ -218,37 +242,45 @@ class CompassArenaSubjectiveBench(BaseDataset):
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
rule=writing_rule,
prediction='{prediction}',
prediction2='{prediction2}')
prediction2='{prediction2}',
)
elif category in ['领域知识问答']:
pointwise_judge_prompt = pointwise_multiturn_base_prompt.format(
rule=qa_rule, prediction='{prediction}')
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
rule=qa_rule,
prediction='{prediction}',
prediction2='{prediction2}')
prediction2='{prediction2}',
)
elif category in ['推理', '代码']:
pointwise_judge_prompt = pointwise_multiturn_base_prompt.format(
rule=reasoning_rule, prediction='{prediction}')
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
rule=reasoning_rule,
prediction='{prediction}',
prediction2='{prediction2}')
prediction2='{prediction2}',
)
elif category in ['人类对齐', '角色扮演', '日常对话']:
pointwise_judge_prompt = pointwise_multiturn_base_prompt.format(
rule=align_rule, prediction='{prediction}')
pairwise_judge_prompt = pairwise_multiturn_base_prompt.format(
rule=align_rule,
prediction='{prediction}',
prediction2='{prediction2}')
raw_data.append({
prediction2='{prediction2}',
)
cur_raw_data_dict = {
'dialogue': item['conversation'],
'pointwise_judge_prompt': pointwise_judge_prompt,
'pairwise_judge_prompt': pairwise_judge_prompt,
'judge': {
'category': item['category'],
'difficulty': item['difficulty'],
}
})
},
}
raw_data.append(cur_raw_data_dict)
dataset = Dataset.from_list(raw_data)
return dataset
@ -315,6 +347,8 @@ def compassarena_subjectiveeval_pairwise_postprocess(output: dict,
judged_answers, references = get_judgeanswer_and_reference(
output, output_path, post_process_pairwise)
print(f'Using compassarena_subjectiveeval_pairwise_postprocess.')
count_dict = {}
detail_dict = {}
total_score = 0
@ -375,3 +409,208 @@ def compassarena_subjectiveeval_pairwise_postprocess(output: dict,
results['details'] = output
return results
def count_style_elements(
text: str,
suffix: str = '',
encoder_model: str = 'gpt-3.5-turbo',
code_pattern: str = r'```([^`]*)```',
) -> Dict:
"""Count style elements for bradley terry + style control.
Args:
text (str): Text to calculate style features from.
suffix (str, optional): Suffix to append to the result keys (optional).
code_pattern (str): Refex pattern to match code blocks.
Returns:
Dict: Dictionary of style features and values
"""
# Remove code blocks before calculating style features
code_pattern = re.compile(code_pattern)
blocks = code_pattern.findall(text)
for block in blocks:
text = text.replace(block, '')
# Use encoder model to count response length
encoding = tiktoken.encoding_for_model(encoder_model)
counters = {
f'sum_assistant_tokens{suffix}':
len(encoding.encode(text, allowed_special='all')),
f'header_count{suffix}': {
'h1': len(re.findall(r'^#{1}\s', text, re.MULTILINE)),
'h2': len(re.findall(r'^#{2}\s', text, re.MULTILINE)),
'h3': len(re.findall(r'^#{3}\s', text, re.MULTILINE)),
'h4': len(re.findall(r'^#{4}\s', text, re.MULTILINE)),
'h5': len(re.findall(r'^#{5}\s', text, re.MULTILINE)),
'h6': len(re.findall(r'^#{6}\s', text, re.MULTILINE)),
},
f'list_count{suffix}': {
'ordered': len(re.findall(r'^\s*\d+\.\s', text, re.MULTILINE)),
'unordered': len(re.findall(r'^\s*[-*+]\s', text, re.MULTILINE)),
},
f'bold_count{suffix}': {
'double_star': len(re.findall(r'\*\*[^*\n]+\*\*', text)),
'double_underscore': len(re.findall(r'__[^_\n]+__', text)),
},
# f"emoji_count{suffix}": len(demoji.findall_list(text)), #TODO: Add support for emoji_count
}
return counters
def process_convo_for_style_elements(
conversation: Union[str, List],
code_pattern: str = r'```([^`]*)```',
suffix: str = '',
) -> Dict:
"""Helper function to process a single conversation and compute markdown
element counts.
Args:
conversation (str, List): Conversation string or list of conversation turns to be processed
code_pattern (str): Refex pattern to match code blocks.
suffix (str, optional): Suffix to append to the result keys (optional).
Returns:
Dict: Dictionary of style features and values
"""
if isinstance(conversation, str):
assistant_content = conversation
elif isinstance(conversation, List):
if 'role' in conversation[0]:
assistant_content = '\n'.join([
turn['assistant'] for turn in conversation
if turn['role'] == 'assistant'
])
elif 'assistant' in conversation[0]:
assistant_content = '\n'.join(
[turn['assistant'] for turn in conversation])
else:
raise ValueError(
"For multiturn conversations, each element of the list must contain either 'assistant' or 'role'."
)
else:
raise ValueError(
f'`conversation` must be a list or str. Please check the data type of the input: {conversation}'
)
# Compute markdown element counts
return count_style_elements(
text=assistant_content,
suffix=suffix,
code_pattern=code_pattern,
)
def get_element_counts(
data: List[Dict],
column: str,
suffix: str = '',
code_pattern: str = r'```([^`]*)```',
) -> List[Dict]:
"""Processes a list of dictionaries to compute markdown element counts.
Args:
data (list): Input data, either a list of dictionaries.
column (str): The key or column name containing the conversation data.
suffix (str): Suffix to append to the result keys (optional).
Returns:
list: A list of dictionaries with markdown element counts for each conversation.
"""
# Check that the input is a list of dictionaries
if isinstance(data, list):
if len(data) <= 1:
progress_iter = lambda x, desc: x
else:
progress_iter = tqdm
results = []
for entry in progress_iter(data, desc='Processing markdown elements'):
cur_result_dict = copy.deepcopy(entry)
cur_result_dict.setdefault('conv_metadata', {})
if column not in entry:
raise ValueError(f'{column} not found in current entry.')
conversation = entry.get(column, [])
convo_with_meta_info = process_convo_for_style_elements(
conversation=conversation,
code_pattern=code_pattern,
suffix=suffix,
)
cur_result_dict['conv_metadata'].update(convo_with_meta_info)
results.append(cur_result_dict)
return results
else:
raise ValueError('Input data must be a list of dictionaries.')
@DICT_POSTPROCESSORS.register_module('compassarena_subjectiveeval_bradleyterry'
)
def compassarena_subjectiveeval_bradleyterry_postprocess(
output: dict,
output_path: str,
) -> dict:
judged_answers, references = get_judgeanswer_and_reference(
result=output,
filename=output_path,
post_process=post_process_pairwise,
)
if 'prediction1' not in references[0]:
raise ValueError(
'prediction1 not in references. Set `keep_predictions=True` for LMEvaluator in dataset config and retry.'
)
if 'prediction2' not in references[0]:
raise ValueError(
'prediction2 not in references. Set `keep_predictions=True` for LMEvaluator in dataset config and retry.'
)
results = {}
matches = []
for judged_answer, reference in zip(judged_answers, references):
cur_dict = {}
if judged_answer in ['A>>B', 'B<<A', 'A>B', 'B<A']:
cur_dict['winner'] = 'model_a'
elif judged_answer in ['A=B', 'B=A']:
cur_dict['winner'] = 'tie'
elif judged_answer in ['A<B', 'B>A', 'A<<B', 'B>>A']:
cur_dict['winner'] = 'model_b'
else:
continue
cur_dict['category'] = reference['category']
cur_dict['difficulty'] = reference['difficulty']
cur_dict['model_a'] = reference['answer1']
cur_dict['model_b'] = reference['answer2']
cur_dict['prediction1'] = reference['prediction1']
cur_dict['prediction2'] = reference['prediction2']
matches.append(cur_dict)
### ---------- Add Style Metadata ---------- ###
matches = get_element_counts(
data=matches,
column='prediction1',
suffix='_a',
)
matches = get_element_counts(
data=matches,
column='prediction2',
suffix='_b',
)
results['matches'] = matches
# results["details"] = output
return results

View File

@ -3,14 +3,15 @@ def get_judgeanswer_and_reference(result, filename, post_process):
"""Extract judgements (scores) and references.
Args:
dataset (ConfigDict): Dataset config.
subdir_path (str): Model path in results dir.
result (ConfigDict): Dataset config.
filename (str): Model path in results dir.
post_process (function): The pre-defined extract function.
"""
if len(result) == 0:
print('*' * 100)
print('There are no results for ' + filename)
print('*' * 100)
judged_answers = []
references = []
for k, v in result.items():
@ -21,10 +22,12 @@ def get_judgeanswer_and_reference(result, filename, post_process):
# else:
# print(v['prediction'])
# print('-' * 128)
if len(judged_answers) <= 0.95 * len(result):
print('*' * 100)
print(
f'For your {filename} judge. Among {len(result)} judgements, successfully extracted {len(judged_answers)} judgements, please check!'
)
print('*' * 100)
return judged_answers, references

View File

@ -1,5 +1,4 @@
# flake8: noqa: E501
# yapf: disable
import os.path as osp
import random
import re
@ -27,7 +26,13 @@ def extract_dicts(data):
return predictions
def order_preds_and_record_references(predictions, references, infer_order, seed=666):
def order_preds_and_record_references(
predictions: List,
references: List,
infer_order: List,
seed: int = 666,
keep_preds: bool = False,
):
"""Order predictions based on args and recording regrading references.
Args:
@ -35,23 +40,41 @@ def order_preds_and_record_references(predictions, references, infer_order, seed
references (List): List of reference based on each problem.
infer_order (str, optional): The mode of inference order.
seed (int, optional): Random seed.
keep_preds (bool, optional): Whether to save model predictions in references. This will be available as input in postprocessor. Defaults to False.
"""
random.seed(seed)
list_of_preds = [[] for _ in range(len(predictions))]
for i in range(len(predictions[0]['model_preds'])):
preds = [[pred['model_preds'][i], pred['model_name']] for pred in predictions]
preds = [[pred['model_preds'][i], pred['model_name']]
for pred in predictions]
if infer_order == 'random':
random.shuffle(preds)
for j in range(len(preds)):
list_of_preds[j].append(preds[j][0])
references[i][f'answer{j+1}'] = preds[j][1]
if keep_preds:
references[i][f'prediction{j+1}'] = preds[j][0]
if infer_order == 'double':
assert len(predictions) == 2
list_of_preds = [a + b for a, b in zip(list_of_preds, reversed(list_of_preds))]
list_of_preds = [
a + b for a, b in zip(list_of_preds, reversed(list_of_preds))
]
reversed_references = []
for item in references:
reversed_item = item.copy()
reversed_item['answer1'], reversed_item['answer2'] = reversed_item['answer2'], reversed_item['answer1']
reversed_item['answer1'], reversed_item['answer2'] = (
reversed_item['answer2'],
reversed_item['answer1'],
)
if keep_preds:
reversed_item['prediction1'], reversed_item['prediction2'] = (
reversed_item['prediction2'],
reversed_item['prediction1'],
)
reversed_references.append(reversed_item)
references += reversed_references
return list_of_preds, references
@ -83,6 +106,7 @@ class LMEvaluator:
pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round.
pred_postprocessor (ConfigDict): The model prediction's postprocessor
config.
keep_predictions (bool): Whether to save model predictions in references. Useful when postprocessor requires model predictions as input to calculate additional features (e.g. response length, markdown list counts, ...). Defaults to False.
"""
def __init__(
@ -95,6 +119,7 @@ class LMEvaluator:
dataset_cfg: Optional[ConfigDict] = None,
pred_postprocessor: Optional[ConfigDict] = None,
dict_postprocessor: Optional[ConfigDict] = None,
keep_predictions: bool = False,
) -> None:
self.output_path = output_path
out_dir, out_name = osp.split(output_path)
@ -103,34 +128,48 @@ class LMEvaluator:
self.prompt_tmpl = ICL_PROMPT_TEMPLATES.build(prompt_template)
if meta_review_prompt_template is not None:
self.meta_review_prompt_tmpl = ICL_PROMPT_TEMPLATES.build(meta_review_prompt_template)
self.meta_review_prompt_tmpl = ICL_PROMPT_TEMPLATES.build(
meta_review_prompt_template)
max_out_len = judge_cfg.get('max_out_len', None)
batch_size = judge_cfg.get('batch_size', None)
model = build_model_from_cfg(model_cfg=judge_cfg)
self.inferencer = GenInferencer(model,
max_out_len=max_out_len,
batch_size=batch_size,
output_json_filepath=out_dir,
output_json_filename=out_name)
self.inferencer = GenInferencer(
model,
max_out_len=max_out_len,
batch_size=batch_size,
output_json_filepath=out_dir,
output_json_filename=out_name,
)
self.logger = get_logger()
self.dataset_cfg = dataset_cfg
self.pack_all_predictions = pack_all_predictions
self.pred_postprocessor = pred_postprocessor
self.dict_postprocessor = dict_postprocessor
self.keep_predictions = keep_predictions
def score(self,
predictions,
judgements: Optional[List] = None,
references: Optional[List] = None,
meta: Optional[bool] = False,
infer_order: Optional[str] = 'random') -> Dict:
def score(
self,
predictions,
judgements: Optional[List] = None,
references: Optional[List] = None,
meta: Optional[bool] = False,
infer_order: Optional[str] = 'random',
) -> Dict:
dup_indices = []
if isinstance(predictions, list):
"""Apply to multi-model comparison."""
if references is None:
references = [{} for _ in range(len(predictions[0]['model_preds']))]
predictions, references = order_preds_and_record_references(predictions, references, infer_order)
references = [
{} for _ in range(len(predictions[0]['model_preds']))
]
predictions, references = order_preds_and_record_references(
predictions=predictions,
references=references,
infer_order=infer_order,
keep_preds=self.keep_predictions,
)
# calculate dupicated predictions numbers
total_predictions_num = len(predictions[0])
@ -145,7 +184,9 @@ class LMEvaluator:
elif isinstance(predictions, dict):
"""Apply to single-model scoring."""
if references is None:
references = [{} for _ in range(len(predictions[0]['model_preds']))]
references = [
{} for _ in range(len(predictions[0]['model_preds']))
]
predictions = [predictions['model_preds']]
# Due to the rarity of identical predictions, we have temporarily disabled the plagiarism detection feature.
@ -166,20 +207,27 @@ class LMEvaluator:
gold_key = 'obj_gold'
pred_dict[key] = predictions[i]
pred_dict[gold_key] = references
pred_dict[key + '_en_word_count'] = [count_english_words(j) for j in predictions[i]]
pred_dict[key + '_cn_word_count'] = [count_chinese_characters(j) for j in predictions[i]]
pred_dict[key + '_en_word_count'] = [
count_english_words(j) for j in predictions[i]
]
pred_dict[key + '_cn_word_count'] = [
count_chinese_characters(j) for j in predictions[i]
]
if judgements:
for i in range(len(judgements)):
key = 'judgement' if i == 0 else f'judgement{i + 1}'
pred_dict[key] = judgements[i]['model_preds']
for j in range(len(references)):
references[j]['judge_model' + str(i + 1)] = judgements[i]['model_name']
references[j]['judge_model' +
str(i + 1)] = judgements[i]['model_name']
elif isinstance(predictions[0][0], list):
# multi round for format like [[[{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}], [{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}]]]
if self.pack_all_predictions:
for i in range(len(predictions)):
key = 'prediction' if i == 0 else f'prediction{i + 1}'
predictions[i] = [str(_) for _ in predictions[i]] # Fix the dictionary order to prevent the following situations: {'assistant':'', 'round':2, 'user':''}
predictions[i] = [
str(_) for _ in predictions[i]
] # Fix the dictionary order to prevent the following situations: {'assistant':'', 'round':2, 'user':''}
pred_dict[key] = predictions[i]
else:
for i in range(len(predictions)):
@ -192,44 +240,62 @@ class LMEvaluator:
raise NotImplementedError(
'Not applied meta-reivew judge on multi-round dataset')
else:
raise NotImplementedError(f'{predictions[0][0]} with type {type(predictions[0][0])}, please check the postprocess you add to the prediction string is right or not, we suggest to return an empty string but not None')
raise NotImplementedError(
f'{predictions[0][0]} with type {type(predictions[0][0])}, please check the postprocess you add to the prediction string is right or not, we suggest to return an empty string but not None'
)
if self.dataset_cfg:
dataset = build_dataset_from_cfg(self.dataset_cfg)
if infer_order == 'double':
new_ds = {k: dataset.test[k] * 2 for k in dataset.test.column_names}
new_ds = {
k: dataset.test[k] * 2
for k in dataset.test.column_names
}
dataset.reader.dataset['test'] = Dataset.from_dict(new_ds)
if len(dup_indices) != 0:
remaining_indices = [idx for idx in range(len(dataset.test)) if idx not in dup_indices]
dataset.reader.dataset['test'] = dataset.test.select(remaining_indices)
print(f'Among total {total_predictions_num} predictions, there are {len(dup_indices)} predictions totally same, which are removed!')
remaining_indices = [
idx for idx in range(len(dataset.test))
if idx not in dup_indices
]
dataset.reader.dataset['test'] = dataset.test.select(
remaining_indices)
print(
f'Among total {total_predictions_num} predictions, there are {len(dup_indices)} predictions totally same, which are removed!'
)
for k, v in pred_dict.items():
dataset.reader.dataset['test'] = dataset.test.add_column(k, v)
dataset.reader.input_columns.append(k)
if references:
dataset.reader.input_columns.append('reference')
dataset.reader.dataset['test'] = dataset.test.add_column('reference', references)
dataset.reader.dataset['test'] = dataset.test.add_column(
'reference', references)
else:
# build a default dataset just for comparison
from opencompass.datasets.lmeval import LMEvalDataset
input_columns = list(pred_dict.keys())
if references:
input_columns.append('reference')
dataset = LMEvalDataset(
reader_cfg=dict(input_columns=input_columns, output_column=None, train_split='test'),
reader_cfg=dict(input_columns=input_columns,
output_column=None,
train_split='test'),
reference=references,
**pred_dict
**pred_dict,
)
dataset.reader.output_column = 'reference'
retriever = ZeroRetriever(dataset)
if meta:
self.inferencer.inference(retriever=retriever, prompt_template=self.meta_review_prompt_tmpl)
self.inferencer.inference(
retriever=retriever,
prompt_template=self.meta_review_prompt_tmpl)
else:
self.inferencer.inference(retriever=retriever, prompt_template=self.prompt_tmpl)
self.inferencer.inference(retriever=retriever,
prompt_template=self.prompt_tmpl)
output = mmengine.load(self.output_path)
return self.postprocess(output)

View File

@ -6,6 +6,7 @@ from .arenahard import ArenaHardSummarizer
from .charm import CharmMemSummarizer
from .common_summarizer import CommonSummarizer
from .compass_arena import CompassArenaSummarizer
from .compass_arena_bradley_terry import CompassArenaBradleyTerrySummarizer
from .compassbench import CompassBenchSummarizer
from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer

File diff suppressed because it is too large Load Diff