[Feature] Add JudgeLLMs (#710)

* add judgellms

* add judgellms

* add sub_size_partition

* add docs

* add ref
This commit is contained in:
bittersweet1999 2023-12-19 18:40:25 +08:00 committed by GitHub
parent eda72e756e
commit 97c2068bd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 650 additions and 13 deletions

View File

@ -0,0 +1,84 @@
from mmengine.config import read_base
with read_base():
from .models.qwen.hf_qwen_7b_chat import models as hf_qwen_7b_chat
from .models.qwen.hf_qwen_14b_chat import models as hf_qwen_14b_chat
from .models.chatglm.hf_chatglm3_6b import models as hf_chatglm3_6b
from .models.baichuan.hf_baichuan2_7b_chat import models as hf_baichuan2_7b
from .models.hf_internlm.hf_internlm_chat_20b import models as hf_internlm_chat_20b
from .datasets.subjective_cmp.alignment_bench import subjective_datasets
datasets = [*subjective_datasets]
from opencompass.models import HuggingFaceCausalLM, HuggingFace, OpenAIAllesAPIN, HuggingFaceChatGLM3
from opencompass.partitioners import NaivePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import AlignmentBenchSummarizer
# -------------Inferen Stage ----------------------------------------
models = [*hf_baichuan2_7b]#, *hf_chatglm3_6b, *hf_internlm_chat_20b, *hf_qwen_7b_chat, *hf_qwen_14b_chat]
infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=256,
task=dict(type=OpenICLInferTask)),
)
# -------------Evalation Stage ----------------------------------------
## ------------- JudgeLLM Configuration
api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)
judge_model = dict(
type=HuggingFaceCausalLM,
abbr='pandalm-7b-v1-hf',
path="WeOpenML/PandaLM-7B-v1",
tokenizer_path='WeOpenML/PandaLM-7B-v1',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
## ------------- Evaluation Configuration
eval = dict(
partitioner=dict(
type=SubjectiveNaivePartitioner,
mode='singlescore',
models = [*hf_baichuan2_7b]
),
runner=dict(
type=LocalRunner,
max_num_workers=2,
task=dict(
type=SubjectiveEvalTask,
judge_cfg=judge_model
)),
)
summarizer = dict(
type=AlignmentBenchSummarizer,
)
work_dir = 'outputs/pandalm'

View File

@ -0,0 +1,26 @@
from opencompass.models import HuggingFaceCausalLM
'''
This is a bilingual 6B version of Auto-J.
It is trained on both the original training data
and its Chinese translation, which can be find in
https://huggingface.co/GAIR/autoj-bilingual-6b
'''
models = [
dict(
type=HuggingFaceCausalLM,
abbr='autoj-bilingual-6b',
path="GAIR/autoj-bilingual-6b",
tokenizer_path='GAIR/autoj-bilingual-6b',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,20 @@
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
abbr='autoj-13b-GPTQ-4bits',
path="GAIR/autoj-13b-GPTQ-4bits",
tokenizer_path='GAIR/autoj-13b-GPTQ-4bits',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,25 @@
from opencompass.models import HuggingFaceCausalLM
'''
#This is a 4bits quantized version of Auto-J by using AutoGPTQ,
which is available on huggingface-hub:
https://huggingface.co/GAIR/autoj-13b-GPTQ-4bits
'''
models = [
dict(
type=HuggingFaceCausalLM,
abbr='autoj-13b',
path="GAIR/autoj-13b",
tokenizer_path='GAIR/autoj-13b',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,20 @@
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
abbr='autoj-scenario-classifier',
path="GAIR/autoj-scenario-classifier",
tokenizer_path='GAIR/autoj-scenario-classifier',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,20 @@
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
abbr='judgelm-13b-v1-hf',
path="BAAI/JudgeLM-13b-v1.0",
tokenizer_path='BAAI/JudgeLM-13b-v1.0',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,20 @@
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
abbr='judgelm-33b-v1-hf',
path="BAAI/JudgeLM-33b-v1.0",
tokenizer_path='BAAI/JudgeLM-33b-v1.0',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,20 @@
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
abbr='judgelm-7b-v1-hf',
path="BAAI/JudgeLM-7B-v1.0",
tokenizer_path='BAAI/JudgeLM-7B-v1.0',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,20 @@
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
abbr='alpaca-pandalm-7b-v1-hf',
path="WeOpenML/PandaLM-Alpaca-7B-v1",
tokenizer_path='WeOpenML/PandaLM-Alpaca-7B-v1',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -0,0 +1,20 @@
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
abbr='pandalm-7b-v1-hf',
path="WeOpenML/PandaLM-7B-v1",
tokenizer_path='WeOpenML/PandaLM-7B-v1',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=512,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

View File

@ -144,6 +144,64 @@ The `-r` parameter allows the reuse of model inference and GPT-4 evaluation resu
The response of JudgeLLM will be output to `output/.../results/timestamp/xxmodel/xxdataset/.json`.
The evaluation report will be output to `output/.../summary/timestamp/report.csv`.
Opencompass has supported lots of JudgeLLM, actually, you can take any model as JudgeLLM in opencompass configs.
And we list the popular open-source JudgeLLM here:
1. Auto-J, refer to `configs/models/judge_llm/auto_j`
Consider cite the following paper if you find it helpful:
```bibtex
@article{li2023generative,
title={Generative judge for evaluating alignment},
author={Li, Junlong and Sun, Shichao and Yuan, Weizhe and Fan, Run-Ze and Zhao, Hai and Liu, Pengfei},
journal={arXiv preprint arXiv:2310.05470},
year={2023}
}
@misc{2023opencompass,
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
author={OpenCompass Contributors},
howpublished = {\url{https://github.com/open-compass/opencompass}},
year={2023}
}
```
2. JudgeLM, refer to `configs/models/judge_llm/judgelm`
```bibtex
@article{zhu2023judgelm,
title={JudgeLM: Fine-tuned Large Language Models are Scalable Judges},
author={Zhu, Lianghui and Wang, Xinggang and Wang, Xinlong},
journal={arXiv preprint arXiv:2310.17631},
year={2023}
}
@misc{2023opencompass,
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
author={OpenCompass Contributors},
howpublished = {\url{https://github.com/open-compass/opencompass}},
year={2023}
}
```
3. PandaLM, refer to `configs/models/judge_llm/pandalm`
Consider cite the following paper if you find it helpful:
```bibtex
@article{wang2023pandalm,
title={PandaLM: An Automatic Evaluation Benchmark for LLM Instruction Tuning Optimization},
author={Wang, Yidong and Yu, Zhuohao and Zeng, Zhengran and Yang, Linyi and Wang, Cunxiang and Chen, Hao and Jiang, Chaoya and Xie, Rui and Wang, Jindong and Xie, Xing and others},
journal={arXiv preprint arXiv:2306.05087},
year={2023}
}
@misc{2023opencompass,
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
author={OpenCompass Contributors},
howpublished = {\url{https://github.com/open-compass/opencompass}},
year={2023}
}
```
## Practice: AlignBench Evaluation
### Dataset

View File

@ -142,6 +142,66 @@ python run.py configs/eval_subjective_score.py -r
JudgeLLM的评测回复会保存在 `output/.../results/timestamp/xxmodel/xxdataset/.json`
评测报告则会输出到 `output/.../summary/timestamp/report.csv`
Opencompass 已经支持了很多的JudgeLLM实际上你可以将Opencompass中所支持的所有模型都当作JudgeLLM使用。
我们列出目前比较流行的开源JudgeLLM
1. Auto-J请参考 `configs/models/judge_llm/auto_j`
如果使用了该方法,请添加引用:
```bibtex
@article{li2023generative,
title={Generative judge for evaluating alignment},
author={Li, Junlong and Sun, Shichao and Yuan, Weizhe and Fan, Run-Ze and Zhao, Hai and Liu, Pengfei},
journal={arXiv preprint arXiv:2310.05470},
year={2023}
}
@misc{2023opencompass,
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
author={OpenCompass Contributors},
howpublished = {\url{https://github.com/open-compass/opencompass}},
year={2023}
}
```
2. JudgeLM请参考 `configs/models/judge_llm/judgelm`
如果使用了该方法,请添加引用:
```bibtex
@article{zhu2023judgelm,
title={JudgeLM: Fine-tuned Large Language Models are Scalable Judges},
author={Zhu, Lianghui and Wang, Xinggang and Wang, Xinlong},
journal={arXiv preprint arXiv:2310.17631},
year={2023}
}
@misc{2023opencompass,
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
author={OpenCompass Contributors},
howpublished = {\url{https://github.com/open-compass/opencompass}},
year={2023}
}
```
3. PandaLM请参考 `configs/models/judge_llm/pandalm`
如果使用了该方法,请添加引用:
```bibtex
@article{wang2023pandalm,
title={PandaLM: An Automatic Evaluation Benchmark for LLM Instruction Tuning Optimization},
author={Wang, Yidong and Yu, Zhuohao and Zeng, Zhengran and Yang, Linyi and Wang, Cunxiang and Chen, Hao and Jiang, Chaoya and Xie, Rui and Wang, Jindong and Xie, Xing and others},
journal={arXiv preprint arXiv:2306.05087},
year={2023}
}
@misc{2023opencompass,
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
author={OpenCompass Contributors},
howpublished = {\url{https://github.com/open-compass/opencompass}},
year={2023}
}
```
## 实战AlignBench 主观评测
### 数据集准备

View File

@ -8,18 +8,6 @@ from opencompass.registry import PARTITIONERS
from .naive import NaivePartitioner
def remove_duplicate_pairs(model_combinations):
combo_dict = {}
for i, combo in enumerate(model_combinations):
sorted_names = tuple(sorted((combo[0]['abbr'], combo[1]['abbr'])))
if sorted_names not in combo_dict:
combo_dict[sorted_names] = i
new_model_combinations = [
model_combinations[i] for i in combo_dict.values()
]
return new_model_combinations
@PARTITIONERS.register_module()
class SubjectiveNaivePartitioner(NaivePartitioner):
"""Naive task partitioner for subjective evaluation. Compared to
@ -47,6 +35,17 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
self.compare_models = compare_models
self.model_pairs = model_pairs
def remove_duplicate_pairs(self, model_combinations):
combo_dict = {}
for i, combo in enumerate(model_combinations):
sorted_names = tuple(sorted((combo[0]['abbr'], combo[1]['abbr'])))
if sorted_names not in combo_dict:
combo_dict[sorted_names] = i
new_model_combinations = [
model_combinations[i] for i in combo_dict.values()
]
return new_model_combinations
def get_model_combinations(
self,
models: List[ConfigDict],
@ -58,7 +57,7 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
elif self.mode == 'm2n':
assert len(base_models) > 0 and len(compare_models) > 0
model_combinations = list(product(base_models, compare_models))
unique_combinations = remove_duplicate_pairs([
unique_combinations = self.remove_duplicate_pairs([
combo for combo in model_combinations if combo[0] != combo[1]
])
return unique_combinations

View File

@ -0,0 +1,245 @@
import copy
import math
import os.path as osp
from fnmatch import fnmatch
from typing import Dict, List, Optional, Tuple, Union
import mmengine
from mmengine.config import Config, ConfigDict
from opencompass.registry import PARTITIONERS
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path)
from .sub_naive import SubjectiveNaivePartitioner
@PARTITIONERS.register_module()
class SubjectiveSizePartitioner(SubjectiveNaivePartitioner):
"""Task partitioner based on the size of the dataset (with some rough
expansion as an estimation of computational cost).
Args:
out_dir (str): The output directory of tasks.
max_task_size (int): The maximum size of a task.
gen_task_coef (int): The dataset cost measurement coefficient for
generation tasks.
strategy (str): The partition strategy. Supported strategies are:
'heuristic' and 'split'. Defaults to 'heuristic'.
heuristic: split large datasets into several tasks, merge small
datasets into one task.
split: split large datasets into several tasks only.
dataset_size_path (str): The path to the dataset size cache file.
keep_keys (list[str]): The keys to be kept from the experiment config
to the task config.
"""
def __init__(self,
mode: str,
out_dir: str,
models: Optional[List[ConfigDict]] = [],
base_models: Optional[List[ConfigDict]] = [],
compare_models: Optional[List[ConfigDict]] = [],
model_pairs: Optional[List[Tuple]] = None,
max_task_size: int = 40000,
gen_task_coef: int = 20,
strategy: str = 'heuristic',
dataset_size_path: str = '.cache/dataset_size.json',
keep_keys: Optional[List[str]] = None):
super().__init__(out_dir=out_dir,
keep_keys=keep_keys,
mode=mode,
models=models,
base_models=base_models,
compare_models=compare_models,
model_pairs=model_pairs)
self.max_task_size = max_task_size
self.gen_task_coef = gen_task_coef
self.dataset_size_path = dataset_size_path
assert strategy in ('heuristic', 'split'), \
f'Unsupported partition strategy: {strategy}. '\
'Supported strategies are: `heuristic`, `split` .'
self.strategy = strategy
def partition(self,
models: List[ConfigDict],
datasets: List[ConfigDict],
work_dir: str,
out_dir: str,
add_cfg: Dict = {}) -> List[ConfigDict]:
"""Partition model-dataset pairs into tasks. Each task is defined as a
dict and will run independently as a unit. Its structure is as
follows:
.. code-block:: python
{
'models': [], # a list of model configs
'datasets': [[]], # a nested list of dataset configs, each
list corresponds to a model
'work_dir': '', # the work dir
**add_cfg # other keys to be kept in the config
}
Args:
models (List[ConfigDict]): A list of model configs.
datasets (List[ConfigDict]): A list of dataset configs.
work_dir (str): The work dir for the task.
out_dir (str): The full output path for the task, intended for
Partitioners to check whether the task is finished via the
existency of result file in this directory.
add_cfg (dict): Other common keys to be added in the task config,
used to share the same config among tasks. Defaults to {}.
Returns:
List[ConfigDict]: A list of tasks.
"""
models = self.models if self.models != [] else models
base_models, compare_models = self.base_models, self.compare_models
if self.mode == 'singlescore':
models = models
else:
models = super().get_model_combinations(models, base_models,
compare_models)
model_dataset_combinations = [{'models': models, 'datasets': datasets}]
tasks = []
for comb in model_dataset_combinations:
comb['datasets'] = sorted(comb['datasets'],
key=lambda x: self.get_cost(x),
reverse=True)
for model in comb['models']:
chunks = [] # elements: tuple(size, dataset_chunk)
for dataset in comb['datasets']:
filename = get_infer_output_path(model, dataset, out_dir)
# skip the task if the task output exists
if osp.exists(filename):
continue
dataset_size = self.get_cost(dataset)
if dataset_size > self.max_task_size:
root, ext = osp.splitext(filename)
dataset_splits = self.split_dataset(dataset)
for i, dataset_split in enumerate(dataset_splits):
if not osp.exists(f'{root}_{i}{ext}'):
chunks.append(
(self.max_task_size, dataset_split))
else:
chunks.append((dataset_size, dataset))
if self.strategy == 'heuristic':
chunks = sorted(chunks, key=lambda x: x[0], reverse=True)
current_size, current_chunks = 0, []
for index in range(len(chunks)):
current_size += chunks[index][0]
current_chunks.append(chunks[index][1])
if index == len(chunks) - 1 or current_size + chunks[
index + 1][0] > self.max_task_size:
tasks.append(
Config({
'models': [model],
'datasets': [current_chunks],
'work_dir': work_dir,
**add_cfg
}))
current_size, current_chunks = 0, []
elif self.strategy == 'split':
for _, dataset in chunks:
tasks.append(
Config({
'models': [model],
'datasets': [[dataset]],
'work_dir': work_dir,
**add_cfg
}))
return tasks
@property
def dataset_size(self):
if not hasattr(self, '_dataset_size'):
if osp.exists(self.dataset_size_path):
self._dataset_size = mmengine.load(self.dataset_size_path)
else:
self._dataset_size = {}
return self._dataset_size
def split_dataset(self, dataset_cfg: ConfigDict) -> List[ConfigDict]:
"""Split dataset into several parts."""
dataset_size, num_repeats = self.get_cost(dataset_cfg,
get_raw_factors=True)
split_configs = []
abbr = dataset_abbr_from_cfg(dataset_cfg)
step = self.max_task_size // num_repeats
# evenly distribute the task
step = math.ceil(dataset_size / math.ceil(dataset_size / step))
for part, i in enumerate(range(0, dataset_size, step)):
cfg = copy.deepcopy(dataset_cfg)
cfg['abbr'] = abbr + f'_{part}'
test_range = cfg['reader_cfg'].get('test_range', '')
cfg['reader_cfg']['test_range'] = f'{test_range}[{i}:{i+step}]'
split_configs.append(cfg)
return split_configs
def get_factor(self, dataset: ConfigDict) -> int:
infer_cfg = dataset.infer_cfg
template = (infer_cfg.prompt_template.template if 'prompt_template'
in infer_cfg else infer_cfg.ice_template.template)
# If it's the Gen template, the dataset size will be multiplied by the
# self.gen_task_coef
factor = self.gen_task_coef
# If it's the PPL template, the dataset size will be multiplied by the
# number of labels
if isinstance(template, dict):
ctr = sum(key in template for key in ('begin', 'round', 'end'))
if ctr != len(template.keys()):
factor = len(template.keys())
dataset_abbr = dataset_abbr_from_cfg(dataset)
if any(
fnmatch(dataset_abbr, pattern)
for pattern in ('bbh*', 'gsm8k*', 'math*', 'strategyqa*',
'agieval-jec*', 'agieval-gaokao-mathcloze',
'agieval-math', '*professional_law')):
factor *= 10
return factor
def get_cost(self,
dataset: ConfigDict,
get_raw_factors: bool = False) -> Union[int, Tuple[int, int]]:
"""Get the computational cost of inferring on the dataset.
Args:
dataset (ConfigDict): The dataset config.
get_raw_factors (bool): If True, the raw factors of computational
cost will be returned.
Returns:
int or Tuple[int, int]: The size of the dataset. If get_raw_factors
is True, the number of repeats will also be returned.
"""
dataset_abbr = dataset_abbr_from_cfg(dataset)
test_range = dataset.reader_cfg.get('test_range', '')
factor = self.get_factor(dataset)
if dataset_abbr in self.dataset_size:
actual_size = eval('len(range(self.dataset_size[dataset_abbr])'
f'{test_range})')
if get_raw_factors:
return actual_size, factor
return factor * actual_size
dataset = build_dataset_from_cfg(dataset)
self.dataset_size[dataset_abbr] = len(dataset.test)
mmengine.mkdir_or_exist('.cache/')
mmengine.dump(self.dataset_size,
self.dataset_size_path,
indent=4,
ensure_ascii=False)
actual_size = eval('len(range(self.dataset_size[dataset_abbr])'
f'{test_range})')
if get_raw_factors:
return actual_size, factor
return factor * actual_size