mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[fix] add different temp for different question in mtbench (#954)
* add temp for mtbench * add document for mtbench * add document for mtbench
This commit is contained in:
parent
7c1a819bb4
commit
848e7c8a76
@ -0,0 +1,64 @@
|
|||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import ChatInferencer, GenInferencer
|
||||||
|
from opencompass.openicl.icl_evaluator import LMEvaluator
|
||||||
|
from opencompass.datasets import MTBenchDataset
|
||||||
|
|
||||||
|
|
||||||
|
subjective_reader_cfg = dict(
|
||||||
|
input_columns=['dialogue', 'capability', 'system_prompt', 'prompt_template'],
|
||||||
|
output_column='judge',
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_all_sets = [
|
||||||
|
"mtbench_0.0","mtbench_0.1","mtbench_0.7"
|
||||||
|
]
|
||||||
|
data_path ="data/subjective/mtbench"
|
||||||
|
|
||||||
|
subjective_datasets = []
|
||||||
|
|
||||||
|
for _name in subjective_all_sets:
|
||||||
|
temperature = float(_name.split('_')[1])
|
||||||
|
do_sample = False if temperature == 0.0 else True
|
||||||
|
subjective_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template="""{dialogue}""",
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=ChatInferencer, max_seq_len=4096, max_out_len=512, temperature=temperature, do_sample=do_sample,infer_mode='every'),
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_eval_cfg = dict(
|
||||||
|
evaluator=dict(
|
||||||
|
type=LMEvaluator,
|
||||||
|
prompt_template=dict(
|
||||||
|
type=PromptTemplate,
|
||||||
|
template=dict(
|
||||||
|
begin=[
|
||||||
|
dict(
|
||||||
|
role='SYSTEM',
|
||||||
|
fallback_role='HUMAN',
|
||||||
|
prompt="{system_prompt}")
|
||||||
|
],
|
||||||
|
round=[
|
||||||
|
dict(
|
||||||
|
role='HUMAN',
|
||||||
|
prompt = "{prompt_template}"
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pred_role="BOT",
|
||||||
|
)
|
||||||
|
|
||||||
|
subjective_datasets.append(
|
||||||
|
dict(
|
||||||
|
abbr=f"{_name}",
|
||||||
|
type=MTBenchDataset,
|
||||||
|
path=data_path,
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=subjective_reader_cfg,
|
||||||
|
infer_cfg=subjective_infer_cfg,
|
||||||
|
eval_cfg=subjective_eval_cfg
|
||||||
|
))
|
@ -1,7 +1,7 @@
|
|||||||
from mmengine.config import read_base
|
from mmengine.config import read_base
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
from .datasets.subjective.multiround.mtbench_single_judge import subjective_datasets
|
from .datasets.subjective.multiround.mtbench_single_judge_diff_temp import subjective_datasets
|
||||||
# from .datasets.subjective.multiround.mtbench_pair_judge import subjective_datasets
|
# from .datasets.subjective.multiround.mtbench_pair_judge import subjective_datasets
|
||||||
|
|
||||||
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3
|
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3
|
||||||
@ -23,38 +23,44 @@ api_meta_template = dict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", begin='\n<|im_start|>user\n', end='<|im_end|>'),
|
||||||
|
dict(role="BOT", begin="\n<|im_start|>assistant\n", end='<|im_end|>', generate=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
# -------------Inference Stage ----------------------------------------
|
# -------------Inference Stage ----------------------------------------
|
||||||
# For subjective evaluation, we often set do sample for models
|
# For subjective evaluation, we often set do sample for models
|
||||||
models = [
|
models = [
|
||||||
dict(
|
dict(
|
||||||
type=HuggingFaceChatGLM3,
|
type=HuggingFaceCausalLM,
|
||||||
abbr='chatglm3-6b-hf',
|
abbr='qwen-7b-chat-hf',
|
||||||
path='THUDM/chatglm3-6b',
|
path="Qwen/Qwen-7B-Chat",
|
||||||
tokenizer_path='THUDM/chatglm3-6b',
|
tokenizer_path='Qwen/Qwen-7B-Chat',
|
||||||
model_kwargs=dict(
|
model_kwargs=dict(
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
trust_remote_code=True,
|
trust_remote_code=True
|
||||||
),
|
),
|
||||||
tokenizer_kwargs=dict(
|
tokenizer_kwargs=dict(
|
||||||
padding_side='left',
|
padding_side='left',
|
||||||
truncation_side='left',
|
truncation_side='left',
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
use_fast=False,
|
||||||
),
|
),
|
||||||
generation_kwargs=dict(
|
pad_token_id=151643,
|
||||||
do_sample=True,
|
max_out_len=100,
|
||||||
),
|
max_seq_len=2048,
|
||||||
meta_template=api_meta_template,
|
batch_size=8,
|
||||||
max_out_len=2048,
|
meta_template=_meta_template,
|
||||||
max_seq_len=4096,
|
|
||||||
batch_size=1,
|
|
||||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||||
|
end_str='<|im_end|>',
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
datasets = [*subjective_datasets]
|
datasets = [*subjective_datasets]
|
||||||
|
|
||||||
infer = dict(
|
infer = dict(
|
||||||
partitioner=dict(type=SizePartitioner, max_task_size=10000),
|
partitioner=dict(type=SizePartitioner, strategy='split', max_task_size=10000),
|
||||||
runner=dict(
|
runner=dict(
|
||||||
type=SlurmSequentialRunner,
|
type=SlurmSequentialRunner,
|
||||||
partition='llm_dev2',
|
partition='llm_dev2',
|
||||||
@ -80,7 +86,6 @@ judge_model = dict(
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
## ------------- Evaluation Configuration
|
## ------------- Evaluation Configuration
|
||||||
# ## pair evaluation
|
# ## pair evaluation
|
||||||
# eval = dict(
|
# eval = dict(
|
||||||
@ -95,7 +100,7 @@ judge_model = dict(
|
|||||||
|
|
||||||
## single evaluation
|
## single evaluation
|
||||||
eval = dict(
|
eval = dict(
|
||||||
partitioner=dict(type=SubjectiveSizePartitioner, max_task_size=10000, mode='singlescore', models=models),
|
partitioner=dict(type=SubjectiveSizePartitioner, strategy='split', max_task_size=10000, mode='singlescore', models=models),
|
||||||
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
|
runner=dict(type=LocalRunner, max_num_workers=32, task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -202,6 +202,54 @@ Consider cite the following paper if you find it helpful:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Multi-round Subjective Evaluation in OpenCompass
|
||||||
|
|
||||||
|
In OpenCompass, we also support subjective multi-turn dialogue evaluation. For instance, the evaluation of MT-Bench can be referred to in `configs/eval_subjective_mtbench.py`.
|
||||||
|
|
||||||
|
In the multi-turn dialogue evaluation, you need to organize the data format into the following dialogue structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
"dialogue": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Imagine you are participating in a race with a group of people. If you have just overtaken the second person, what's your current position? Where is the person you just overtook?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "If the \"second person\" is changed to \"last person\" in the above question, what would the answer be?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
```
|
||||||
|
|
||||||
|
It's important to note that due to the different question types in MTBench having different temperature settings, we need to divide the original data files into three different subsets according to the temperature for separate inference. For different subsets, we can set different temperatures. For specific settings, please refer to `configs\datasets\subjective\multiround\mtbench_single_judge_diff_temp.py`.
|
||||||
|
|
||||||
|
Consider cite the following paper if you find it helpful:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{zheng2023judging,
|
||||||
|
title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
|
||||||
|
author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
|
||||||
|
year={2023},
|
||||||
|
eprint={2306.05685},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CL}
|
||||||
|
}
|
||||||
|
@misc{2023opencompass,
|
||||||
|
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
||||||
|
author={OpenCompass Contributors},
|
||||||
|
howpublished = {\url{https://github.com/open-compass/opencompass}},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Practice: AlignBench Evaluation
|
## Practice: AlignBench Evaluation
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
@ -202,6 +202,54 @@ Opencompass 已经支持了很多的JudgeLLM,实际上,你可以将Opencompa
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 主观多轮对话评测
|
||||||
|
|
||||||
|
在OpenCompass中我们同样支持了主观的多轮对话评测,以MT-Bench为例,对MTBench的评测可以参见`configs/eval_subjective_mtbench.py`
|
||||||
|
|
||||||
|
在多轮对话评测中,你需要将数据格式整理为如下的dialogue格式
|
||||||
|
|
||||||
|
```
|
||||||
|
"dialogue": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Imagine you are participating in a race with a group of people. If you have just overtaken the second person, what's your current position? Where is the person you just overtook?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "If the \"second person\" is changed to \"last person\" in the above question, what would the answer be?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
```
|
||||||
|
|
||||||
|
值得注意的是,由于MTBench各不同的题目类型设置了不同的温度,因此我们需要将原始数据文件按照温度分成三个不同的子集以分别推理,针对不同的子集我们可以设置不同的温度,具体设置参加`configs\datasets\subjective\multiround\mtbench_single_judge_diff_temp.py`
|
||||||
|
|
||||||
|
如果使用了该方法,请添加引用:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{zheng2023judging,
|
||||||
|
title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
|
||||||
|
author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
|
||||||
|
year={2023},
|
||||||
|
eprint={2306.05685},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CL}
|
||||||
|
}
|
||||||
|
@misc{2023opencompass,
|
||||||
|
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
||||||
|
author={OpenCompass Contributors},
|
||||||
|
howpublished = {\url{https://github.com/open-compass/opencompass}},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## 实战:AlignBench 主观评测
|
## 实战:AlignBench 主观评测
|
||||||
|
|
||||||
### 数据集准备
|
### 数据集准备
|
||||||
|
@ -172,6 +172,8 @@ class ChatInferencer(BaseInferencer):
|
|||||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||||
output_json_filename: Optional[str] = 'predictions',
|
output_json_filename: Optional[str] = 'predictions',
|
||||||
save_every: Optional[int] = 1,
|
save_every: Optional[int] = 1,
|
||||||
|
temperature: Optional[float] = 0.0,
|
||||||
|
do_sample: Optional[bool] = False,
|
||||||
infer_mode: str = 'last',
|
infer_mode: str = 'last',
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -182,6 +184,8 @@ class ChatInferencer(BaseInferencer):
|
|||||||
)
|
)
|
||||||
assert infer_mode in ['last', 'every', 'every_with_gt']
|
assert infer_mode in ['last', 'every', 'every_with_gt']
|
||||||
self.infer_mode = infer_mode
|
self.infer_mode = infer_mode
|
||||||
|
self.temperature = temperature
|
||||||
|
self.do_sample = do_sample
|
||||||
self.model: BaseModel
|
self.model: BaseModel
|
||||||
self._set_meta_template(self.model)
|
self._set_meta_template(self.model)
|
||||||
|
|
||||||
@ -347,8 +351,16 @@ class ChatInferencer(BaseInferencer):
|
|||||||
|
|
||||||
for i in assistant_indices:
|
for i in assistant_indices:
|
||||||
history = chat[:i]
|
history = chat[:i]
|
||||||
output = self.model.generate_from_template([history],
|
if self.do_sample:
|
||||||
max_out_len=512)[0]
|
output = self.model.generate_from_template(
|
||||||
|
[history],
|
||||||
|
do_sample=self.do_sample,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_out_len=512)[0]
|
||||||
|
else:
|
||||||
|
output = self.model.generate_from_template([history],
|
||||||
|
do_sample=False,
|
||||||
|
max_out_len=512)[0]
|
||||||
chat[i]['content'] = output
|
chat[i]['content'] = output
|
||||||
if not self.dialogue_mode:
|
if not self.dialogue_mode:
|
||||||
output_handler.save_multiround_results(
|
output_handler.save_multiround_results(
|
||||||
|
@ -127,12 +127,16 @@ class MTBenchSummarizer(CompassArenaSummarizer):
|
|||||||
fout = osp.join(
|
fout = osp.join(
|
||||||
output_dir,
|
output_dir,
|
||||||
'judged-by--' + judge_model + '-capability.csv')
|
'judged-by--' + judge_model + '-capability.csv')
|
||||||
|
overall_judged_answers, overall_references = [], []
|
||||||
for dataset in dataset_cfgs:
|
for dataset in dataset_cfgs:
|
||||||
judged_answers, references = get_judgeanswer_and_reference(
|
judged_answers, references = get_judgeanswer_and_reference(
|
||||||
dataset, subdir_path, self.judge_function)
|
dataset, subdir_path, self.judge_function)
|
||||||
get_capability_results(judged_answers, references,
|
overall_judged_answers += judged_answers
|
||||||
fout, fout_flag, model)
|
overall_references += references
|
||||||
fout_flag += 1
|
get_capability_results(overall_judged_answers,
|
||||||
|
overall_references, fout, fout_flag,
|
||||||
|
model)
|
||||||
|
fout_flag += 1
|
||||||
else:
|
else:
|
||||||
print(subdir_path + ' is not exist! please check!')
|
print(subdir_path + ' is not exist! please check!')
|
||||||
with open(fout, 'r') as f:
|
with open(fout, 'r') as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user