mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support AlpacaEval_V2 (#1006)
* support alpacaeval_v2 * support alpacaeval * update docs * update docs
This commit is contained in:
parent
0a6a03fe1a
commit
02e7eec911
@ -90,7 +90,7 @@ for _name in subjective_all_sets:
|
||||
dict(
|
||||
abbr=f"{_name}",
|
||||
type=SubjectiveCmpDataset,
|
||||
path="./data/subjective/",
|
||||
path="./data/subjective/alpaca_eval",
|
||||
name=_name,
|
||||
reader_cfg=subjective_reader_cfg,
|
||||
infer_cfg=subjective_infer_cfg,
|
||||
|
@ -92,7 +92,7 @@ for _name in subjective_all_sets:
|
||||
dict(
|
||||
abbr=f"{_name}",
|
||||
type=SubjectiveCmpDataset,
|
||||
path="./data/subjective/",
|
||||
path="./data/subjective/alpaca_eval",
|
||||
name=_name,
|
||||
reader_cfg=subjective_reader_cfg,
|
||||
infer_cfg=subjective_infer_cfg,
|
||||
|
@ -1,7 +1,6 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .datasets.subjective.alpaca_eval.alpacav1_judgeby_gpt4 import subjective_datasets as alpacav1
|
||||
from .datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import subjective_datasets as alpacav2
|
||||
|
||||
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3
|
||||
@ -12,7 +11,7 @@ from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.runners import SlurmSequentialRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
|
||||
from opencompass.tasks.outer_eval.alpacaeval import AlpacaEvalTask
|
||||
from opencompass.summarizers import AlpacaSummarizer
|
||||
|
||||
api_meta_template = dict(
|
||||
@ -29,7 +28,7 @@ api_meta_template = dict(
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceChatGLM3,
|
||||
abbr='chatglm3-6b-hf',
|
||||
abbr='chatglm3-6b',
|
||||
path='THUDM/chatglm3-6b',
|
||||
tokenizer_path='THUDM/chatglm3-6b',
|
||||
model_kwargs=dict(
|
||||
@ -54,52 +53,25 @@ models = [
|
||||
|
||||
datasets = [*alpacav2]
|
||||
|
||||
gpt4 = dict(
|
||||
abbr='gpt4-turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-1106-preview',
|
||||
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
|
||||
meta_template=api_meta_template,
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=4096,
|
||||
batch_size=4,
|
||||
retry=20,
|
||||
temperature=1,
|
||||
) # Re-inference gpt4's predictions or you can choose to use the pre-commited gpt4's predictions
|
||||
|
||||
|
||||
|
||||
# -------------Evalation Stage ----------------------------------------
|
||||
|
||||
## ------------- JudgeLLM Configuration
|
||||
judge_model = dict(
|
||||
gpt4_judge = dict(
|
||||
abbr='GPT4-Turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-1106-preview',
|
||||
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
|
||||
meta_template=api_meta_template,
|
||||
query_per_second=1,
|
||||
max_out_len=1024,
|
||||
max_seq_len=4096,
|
||||
batch_size=2,
|
||||
retry=20,
|
||||
temperature=0,
|
||||
config='weighted_alpaca_eval_gpt4_turbo'
|
||||
)
|
||||
|
||||
## ------------- Evaluation Configuration
|
||||
eval = dict(
|
||||
partitioner=dict(
|
||||
type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models
|
||||
type=NaivePartitioner
|
||||
),
|
||||
runner=dict(
|
||||
type=SlurmSequentialRunner,
|
||||
partition='llmeval',
|
||||
quotatype='auto',
|
||||
type=LocalRunner,
|
||||
max_num_workers=256,
|
||||
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
|
||||
),
|
||||
task=dict(type=AlpacaEvalTask, judge_cfg=gpt4_judge),
|
||||
)
|
||||
)
|
||||
work_dir = 'outputs/alpaca/'
|
||||
|
||||
summarizer = dict(type=AlpacaSummarizer, judge_type='v2')
|
||||
|
105
configs/eval_subjective_alpacaeval_oc.py
Normal file
105
configs/eval_subjective_alpacaeval_oc.py
Normal file
@ -0,0 +1,105 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .datasets.subjective.alpaca_eval.alpacav1_judgeby_gpt4 import subjective_datasets as alpacav1
|
||||
from .datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import subjective_datasets as alpacav2
|
||||
|
||||
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3
|
||||
from opencompass.models.openai_api import OpenAI, OpenAIAllesAPIN
|
||||
from opencompass.partitioners import NaivePartitioner, SizePartitioner
|
||||
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
|
||||
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.runners import SlurmSequentialRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
|
||||
from opencompass.summarizers import AlpacaSummarizer
|
||||
|
||||
api_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', api_role='HUMAN'),
|
||||
dict(role='BOT', api_role='BOT', generate=True),
|
||||
],
|
||||
reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
|
||||
)
|
||||
|
||||
# -------------Inference Stage ----------------------------------------
|
||||
|
||||
# For subjective evaluation, we often set do sample for models
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceChatGLM3,
|
||||
abbr='chatglm3-6b-hf',
|
||||
path='THUDM/chatglm3-6b',
|
||||
tokenizer_path='THUDM/chatglm3-6b',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
generation_kwargs=dict(
|
||||
do_sample=True,
|
||||
),
|
||||
meta_template=api_meta_template,
|
||||
max_out_len=2048,
|
||||
max_seq_len=4096,
|
||||
batch_size=1,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
||||
datasets = [*alpacav2]
|
||||
|
||||
gpt4 = dict(
|
||||
abbr='gpt4-turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-1106-preview',
|
||||
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
|
||||
meta_template=api_meta_template,
|
||||
query_per_second=1,
|
||||
max_out_len=2048,
|
||||
max_seq_len=4096,
|
||||
batch_size=4,
|
||||
retry=20,
|
||||
temperature=1,
|
||||
) # Re-inference gpt4's predictions or you can choose to use the pre-commited gpt4's predictions
|
||||
|
||||
|
||||
|
||||
# -------------Evalation Stage ----------------------------------------
|
||||
|
||||
## ------------- JudgeLLM Configuration
|
||||
judge_model = dict(
|
||||
abbr='GPT4-Turbo',
|
||||
type=OpenAI,
|
||||
path='gpt-4-1106-preview',
|
||||
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
|
||||
meta_template=api_meta_template,
|
||||
query_per_second=1,
|
||||
max_out_len=1024,
|
||||
max_seq_len=4096,
|
||||
batch_size=2,
|
||||
retry=20,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
## ------------- Evaluation Configuration
|
||||
eval = dict(
|
||||
partitioner=dict(
|
||||
type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models
|
||||
),
|
||||
runner=dict(
|
||||
type=SlurmSequentialRunner,
|
||||
partition='llmeval',
|
||||
quotatype='auto',
|
||||
max_num_workers=256,
|
||||
task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model),
|
||||
),
|
||||
)
|
||||
work_dir = 'outputs/alpaca/'
|
||||
|
||||
summarizer = dict(type=AlpacaSummarizer, judge_type='v2')
|
@ -7,14 +7,24 @@ with read_base():
|
||||
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
|
||||
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
|
||||
from opencompass.summarizers import AlpacaSummarizer
|
||||
from opencompass.tasks.outer_eval.alpacaeval import AlpacaEvalTask
|
||||
datasets = [*alpacav2]
|
||||
gpt4_judge = dict(
|
||||
abbr='GPT4-Turbo',
|
||||
path='gpt-4-1106-preview',
|
||||
key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
|
||||
config='weighted_alpaca_eval_gpt4_turbo'
|
||||
)
|
||||
## ------------- Evaluation Configuration
|
||||
eval = dict(
|
||||
partitioner=dict(
|
||||
type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models
|
||||
type=NaivePartitioner
|
||||
),
|
||||
runner=runner,
|
||||
given_pred=given_pred
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=256,
|
||||
task=dict(type=AlpacaEvalTask, judge_cfg=gpt4_judge),
|
||||
)
|
||||
)
|
||||
work_dir = 'outputs/alpaca/'
|
||||
|
||||
summarizer = dict(type=AlpacaSummarizer, judge_type='v2')
|
||||
|
@ -13,6 +13,13 @@ A popular evaluation method involves
|
||||
|
||||
We support the use of GPT-4 (or other JudgeLLM) for the subjective evaluation of models based on above methods.
|
||||
|
||||
## Current Supported Subjective Evaluation Datasets
|
||||
|
||||
1. AlginBench (https://github.com/THUDM/AlignBench)
|
||||
2. MTBench (https://github.com/lm-sys/FastChat)
|
||||
3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval)
|
||||
4. CompassArena (Internal dataset)
|
||||
|
||||
## Subjective Evaluation with Custom Dataset
|
||||
|
||||
The specific process includes:
|
||||
|
@ -72,6 +72,19 @@
|
||||
|
||||
</details>
|
||||
|
||||
5. Install alpaca-eval (Optional):
|
||||
|
||||
If you want to**evaluate alpaca-eval in official ways**, follow this step.
|
||||
|
||||
<details>
|
||||
<summary><b>click to show the details</b></summary>
|
||||
|
||||
```bash
|
||||
pip install alpaca-eval
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
# Dataset Preparation
|
||||
|
||||
The datasets supported by OpenCompass mainly include two parts:
|
||||
|
@ -13,6 +13,13 @@
|
||||
|
||||
我们基于以上方法支持了JudgeLLM用于模型的主观能力评估(目前opencompass仓库里支持的所有模型都可以直接作为JudgeLLM进行调用,此外一些专用的JudgeLLM我们也在计划支持中)。
|
||||
|
||||
## 目前已支持的主观评测数据集
|
||||
|
||||
1. AlginBench(https://github.com/THUDM/AlignBench)
|
||||
2. MTBench (https://github.com/lm-sys/FastChat)
|
||||
3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval)
|
||||
4. CompassArena(内部数据集)
|
||||
|
||||
## 自定义主观数据集评测
|
||||
|
||||
主观评测的具体流程包括:
|
||||
|
@ -73,6 +73,19 @@
|
||||
|
||||
</details>
|
||||
|
||||
5. 安装 alpaca-eval(可选):
|
||||
|
||||
如果你需要**使用官方alpaca-eval实现评测 alpaca-eval 数据集**,请执行此步骤,否则忽略这一步。
|
||||
|
||||
<details>
|
||||
<summary><b>点击查看详细</b></summary>
|
||||
|
||||
```bash
|
||||
pip install alpaca-eval
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
# 数据集准备
|
||||
|
||||
OpenCompass 支持的数据集主要包括两个部分:
|
||||
|
@ -65,6 +65,8 @@ class OpenAI(BaseAPIModel):
|
||||
meta_template: Optional[Dict] = None,
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
mode: str = 'none',
|
||||
logprobs: Optional[bool] = False,
|
||||
top_logprobs: Optional[int] = None,
|
||||
temperature: Optional[float] = None):
|
||||
|
||||
super().__init__(path=path,
|
||||
@ -78,6 +80,8 @@ class OpenAI(BaseAPIModel):
|
||||
self.temperature = temperature
|
||||
assert mode in ['none', 'front', 'mid', 'rear']
|
||||
self.mode = mode
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
|
||||
if isinstance(key, str):
|
||||
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
||||
@ -218,6 +222,8 @@ class OpenAI(BaseAPIModel):
|
||||
messages=messages,
|
||||
max_tokens=max_out_len,
|
||||
n=1,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
stop=None,
|
||||
temperature=temperature,
|
||||
)
|
||||
@ -234,7 +240,10 @@ class OpenAI(BaseAPIModel):
|
||||
str(raw_response.content))
|
||||
continue
|
||||
try:
|
||||
return response['choices'][0]['message']['content'].strip()
|
||||
if self.logprobs:
|
||||
return response['choices']
|
||||
else:
|
||||
return response['choices'][0]['message']['content'].strip()
|
||||
except KeyError:
|
||||
if 'error' in response:
|
||||
if response['error']['code'] == 'rate_limit_exceeded':
|
||||
|
128
opencompass/tasks/outer_eval/alpacaeval.py
Normal file
128
opencompass/tasks/outer_eval/alpacaeval.py
Normal file
@ -0,0 +1,128 @@
|
||||
# flake8: noqa: E501
|
||||
import copy
|
||||
import json
|
||||
import os.path as osp
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path,
|
||||
get_logger)
|
||||
|
||||
|
||||
class PredictionMerger:
|
||||
""""""
|
||||
|
||||
def __init__(self, cfg: ConfigDict) -> None:
|
||||
|
||||
self.cfg = cfg
|
||||
self.model_cfg = copy.deepcopy(self.cfg['model'])
|
||||
self.dataset_cfg = copy.deepcopy(self.cfg['dataset'])
|
||||
|
||||
self.work_dir = self.cfg.get('work_dir')
|
||||
|
||||
def run(self):
|
||||
filename = get_infer_output_path(
|
||||
self.model_cfg, self.dataset_cfg,
|
||||
osp.join(self.work_dir, 'predictions'))
|
||||
root, ext = osp.splitext(filename)
|
||||
partial_filename = root + '_0' + ext
|
||||
|
||||
if osp.exists(osp.realpath(filename)):
|
||||
return
|
||||
|
||||
if not osp.exists(osp.realpath(partial_filename)):
|
||||
print(f'{filename} not found')
|
||||
return
|
||||
|
||||
# Load predictions
|
||||
partial_filenames = []
|
||||
if osp.exists(osp.realpath(filename)):
|
||||
preds = mmengine.load(filename)
|
||||
else:
|
||||
preds, offset = {}, 0
|
||||
i = 1
|
||||
while osp.exists(osp.realpath(partial_filename)):
|
||||
partial_filenames.append(osp.realpath(partial_filename))
|
||||
_preds = mmengine.load(partial_filename)
|
||||
partial_filename = root + f'_{i}' + ext
|
||||
i += 1
|
||||
for _o in range(len(_preds)):
|
||||
preds[str(offset)] = _preds[str(_o)]
|
||||
offset += 1
|
||||
|
||||
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
||||
if len(preds) != len(dataset.test):
|
||||
print('length mismatch')
|
||||
return
|
||||
|
||||
with open(
|
||||
osp.realpath(osp.join(self.dataset_cfg['path'],
|
||||
'example.json')), 'r') as f:
|
||||
data_format = json.load(f)
|
||||
|
||||
for idx in range(len(preds)):
|
||||
data_format[idx]['output'] = preds[str(idx)]['prediction']
|
||||
data_format[idx]['generator'] = self.model_cfg['abbr']
|
||||
|
||||
print(f'Merge {partial_filenames} to {filename}')
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
json.dump(data_format, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
class AlpacaEvalTask(BaseTask):
|
||||
"""Subjective Evaluation Task.
|
||||
|
||||
This task is used to evaluate the metric between predictions and
|
||||
references.
|
||||
|
||||
Args:
|
||||
cfg (ConfigDict): The configuration of the entire evaluation task.
|
||||
"""
|
||||
|
||||
name_prefix = 'SubjectiveEval'
|
||||
log_subdir = 'logs/eval'
|
||||
output_subdir = 'results'
|
||||
|
||||
def __init__(self, cfg: ConfigDict):
|
||||
super().__init__(cfg)
|
||||
self.logger = get_logger()
|
||||
judge_cfg = cfg.eval.runner.task.get('judge_cfg', {})
|
||||
assert type(judge_cfg) == ConfigDict
|
||||
run_cfg = judge_cfg.get('run_cfg', {})
|
||||
self.num_gpus = run_cfg.get('num_gpus', 0)
|
||||
self.num_procs = run_cfg.get('num_procs', 1)
|
||||
self.judge_cfg = copy.deepcopy(judge_cfg)
|
||||
|
||||
def get_command(self, cfg_path, template):
|
||||
"""Get the command template for the task.
|
||||
|
||||
Args:
|
||||
cfg_path (str): The path to the config file of the task.
|
||||
template (str): The template which have '{task_cmd}' to format
|
||||
the command.
|
||||
"""
|
||||
# script_path = __file__
|
||||
alpaca_cfg = self.judge_cfg.get('config', None)
|
||||
api_key = self.judge_cfg.get('key', None)
|
||||
assert alpaca_cfg is not None
|
||||
all_cfg = Config.fromfile(cfg_path)
|
||||
model_cfg = all_cfg['models']
|
||||
dataset_cfg = all_cfg['datasets'][0][0]
|
||||
work_dir = osp.realpath(all_cfg['work_dir'])
|
||||
for m_cfg in model_cfg:
|
||||
PredictionMerger({
|
||||
'model': m_cfg,
|
||||
'dataset': dataset_cfg,
|
||||
'work_dir': work_dir
|
||||
}).run()
|
||||
filename = get_infer_output_path(m_cfg, dataset_cfg,
|
||||
osp.join(work_dir, 'predictions'))
|
||||
output_path = osp.join(work_dir, 'results', m_cfg['abbr'])
|
||||
command = f'export OPENAI_API_KEY={api_key}; alpaca_eval --model_outputs {filename} --annotators_config {alpaca_cfg} --output_path {output_path}'
|
||||
return template.format(task_cmd=command)
|
||||
|
||||
def run(self):
|
||||
# model_cfg can be a list of model configs
|
||||
pass
|
@ -132,30 +132,29 @@ class SubjectiveEvalTask(BaseTask):
|
||||
# Get partition name
|
||||
root, ext = osp.splitext(filename)
|
||||
partial_filename = root + '_0' + ext
|
||||
|
||||
# If no predictions get in predictions dir
|
||||
if not osp.exists(osp.realpath(filename)) and not osp.exists(
|
||||
osp.realpath(partial_filename)):
|
||||
return {'error': 'No predictions found.'}
|
||||
assert osp.exists(filename) or osp.exists(
|
||||
osp.realpath(partial_filename)
|
||||
), 'No predictions found for {filename}.'.format(filename=filename)
|
||||
|
||||
# If use Naive partition in infer stage
|
||||
if osp.exists(osp.realpath(filename)):
|
||||
preds = mmengine.load(filename)
|
||||
pred_strs = [
|
||||
preds[str(i)]['prediction'] for i in range(len(preds))
|
||||
]
|
||||
# If use Size partition in infer stage
|
||||
else:
|
||||
# If use Naive partition in infer stage
|
||||
if osp.exists(osp.realpath(filename)):
|
||||
filename = partial_filename
|
||||
pred_strs = []
|
||||
i = 1
|
||||
while osp.exists(osp.realpath(filename)):
|
||||
preds = mmengine.load(filename)
|
||||
pred_strs = [
|
||||
filename = root + f'_{i}' + ext
|
||||
i += 1
|
||||
pred_strs += [
|
||||
preds[str(i)]['prediction'] for i in range(len(preds))
|
||||
]
|
||||
# If use Size partition in infer stage
|
||||
else:
|
||||
filename = partial_filename
|
||||
pred_strs = []
|
||||
i = 1
|
||||
while osp.exists(osp.realpath(filename)):
|
||||
preds = mmengine.load(filename)
|
||||
filename = root + f'_{i}' + ext
|
||||
i += 1
|
||||
pred_strs += [
|
||||
preds[str(i)]['prediction'] for i in range(len(preds))
|
||||
]
|
||||
# Get all predictions in pred_strs
|
||||
|
||||
# If take SubjectSizePartition, get new pred_strs based on test_range
|
||||
|
@ -1 +1,2 @@
|
||||
alpaca-eval
|
||||
faiss_gpu==1.7.2
|
||||
|
Loading…
Reference in New Issue
Block a user