mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] update model configs (#574)
This commit is contained in:
parent
689ffe5b63
commit
d3de5c41fb
1
.gitignore
vendored
1
.gitignore
vendored
@ -89,3 +89,4 @@ docs/zh_cn/_build/
|
||||
|
||||
# sft config ignore list
|
||||
configs/sft_cfg/*B_*
|
||||
configs/cky/
|
||||
|
@ -22,8 +22,8 @@ MRPC_infer_cfg = dict(
|
||||
},
|
||||
ice_token='</E>',
|
||||
),
|
||||
retriever=dict(type=FixKRetriever),
|
||||
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]))
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
MRPC_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
|
||||
|
35
configs/datasets/humaneval/humaneval_gen_4a6eef.py
Normal file
35
configs/datasets/humaneval/humaneval_gen_4a6eef.py
Normal file
@ -0,0 +1,35 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import HFDataset, HumanEvaluator, humaneval_postprocess
|
||||
|
||||
humaneval_reader_cfg = dict(
|
||||
input_columns=['prompt'], output_column='task_id', train_split='test')
|
||||
|
||||
# TODO: allow empty output-column
|
||||
humaneval_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role='HUMAN',
|
||||
prompt='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a Python script for this problem:\n{prompt}\n\n### Response:\n'),
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
humaneval_eval_cfg = dict(
|
||||
evaluator=dict(type=HumanEvaluator),
|
||||
pred_role='BOT',
|
||||
k=[1, 10, 100], # the parameter only for humaneval
|
||||
pred_postprocessor=dict(type=humaneval_postprocess),
|
||||
)
|
||||
|
||||
humaneval_datasets = [
|
||||
dict(
|
||||
type=HFDataset,
|
||||
path='openai_humaneval',
|
||||
reader_cfg=humaneval_reader_cfg,
|
||||
infer_cfg=humaneval_infer_cfg,
|
||||
eval_cfg=humaneval_eval_cfg)
|
||||
]
|
24
configs/models/aquila/hf_aquila2_34b.py
Normal file
24
configs/models/aquila/hf_aquila2_34b.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='aquila2-34b-hf',
|
||||
path="BAAI/Aquila2-34B",
|
||||
tokenizer_path='BAAI/Aquila2-34B',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=2, num_procs=1),
|
||||
)
|
||||
]
|
24
configs/models/aquila/hf_aquila2_7b.py
Normal file
24
configs/models/aquila/hf_aquila2_7b.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='aquila2-7b-hf',
|
||||
path="BAAI/Aquila2-7B",
|
||||
tokenizer_path='BAAI/Aquila2-7B',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
33
configs/models/aquila/hf_aquilachat2_34b.py
Normal file
33
configs/models/aquila/hf_aquilachat2_34b.py
Normal file
@ -0,0 +1,33 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='### Human: ', end='\n'),
|
||||
dict(role='BOT', begin='### Assistant: ', end='</s>', generate=True),
|
||||
],
|
||||
eos_token_id=100007,
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='aquilachat2-34b-hf',
|
||||
path="BAAI/AquilaChat2-34B",
|
||||
tokenizer_path='BAAI/AquilaChat2-34B',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
meta_template=_meta_template,
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=2, num_procs=1),
|
||||
)
|
||||
]
|
34
configs/models/aquila/hf_aquilachat2_34b_16k.py
Normal file
34
configs/models/aquila/hf_aquilachat2_34b_16k.py
Normal file
@ -0,0 +1,34 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
_meta_template = dict(
|
||||
begin='###',
|
||||
round=[
|
||||
dict(role='HUMAN', begin='Human: ', end='###'),
|
||||
dict(role='BOT', begin='Assistant: ', end='</s>', generate=True),
|
||||
],
|
||||
eos_token_id=100007,
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='aquilachat2-34b-16k-hf',
|
||||
path="BAAI/AquilaChat2-34B-16K",
|
||||
tokenizer_path='BAAI/AquilaChat2-34B-16K',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
meta_template=_meta_template,
|
||||
max_out_len=100,
|
||||
max_seq_len=4096,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=2, num_procs=1),
|
||||
)
|
||||
]
|
33
configs/models/aquila/hf_aquilachat2_7b.py
Normal file
33
configs/models/aquila/hf_aquilachat2_7b.py
Normal file
@ -0,0 +1,33 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='<|startofpiece|>', end=''),
|
||||
dict(role='BOT', begin='<|endofpiece|>', end='</s>', generate=True),
|
||||
],
|
||||
eos_token_id=2,
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='aquilachat2-7b-hf',
|
||||
path="BAAI/AquilaChat2-7B",
|
||||
tokenizer_path='BAAI/AquilaChat2-7B',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
meta_template=_meta_template,
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
34
configs/models/aquila/hf_aquilachat2_7b_16k.py
Normal file
34
configs/models/aquila/hf_aquilachat2_7b_16k.py
Normal file
@ -0,0 +1,34 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
_meta_template = dict(
|
||||
begin='###',
|
||||
round=[
|
||||
dict(role='HUMAN', begin='Human: ', end='###'),
|
||||
dict(role='BOT', begin='Assistant: ', end='</s>', generate=True),
|
||||
],
|
||||
eos_token_id=100007,
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='aquilachat2-7b-16k-hf',
|
||||
path="BAAI/AquilaChat2-7B-16K",
|
||||
tokenizer_path='BAAI/AquilaChat2-7B-16K',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
meta_template=_meta_template,
|
||||
max_out_len=100,
|
||||
max_seq_len=4096,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
@ -7,15 +7,18 @@ models = [
|
||||
abbr='chatglm2-6b-hf',
|
||||
path='THUDM/chatglm2-6b',
|
||||
tokenizer_path='THUDM/chatglm2-6b',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
max_seq_len=4096,
|
||||
batch_size=8,
|
||||
model_kwargs=dict(trust_remote_code=True, device_map='auto', revision='a6d54fac46dff2db65d53416c207a4485ca6bd40'),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
31
configs/models/chatglm/hf_chatglm3_6b.py
Normal file
31
configs/models/chatglm/hf_chatglm3_6b.py
Normal file
@ -0,0 +1,31 @@
|
||||
from opencompass.models import HuggingFaceChatGLM3
|
||||
|
||||
api_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', api_role='HUMAN'),
|
||||
dict(role='BOT', api_role='BOT', generate=True),
|
||||
]
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceChatGLM3,
|
||||
abbr='chatglm3-6b-hf',
|
||||
path='THUDM/chatglm3-6b',
|
||||
tokenizer_path='THUDM/chatglm3-6b',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
meta_template=api_meta_template,
|
||||
max_out_len=100,
|
||||
max_seq_len=4096,
|
||||
batch_size=1,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1)
|
||||
)
|
||||
]
|
24
configs/models/chatglm/hf_chatglm3_6b_base.py
Normal file
24
configs/models/chatglm/hf_chatglm3_6b_base.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFace
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFace,
|
||||
abbr='chatglm3-6b-base-hf',
|
||||
path='THUDM/chatglm3-6b-base',
|
||||
tokenizer_path='THUDM/chatglm3-6b-base',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=4096,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
@ -7,6 +7,10 @@ models = [
|
||||
abbr='chatglm-6b-hf',
|
||||
path='THUDM/chatglm-6b',
|
||||
tokenizer_path='THUDM/chatglm-6b',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
@ -15,7 +19,6 @@ models = [
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
model_kwargs=dict(trust_remote_code=True, device_map='auto', revision='1d240ba371910e9282298d4592532d7f0f3e9f3e'),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
22
configs/models/hf_internlm/hf_internlm_20b.py
Normal file
22
configs/models/hf_internlm/hf_internlm_20b.py
Normal file
@ -0,0 +1,22 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='internlm-20b-hf',
|
||||
path="internlm/internlm-20b",
|
||||
tokenizer_path='internlm/internlm-20b',
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
use_fast=False,
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
model_kwargs=dict(trust_remote_code=True, device_map='auto'),
|
||||
run_cfg=dict(num_gpus=2, num_procs=1),
|
||||
)
|
||||
]
|
@ -7,6 +7,10 @@ models = [
|
||||
abbr='internlm-7b-hf',
|
||||
path="internlm/internlm-7b",
|
||||
tokenizer_path='internlm/internlm-7b',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
@ -16,7 +20,6 @@ models = [
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
model_kwargs=dict(trust_remote_code=True, device_map='auto'),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
33
configs/models/hf_internlm/hf_internlm_chat_20b.py
Normal file
33
configs/models/hf_internlm/hf_internlm_chat_20b.py
Normal file
@ -0,0 +1,33 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'),
|
||||
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
|
||||
],
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='internlm-chat-20b-hf',
|
||||
path="internlm/internlm-chat-20b",
|
||||
tokenizer_path='internlm/internlm-chat-20b',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
use_fast=False,
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
meta_template=_meta_template,
|
||||
run_cfg=dict(num_gpus=2, num_procs=1),
|
||||
)
|
||||
]
|
@ -14,21 +14,20 @@ models = [
|
||||
abbr='internlm-chat-7b-hf',
|
||||
path="internlm/internlm-chat-7b",
|
||||
tokenizer_path='internlm/internlm-chat-7b',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
use_fast=False,
|
||||
trust_remote_code=True,
|
||||
revision="1a6328795c6e207904e1eb58177e03ad24ae06f3"
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
meta_template=_meta_template,
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
revision="1a6328795c6e207904e1eb58177e03ad24ae06f3"),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
@ -14,6 +14,10 @@ models = [
|
||||
abbr='internlm-chat-7b-8k-hf',
|
||||
path="internlm/internlm-chat-7b-8k",
|
||||
tokenizer_path='internlm/internlm-chat-7b-8k',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
@ -24,7 +28,6 @@ models = [
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
meta_template=_meta_template,
|
||||
model_kwargs=dict(trust_remote_code=True, device_map='auto'),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
25
configs/models/lingowhale/hf_lingowhale_8b.py
Normal file
25
configs/models/lingowhale/hf_lingowhale_8b.py
Normal file
@ -0,0 +1,25 @@
|
||||
from opencompass.models import HuggingFace
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFace,
|
||||
abbr='lingowhale-8b-hf',
|
||||
path='deeplang-ai/LingoWhale-8B',
|
||||
tokenizer_path='deeplang-ai/LingoWhale-8B',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
torch_dtype='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
24
configs/models/mistral/hf_mistral_7b.py
Normal file
24
configs/models/mistral/hf_mistral_7b.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='mistral-7b-v0.1-hf',
|
||||
type=HuggingFaceCausalLM,
|
||||
path='mistralai/Mistral-7B-v0.1',
|
||||
tokenizer_path='mistralai/Mistral-7B-v0.1',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
34
configs/models/mistral/hf_mistral_7b_instruct.py
Normal file
34
configs/models/mistral/hf_mistral_7b_instruct.py
Normal file
@ -0,0 +1,34 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
_meta_template = dict(
|
||||
begin="<s>",
|
||||
round=[
|
||||
dict(role="HUMAN", begin='[INST]', end='[/INST]'),
|
||||
dict(role="BOT", begin="", end='</s>', generate=True),
|
||||
],
|
||||
eos_token_id=2
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='mistral-7b-instruct-v0.1-hf',
|
||||
type=HuggingFaceCausalLM,
|
||||
path='mistralai/Mistral-7B-Instruct-v0.1',
|
||||
tokenizer_path='mistralai/Mistral-7B-Instruct-v0.1',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
meta_template=_meta_template,
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
25
configs/models/qwen/hf_qwen_14b.py
Normal file
25
configs/models/qwen/hf_qwen_14b.py
Normal file
@ -0,0 +1,25 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='qwen-14b-hf',
|
||||
path="Qwen/Qwen-14B",
|
||||
tokenizer_path='Qwen/Qwen-14B',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
pad_token_id=151643,
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
33
configs/models/qwen/hf_qwen_14b_chat.py
Normal file
33
configs/models/qwen/hf_qwen_14b_chat.py
Normal file
@ -0,0 +1,33 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role="HUMAN", begin='\n<|im_start|>user\n', end='<|im_end|>'),
|
||||
dict(role="BOT", begin="\n<|im_start|>assistant\n", end='<|im_end|>', generate=True),
|
||||
],
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='qwen-14b-chat-hf',
|
||||
path="Qwen/Qwen-14B-Chat",
|
||||
tokenizer_path='Qwen/Qwen-14B-Chat',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,),
|
||||
pad_token_id=151643,
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
meta_template=_meta_template,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
@ -1,33 +1,25 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
# Please note that we have specified the revision here. Recently (on 20230827),
|
||||
# during our evaluations, we found that the newer revision models have a drop
|
||||
# of more than 5 points on datasets like GaokaoBench / mbpp.
|
||||
# We are not yet sure whether this drop is due to incorrect logic in OpenCompass
|
||||
# calling qwen or some other reasons. We would like to highlight this.
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='qwen-7b-hf',
|
||||
path="Qwen/Qwen-7B",
|
||||
tokenizer_path='Qwen/Qwen-7B',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
revision='39fc5fdcb95c8c367bbdb3bfc0db71d96266de09'
|
||||
),
|
||||
pad_token_id=151643,
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
revision='39fc5fdcb95c8c367bbdb3bfc0db71d96266de09'
|
||||
),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
@ -14,6 +14,10 @@ models = [
|
||||
abbr='qwen-7b-chat-hf',
|
||||
path="Qwen/Qwen-7B-Chat",
|
||||
tokenizer_path='Qwen/Qwen-7B-Chat',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
@ -24,7 +28,6 @@ models = [
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
meta_template=_meta_template,
|
||||
model_kwargs=dict(device_map='auto', trust_remote_code=True),
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
24
configs/models/skywork/hf_skywork_13b.py
Normal file
24
configs/models/skywork/hf_skywork_13b.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='skywork-13b-hf',
|
||||
path="Skywork/Skywork-13B-base",
|
||||
tokenizer_path='Skywork/Skywork-13B-base',
|
||||
model_kwargs=dict(
|
||||
device_map='auto',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
use_fast=False,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
24
configs/models/tigerbot/hf_tigerbot_70b_base.py
Normal file
24
configs/models/tigerbot/hf_tigerbot_70b_base.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='tigerbot-70b-base-v1-hf',
|
||||
path='TigerResearch/tigerbot-70b-base',
|
||||
tokenizer_path='TigerResearch/tigerbot-70b-base',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4, num_procs=1),
|
||||
),
|
||||
]
|
29
configs/models/tigerbot/hf_tigerbot_70b_chat_v2.py
Normal file
29
configs/models/tigerbot/hf_tigerbot_70b_chat_v2.py
Normal file
@ -0,0 +1,29 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='\n\n### Instruction:\n'),
|
||||
dict(role='BOT', begin='\n\n### Response:\n', generate=True),
|
||||
],
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='tigerbot-70b-chat-v2-hf',
|
||||
path="TigerResearch/tigerbot-70b-chat-v2",
|
||||
tokenizer_path='TigerResearch/tigerbot-70b-chat-v2',
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
meta_template=_meta_template,
|
||||
model_kwargs=dict(trust_remote_code=True, device_map='auto'),
|
||||
run_cfg=dict(num_gpus=4, num_procs=1),
|
||||
)
|
||||
]
|
32
configs/models/tigerbot/hf_tigerbot_70b_chat_v3.py
Normal file
32
configs/models/tigerbot/hf_tigerbot_70b_chat_v3.py
Normal file
@ -0,0 +1,32 @@
|
||||
from opencompass.models import HuggingFaceCausalLM
|
||||
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='\n\n### Instruction:\n'),
|
||||
dict(role='BOT', begin='\n\n### Response:\n', generate=True),
|
||||
],
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFaceCausalLM,
|
||||
abbr='tigerbot-70b-chat-v3-hf',
|
||||
path="TigerResearch/tigerbot-70b-chat-v3",
|
||||
tokenizer_path='TigerResearch/tigerbot-70b-chat-v3',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
meta_template=_meta_template,
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4, num_procs=1),
|
||||
)
|
||||
]
|
@ -17,6 +17,6 @@ models = [
|
||||
batch_size=8,
|
||||
model_kwargs=dict(device_map='auto'),
|
||||
batch_padding=False, # if false, inference with for-loop without batch padding
|
||||
run_cfg=dict(num_gpus=2, num_procs=1)
|
||||
run_cfg=dict(num_gpus=1, num_procs=1)
|
||||
)
|
||||
]
|
||||
|
24
configs/models/yi/hf_yi_34b.py
Normal file
24
configs/models/yi/hf_yi_34b.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFace
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFace,
|
||||
abbr='yi-34b-hf',
|
||||
path='01-ai/Yi-34B',
|
||||
tokenizer_path='01-ai/Yi-34B',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4, num_procs=1),
|
||||
)
|
||||
]
|
24
configs/models/yi/hf_yi_6b.py
Normal file
24
configs/models/yi/hf_yi_6b.py
Normal file
@ -0,0 +1,24 @@
|
||||
from opencompass.models import HuggingFace
|
||||
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=HuggingFace,
|
||||
abbr='yi-6b-hf',
|
||||
path='01-ai/Yi-6B',
|
||||
tokenizer_path='01-ai/Yi-6B',
|
||||
model_kwargs=dict(
|
||||
trust_remote_code=True,
|
||||
device_map='auto',
|
||||
),
|
||||
tokenizer_kwargs=dict(
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=True,
|
||||
),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
@ -66,7 +66,7 @@ class Gsm8kEvaluator(BaseEvaluator):
|
||||
count = 0
|
||||
details = []
|
||||
for i, j in zip(predictions, references):
|
||||
detail = {'pred': i, 'answers': j, 'correct': False}
|
||||
detail = {'pred': i, 'answer': j, 'correct': False}
|
||||
count += 1
|
||||
if i == j:
|
||||
correct += 1
|
||||
|
@ -4,6 +4,7 @@ from .claude_api import Claude # noqa: F401
|
||||
from .glm import GLM130B # noqa: F401, F403
|
||||
from .huggingface import HuggingFace # noqa: F401, F403
|
||||
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
||||
from .huggingface import HuggingFaceChatGLM3 # noqa: F401, F403
|
||||
from .intern_model import InternLM # noqa: F401, F403
|
||||
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
||||
from .minimax_api import MiniMax # noqa: F401
|
||||
|
@ -5,6 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.models.base_api import APITemplateParser
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils.logging import get_logger
|
||||
from opencompass.utils.prompt import PromptList
|
||||
@ -442,3 +443,85 @@ class HuggingFaceCausalLM(HuggingFace):
|
||||
is_trainable=False)
|
||||
self.model.eval()
|
||||
self.model.generation_config.do_sample = False
|
||||
|
||||
|
||||
class HuggingFaceChatGLM3(HuggingFace):
|
||||
"""Model wrapper around HuggingFace's ChatGLM3. Details available in
|
||||
`https://huggingface.co/THUDM/chatglm3-6b`.
|
||||
|
||||
model.chat() is used for inference.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
hf_cache_dir: Optional[str] = None,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
peft_path: Optional[str] = None,
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = None,
|
||||
extract_pred_after_decode: bool = False,
|
||||
batch_padding: bool = False,
|
||||
pad_token_id: Optional[int] = None,
|
||||
mode: str = 'none',
|
||||
num_extra_tokens: int = 50):
|
||||
super().__init__(path=path,
|
||||
hf_cache_dir=hf_cache_dir,
|
||||
max_seq_len=max_seq_len,
|
||||
tokenizer_path=tokenizer_path,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
peft_path=peft_path,
|
||||
tokenizer_only=tokenizer_only,
|
||||
model_kwargs=model_kwargs,
|
||||
meta_template=meta_template,
|
||||
extract_pred_after_decode=extract_pred_after_decode,
|
||||
batch_padding=batch_padding,
|
||||
pad_token_id=pad_token_id,
|
||||
mode=mode)
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
# used to compensate for #tokens occupied by sth like system prompt
|
||||
self.num_extra_tokens = num_extra_tokens
|
||||
|
||||
def generate(self,
|
||||
inputs: List[str or PromptList],
|
||||
max_out_len: int = 512,
|
||||
temperature: float = 0.6) -> str:
|
||||
"""Generate response from input prompt.
|
||||
|
||||
Args:
|
||||
inputs (list): input prompt
|
||||
max_out_len (int): max output length
|
||||
temperature (float): temperature for sampling
|
||||
"""
|
||||
responses = []
|
||||
for _input in inputs:
|
||||
assert isinstance(_input, (str, PromptList))
|
||||
if isinstance(_input, str):
|
||||
history = [{'role': 'user', 'content': _input}]
|
||||
else:
|
||||
history = []
|
||||
for item in _input:
|
||||
msg = {
|
||||
'content': item['prompt'],
|
||||
'role': {
|
||||
'HUMAN': 'user',
|
||||
'BOT': 'assistant',
|
||||
'SYSTEM': 'system'
|
||||
}[item['role']]
|
||||
}
|
||||
history.append(msg)
|
||||
user_content = history[-1]['content']
|
||||
history = history[:-1]
|
||||
try:
|
||||
response, history = self.model.chat(self.tokenizer,
|
||||
user_content,
|
||||
history=history)
|
||||
responses.append(response)
|
||||
except Exception:
|
||||
responses.append('')
|
||||
return responses
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
return len(self.tokenizer.encode(prompt)) + self.num_extra_tokens
|
||||
|
@ -266,7 +266,13 @@ class EDAccEvaluator(AccEvaluator):
|
||||
|
||||
for i in range(len(predictions)):
|
||||
pred, ref = predictions[i], references[i]
|
||||
dists = [self.dist(pred, cand) for cand in ref['candidates']]
|
||||
dists = []
|
||||
for cands in ref['candidates']:
|
||||
if isinstance(cands, str):
|
||||
d = self.dist(pred, cands)
|
||||
else:
|
||||
d = np.min([self.dist(pred, cand) for cand in cands])
|
||||
dists.append(d)
|
||||
preds.append(np.argmin(dists))
|
||||
golds.append(ref['label'])
|
||||
|
||||
|
@ -190,7 +190,7 @@ class PPLInferencer(BaseInferencer):
|
||||
label, prompt.replace(ice_str, ''), prompt, res, index)
|
||||
output_handler.results_dict[str(
|
||||
index)][f'label: {str(label)}'][
|
||||
'BPB'] = res * token_num_list[idx] / len(
|
||||
'BPB'] = res * token_num_list[index] / len(
|
||||
prompt.replace(ice_str, '').encode())
|
||||
index = index + 1
|
||||
ppl.append(sub_ppl_list)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .dlc import * # noqa: F401, F403
|
||||
from .local import * # noqa: F401, F403
|
||||
from .slurm import * # noqa: F401, F403
|
||||
from .slurm_sequential import * # noqa: F401, F403
|
||||
|
242
opencompass/runners/slurm_sequential.py
Normal file
242
opencompass/runners/slurm_sequential.py
Normal file
@ -0,0 +1,242 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from multiprocessing import Pipe, Pool
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import ConfigDict
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.registry import RUNNERS, TASKS
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
from .base import BaseRunner
|
||||
|
||||
|
||||
@RUNNERS.register_module()
|
||||
class SlurmSequentialRunner(BaseRunner):
|
||||
"""Distributed runner based on Slurm. It will launch tasks in parallel
|
||||
using `srun` command.
|
||||
|
||||
This runner launches tasks one by one for execution. A new task will only
|
||||
be launched when and only when max_num_workers is not met, and the previous
|
||||
task has been successfully allocated to a machine. Therefore, unlike the
|
||||
`SlurmRunner`, at most only one task will be in the PENDING status at the
|
||||
same time during a run, making the random_sleep strategy no longer
|
||||
necessary. In addition, this runner also includes a feature to
|
||||
automatically kill all jobs by the job_id on exit.
|
||||
|
||||
The runner will obtain the job_id by reading the srun output similar to
|
||||
`srun: Job 123456 scheduled successfully!`. If the output of srun does not
|
||||
match this pattern, the runner will not work properly.
|
||||
|
||||
Args:
|
||||
task (ConfigDict): Task type config.
|
||||
max_num_workers (int): Max number of workers to run in parallel.
|
||||
Defaults to 32.
|
||||
retry (int): Number of retries if the job failed. Defaults to 2.
|
||||
partition (str): Slurm partition name. Defaults to None.
|
||||
quotatype (str): Slurm quota type. Defaults to None.
|
||||
qos (str): Slurm quality of service. Defaults to None.
|
||||
debug (bool): Whether to run in debug mode. Defaults to False.
|
||||
lark_bot_url (str): Lark bot url. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
task: ConfigDict,
|
||||
task_prefix: str = '',
|
||||
max_num_workers: int = 32,
|
||||
retry: int = 2,
|
||||
partition: str = None,
|
||||
quotatype: str = None,
|
||||
qos: str = None,
|
||||
debug: bool = False,
|
||||
lark_bot_url: str = None):
|
||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||
self.max_num_workers = max_num_workers
|
||||
self.retry = retry
|
||||
self.partition = partition
|
||||
self.quotatype = quotatype
|
||||
self.qos = qos
|
||||
self.task_prefix = task_prefix
|
||||
|
||||
logger = get_logger()
|
||||
if self.quotatype in ['spot', 'auto']:
|
||||
logger.warning(
|
||||
'Quotatype spot or auto may cause stability issues, '
|
||||
'reserved is recommended.')
|
||||
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
if not self.debug:
|
||||
return self._launch_wo_debug(tasks)
|
||||
else:
|
||||
return [self._launch(task) for task in tasks]
|
||||
|
||||
def _launch_wo_debug(self,
|
||||
tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
launched_bar = tqdm(total=len(tasks), desc='Launched')
|
||||
finished_bar = tqdm(total=len(tasks), desc='Finished')
|
||||
job_ids = []
|
||||
status = []
|
||||
|
||||
def _update(result):
|
||||
finished_bar.update()
|
||||
status.append(result)
|
||||
return result
|
||||
|
||||
def _err_update(err):
|
||||
finished_bar.update()
|
||||
traceback.print_exc()
|
||||
status.append(('', -1))
|
||||
|
||||
try:
|
||||
parent_conns = []
|
||||
num_workers = min(self.max_num_workers, len(tasks))
|
||||
with Pool(processes=num_workers) as pool:
|
||||
for task in tasks:
|
||||
parent_conn, child_conn = Pipe()
|
||||
_ = pool.apply_async(self._launch,
|
||||
kwds={
|
||||
'cfg': task,
|
||||
'child_conn': child_conn
|
||||
},
|
||||
callback=_update,
|
||||
error_callback=_err_update)
|
||||
time.sleep(0.5)
|
||||
|
||||
job_id = parent_conn.recv()
|
||||
launched_bar.update()
|
||||
parent_conns.append(parent_conn)
|
||||
job_ids.append(job_id)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
return status
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
finally:
|
||||
launched_bar.close()
|
||||
finished_bar.close()
|
||||
for parent_conn in parent_conns:
|
||||
while parent_conn.poll():
|
||||
try:
|
||||
job_id = parent_conn.recv()
|
||||
job_ids.append(job_id)
|
||||
except EOFError:
|
||||
break
|
||||
parent_conn.close()
|
||||
|
||||
for job_id in tqdm(job_ids, desc='clear sruns'):
|
||||
if job_id is None:
|
||||
continue
|
||||
cmd = f'scancel {job_id}'
|
||||
p = subprocess.Popen(cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT)
|
||||
p.wait()
|
||||
|
||||
def _launch(self, cfg: ConfigDict, child_conn: Pipe = None):
|
||||
logger = get_logger()
|
||||
|
||||
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
|
||||
num_gpus = task.num_gpus
|
||||
task_name = task.name
|
||||
task_name = self.task_prefix + task_name
|
||||
|
||||
# Dump task config to file
|
||||
mmengine.mkdir_or_exist('tmp/')
|
||||
param_file = f'tmp/{os.getpid()}_params.py'
|
||||
process = None
|
||||
try:
|
||||
cfg.dump(param_file)
|
||||
|
||||
# Build up slurm command
|
||||
tmpl = 'srun'
|
||||
if self.partition:
|
||||
tmpl += f' -p {self.partition}'
|
||||
if self.quotatype:
|
||||
tmpl += f' --quotatype={self.quotatype}'
|
||||
if self.qos:
|
||||
tmpl += f' --qos={self.qos}'
|
||||
if num_gpus > 0:
|
||||
tmpl += f' --gres=gpu:{num_gpus}'
|
||||
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
||||
get_cmd = partial(task.get_command,
|
||||
cfg_path=param_file,
|
||||
template=tmpl)
|
||||
cmd = get_cmd()
|
||||
|
||||
logger.debug(f'Running command: {cmd}')
|
||||
|
||||
retry = self.retry
|
||||
output_paths = task.get_output_paths()
|
||||
|
||||
if self.debug:
|
||||
while True:
|
||||
process = subprocess.Popen(cmd, shell=True, text=True)
|
||||
process.communicate()
|
||||
process.wait()
|
||||
if self._job_failed(process.returncode, output_paths):
|
||||
if retry > 0:
|
||||
logger.warning(
|
||||
f'task {task_name} failed, retrying...')
|
||||
retry -= 1
|
||||
cmd = get_cmd()
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
out_path = task.get_log_path(file_extension='out')
|
||||
mmengine.mkdir_or_exist(osp.split(out_path)[0])
|
||||
stdout = open(out_path, 'w', encoding='utf-8')
|
||||
stderr = subprocess.PIPE
|
||||
while True:
|
||||
process = subprocess.Popen(cmd,
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=stdout,
|
||||
stderr=stderr)
|
||||
job_id = None
|
||||
while True:
|
||||
line = process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
match = re.search(
|
||||
r'srun: Job (\d+) scheduled successfully!', line)
|
||||
if match and job_id is None:
|
||||
job_id = match.group(1)
|
||||
child_conn.send(job_id)
|
||||
stdout.write(line)
|
||||
process.wait()
|
||||
if self._job_failed(process.returncode, output_paths):
|
||||
if retry > 0:
|
||||
retry -= 1
|
||||
cmd = get_cmd()
|
||||
else:
|
||||
logger.warning(
|
||||
f'task {task_name} fail, see\n{out_path}')
|
||||
break
|
||||
else:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
finally:
|
||||
# Clean up
|
||||
if child_conn is not None:
|
||||
child_conn.send(None)
|
||||
child_conn.close()
|
||||
if process is not None:
|
||||
process.kill()
|
||||
os.remove(param_file)
|
||||
return task_name, process.returncode
|
||||
|
||||
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
|
||||
return return_code != 0 or not all(
|
||||
osp.exists(output_path) for output_path in output_paths)
|
337
opencompass/summarizers/summarizer_pretrain.py
Normal file
337
opencompass/summarizers/summarizer_pretrain.py
Normal file
@ -0,0 +1,337 @@
|
||||
# flake8: noqa
|
||||
# yapf: disable
|
||||
import getpass
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine
|
||||
import pytz
|
||||
import tabulate
|
||||
from mmengine import ConfigDict
|
||||
|
||||
from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg,
|
||||
get_infer_output_path, get_logger,
|
||||
model_abbr_from_cfg)
|
||||
from opencompass.utils.prompt import get_prompt_hash
|
||||
|
||||
METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth']
|
||||
METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len']
|
||||
|
||||
class PretrainSummarizer:
|
||||
""""""
|
||||
|
||||
def __init__(self, config: ConfigDict, dataset_abbrs: Optional[List[str]] = None, summary_groups: List = [], prompt_db = None) -> None:
|
||||
self.tasks = []
|
||||
self.cfg = config
|
||||
self.logger = get_logger()
|
||||
|
||||
# Enable lark bot if lark_url is presented
|
||||
self.lark_reporter = None
|
||||
if self.cfg.get('lark_bot_url', None):
|
||||
self.lark_reporter = LarkReporter(self.cfg['lark_bot_url'])
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
output_path: str = None,
|
||||
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): # noqa
|
||||
|
||||
model_cfgs = self.cfg['models']
|
||||
dataset_cfgs = self.cfg['datasets']
|
||||
summarizer_cfg = self.cfg.get('summarizer', {})
|
||||
work_dir = self.cfg['work_dir']
|
||||
|
||||
# pick up results
|
||||
raw_results = {}
|
||||
parsed_results = {}
|
||||
dataset_metrics = {}
|
||||
|
||||
model_abbrs = [model_abbr_from_cfg(model) for model in model_cfgs]
|
||||
for model in model_cfgs:
|
||||
model_abbr = model_abbr_from_cfg(model)
|
||||
parsed_results[model_abbr] = {}
|
||||
raw_results[model_abbr] = {}
|
||||
for dataset in dataset_cfgs:
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||
filepath = get_infer_output_path(model, dataset, osp.join(work_dir, 'results'))
|
||||
if not osp.exists(filepath):
|
||||
continue
|
||||
result = mmengine.load(filepath)
|
||||
raw_results[model_abbr][dataset_abbr] = result
|
||||
if 'error' in result:
|
||||
self.debug(f'error in {model_abbr} {dataset_abbr} {result["error"]}')
|
||||
continue
|
||||
else:
|
||||
parsed_results[model_abbr][dataset_abbr] = []
|
||||
dataset_metrics[dataset_abbr] = []
|
||||
for metric, score in result.items():
|
||||
if metric not in METRIC_BLACKLIST and isinstance(score, (int, float)):
|
||||
parsed_results[model_abbr][dataset_abbr].append(score)
|
||||
dataset_metrics[dataset_abbr].append(metric)
|
||||
else:
|
||||
continue
|
||||
if len(parsed_results[model_abbr][dataset_abbr]) == 0:
|
||||
self.logger.warning(f'unknown result format: {result}, continue')
|
||||
del parsed_results[model_abbr][dataset_abbr]
|
||||
del dataset_metrics[dataset_abbr]
|
||||
continue
|
||||
indice = sorted(
|
||||
list(range(len(dataset_metrics[dataset_abbr]))),
|
||||
key=lambda i: (
|
||||
METRIC_WHITELIST.index(dataset_metrics[dataset_abbr][i])
|
||||
if dataset_metrics[dataset_abbr][i] in METRIC_WHITELIST
|
||||
else len(METRIC_WHITELIST)
|
||||
)
|
||||
)
|
||||
parsed_results[model_abbr][dataset_abbr] = [parsed_results[model_abbr][dataset_abbr][i] for i in indice]
|
||||
dataset_metrics[dataset_abbr] = [dataset_metrics[dataset_abbr][i] for i in indice]
|
||||
|
||||
# parse eval mode
|
||||
dataset_eval_mode = {}
|
||||
for dataset in dataset_cfgs:
|
||||
inferencer = dataset.get('infer_cfg', {}).get('inferencer', {}).get('type', '')
|
||||
inferencer = inferencer if isinstance(inferencer, str) else inferencer.__name__
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||
if 'GenInferencer' in inferencer:
|
||||
dataset_eval_mode[dataset_abbr] = 'gen'
|
||||
elif 'PPLInferencer' in inferencer:
|
||||
dataset_eval_mode[dataset_abbr] = 'ppl'
|
||||
else:
|
||||
dataset_eval_mode[dataset_abbr] = 'unknown'
|
||||
self.logger.warning(f'unknown inferencer: {inferencer} - {dataset_abbr}')
|
||||
|
||||
# calculate group metrics
|
||||
summary_groups = summarizer_cfg.get('summary_groups', [])
|
||||
for sg in summary_groups:
|
||||
for model_abbr in model_abbrs:
|
||||
results = {}
|
||||
eval_modes = []
|
||||
for dataset_abbr in sg['subsets']:
|
||||
if dataset_abbr in parsed_results[model_abbr]:
|
||||
results[dataset_abbr] = parsed_results[model_abbr][dataset_abbr][0]
|
||||
eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown'))
|
||||
if len(results) == len(sg['subsets']):
|
||||
if 'weights' in sg:
|
||||
numerator = sum(results[k] * sg['weights'][k] for k in sg['weights'])
|
||||
denominator = sum(sg['weights'].values())
|
||||
metric = 'weighted_average'
|
||||
else:
|
||||
numerator = sum(results[k] for k in results)
|
||||
denominator = len(results)
|
||||
metric = 'naive_average'
|
||||
results[metric] = numerator / denominator
|
||||
eval_modes = list(set(eval_modes))
|
||||
eval_mode = eval_modes[0] if len(eval_modes) == 1 else 'mixed'
|
||||
|
||||
# add to global results
|
||||
raw_results[model_abbr][sg['name']] = results
|
||||
parsed_results[model_abbr][sg['name']] = [numerator / denominator]
|
||||
dataset_metrics[sg['name']] = [metric]
|
||||
dataset_eval_mode[sg['name']] = eval_mode
|
||||
elif len(results) == 0:
|
||||
continue
|
||||
else:
|
||||
raw_results[model_abbr][sg['name']] = {'error': 'missing datasets: {}'.format(set(sg['subsets']) - set(results.keys()))}
|
||||
|
||||
prompt_version = {dataset_abbr_from_cfg(d): get_prompt_hash(d)[:6] for d in dataset_cfgs}
|
||||
|
||||
# format table
|
||||
summarizer_dataset_abbrs = []
|
||||
if summarizer_cfg.get('dataset_abbrs') is None:
|
||||
for dataset in dataset_cfgs:
|
||||
dataset_abbr = dataset_abbr_from_cfg(dataset)
|
||||
if dataset_abbr in dataset_metrics:
|
||||
for metric in dataset_metrics[dataset_abbr]:
|
||||
summarizer_dataset_abbrs.append((dataset_abbr, metric))
|
||||
else:
|
||||
summarizer_dataset_abbrs.append((dataset_abbr, None))
|
||||
for dataset_abbr in dataset_metrics:
|
||||
for metric in dataset_metrics[dataset_abbr]:
|
||||
if (dataset_abbr, metric) not in summarizer_dataset_abbrs:
|
||||
summarizer_dataset_abbrs.append((dataset_abbr, metric))
|
||||
else:
|
||||
for item in summarizer_cfg['dataset_abbrs']:
|
||||
if isinstance(item, str):
|
||||
summarizer_dataset_abbrs.append((item, None))
|
||||
elif isinstance(item, (list, tuple)):
|
||||
summarizer_dataset_abbrs.append((item[0], item[1]))
|
||||
table = []
|
||||
checkpoints = [model_abbr.rsplit('_', 1)[1] if '_' in model_abbr else model_abbr for model_abbr in model_abbrs]
|
||||
# model_abbrs = [model_abbr.rsplit("_", 1)[0] for model_abbr in model_abbrs]
|
||||
header = ['dataset', 'version', 'metric', 'mode'] + model_abbrs
|
||||
time_zone = pytz.timezone('Asia/Shanghai')
|
||||
now = datetime.now(time_zone)
|
||||
time = now.strftime('%m/%d %H:%M')
|
||||
times = [time] * len(model_abbrs)
|
||||
table.append(header)
|
||||
table.append(['dataset', 'version', 'metric', 'mode'] + times)
|
||||
table.append(['dataset', 'version', 'metric', 'mode']+ checkpoints)
|
||||
dataset_score = [0]* len(model_abbrs)
|
||||
dataset_num = [0] * len(model_abbrs)
|
||||
|
||||
for dataset_abbr, metric in summarizer_dataset_abbrs:
|
||||
# if dataset_abbr not in dataset_metrics:
|
||||
# table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(model_abbrs))
|
||||
# continue
|
||||
if metric is None and dataset_abbr in dataset_metrics:
|
||||
index = 0
|
||||
metric = dataset_metrics[dataset_abbr][0]
|
||||
elif dataset_abbr in dataset_metrics and metric in dataset_metrics[dataset_abbr]:
|
||||
index = dataset_metrics[dataset_abbr].index(metric)
|
||||
elif not dataset_abbr.startswith('---'):
|
||||
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(model_abbrs))
|
||||
continue
|
||||
if dataset_abbr.startswith('---'):
|
||||
row = [dataset_abbr,'-','-','-']
|
||||
else:
|
||||
row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric, dataset_eval_mode.get(dataset_abbr, '-')]
|
||||
for i, model_abbr in enumerate(model_abbrs):
|
||||
if dataset_abbr in parsed_results[model_abbr]:
|
||||
if index == 0:
|
||||
row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][index]))
|
||||
dataset_score[i] += parsed_results[model_abbr][dataset_abbr][index]
|
||||
dataset_num[i] += 1
|
||||
# row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][index]))
|
||||
else:
|
||||
if dataset_abbr.startswith('---') and dataset_num[i] != 0:
|
||||
row.append('{:.02f}'.format(dataset_score[i] / dataset_num[i]))
|
||||
dataset_score[i] = 0
|
||||
dataset_num[i] = 0
|
||||
else:
|
||||
row.append('-')
|
||||
table.append(row)
|
||||
|
||||
# format raw txt
|
||||
raw_dataset_abbrs = []
|
||||
for model_abbr in model_abbrs:
|
||||
for dataset_abbr in raw_results[model_abbr]:
|
||||
if dataset_abbr not in raw_dataset_abbrs:
|
||||
raw_dataset_abbrs.append(dataset_abbr)
|
||||
raw_txts = []
|
||||
for model_abbr in model_abbrs:
|
||||
raw_txts.append('-------------------------------')
|
||||
raw_txts.append(f'Model: {model_abbr}')
|
||||
for dataset_abbr in raw_dataset_abbrs:
|
||||
result = raw_results[model_abbr].get(dataset_abbr, '{}')
|
||||
raw_txts.append(f'{dataset_abbr}: {result}')
|
||||
raw_txts = '\n'.join(raw_txts)
|
||||
|
||||
# output to screean
|
||||
print(tabulate.tabulate(table, headers='firstrow'))
|
||||
|
||||
# output to file
|
||||
if output_path is None:
|
||||
output_path = osp.join(work_dir, 'summary', f'summary_{time_str}.txt')
|
||||
output_csv_path = osp.join(work_dir, 'summary', f'summary_{time_str}.csv')
|
||||
else:
|
||||
output_csv_path = output_path.replace('.txt', '.csv')
|
||||
|
||||
output_dir = osp.split(output_path)[0]
|
||||
mmengine.mkdir_or_exist(output_dir)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(time_str + '\n')
|
||||
f.write('tabulate format\n')
|
||||
f.write('^' * 128 + '\n')
|
||||
f.write(tabulate.tabulate(table, headers='firstrow') + '\n')
|
||||
f.write('$' * 128 + '\n')
|
||||
f.write('\n' + '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n')
|
||||
f.write('csv format\n')
|
||||
f.write('^' * 128 + '\n')
|
||||
f.write('\n'.join([','.join(row) for row in table]) + '\n')
|
||||
f.write('$' * 128 + '\n')
|
||||
f.write('\n' + '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n')
|
||||
f.write('raw format\n')
|
||||
f.write('^' * 128 + '\n')
|
||||
f.write(raw_txts + '\n')
|
||||
f.write('$' * 128 + '\n')
|
||||
self.logger.info(f'write summary to {osp.abspath(output_path)}')
|
||||
|
||||
if self.lark_reporter:
|
||||
content = f'{getpass.getuser()} 的'
|
||||
content += f'详细评测汇总已输出至 {osp.abspath(output_path)}'
|
||||
self.lark_reporter.post(content)
|
||||
|
||||
with open(output_csv_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join([','.join(row) for row in table]) + '\n')
|
||||
self.logger.info(f'write csv to {osp.abspath(output_csv_path)}')
|
||||
|
||||
|
||||
summary_groups = summarizer_cfg.get('summary_groups', [])
|
||||
for sg in summary_groups:
|
||||
for model_abbr in model_abbrs:
|
||||
results = {}
|
||||
eval_modes = []
|
||||
for dataset_abbr in sg['subsets']:
|
||||
if dataset_abbr in parsed_results[model_abbr]:
|
||||
results[dataset_abbr] = (parsed_results[model_abbr][dataset_abbr][-1],parsed_results[model_abbr][dataset_abbr][-2])
|
||||
eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown'))
|
||||
|
||||
if len(results) == len(sg['subsets']):
|
||||
numerator1 = sum(results[k][0] for k in results)
|
||||
numerator2 = sum(results[k][1] for k in results)
|
||||
denominator = len(results)
|
||||
metric = 'correct_bpb-incorrect_bpb'
|
||||
|
||||
count_ppl = eval_modes.count('ppl')
|
||||
count_gen = len(eval_modes)-count_ppl
|
||||
if count_ppl==0:
|
||||
results[metric] = -1
|
||||
else:
|
||||
results[metric] = (numerator1+count_gen) / count_ppl
|
||||
eval_modes = list(set(eval_modes))
|
||||
eval_mode = eval_modes[0] if len(eval_modes) == 1 else 'mixed'
|
||||
# add to global results
|
||||
|
||||
raw_results[model_abbr][sg['name']] = results
|
||||
parsed_results[model_abbr][sg['name']] = [((numerator1+count_gen) / count_ppl) if count_ppl != 0 else -1, ((numerator2+count_gen) / count_ppl) if count_ppl != 0 else -1]
|
||||
dataset_metrics[sg['name']] = ['incorrect_bpb','correct_bpb']
|
||||
dataset_eval_mode[sg['name']] = eval_mode
|
||||
|
||||
elif len(results) == 0:
|
||||
continue
|
||||
else:
|
||||
raw_results[model_abbr][sg['name']] = {'error': 'missing datasets: {}'.format(set(sg['subsets']) - set(results.keys()))}
|
||||
|
||||
table = []
|
||||
table.append(['', '', '', ''] + [''] * len(model_abbrs))
|
||||
table.append(['', '', '', ''] + [''] * len(model_abbrs))
|
||||
table.append(['', '', '', ''] + [''] * len(model_abbrs))
|
||||
for dataset_abbr, metric in summarizer_dataset_abbrs:
|
||||
incorrect_bpb = -1
|
||||
correct_bpb = -1
|
||||
if dataset_abbr not in dataset_metrics:
|
||||
table.append([dataset_abbr, '', '', ''] + [''] * len(model_abbrs))
|
||||
continue
|
||||
if metric is None:
|
||||
index = 0
|
||||
try:
|
||||
incorrect_bpb = dataset_metrics[dataset_abbr].index('incorrect_bpb')
|
||||
correct_bpb = dataset_metrics[dataset_abbr].index('correct_bpb')
|
||||
except ValueError:
|
||||
try:
|
||||
incorrect_bpb = dataset_metrics[dataset_abbr].index('wrong_bpb')
|
||||
correct_bpb = dataset_metrics[dataset_abbr].index('right_bpb')
|
||||
except ValueError:
|
||||
incorrect_bpb = -1
|
||||
correct_bpb = -1
|
||||
metric = 'correct_bpb-incorrect_bpb'
|
||||
elif metric in dataset_metrics[dataset_abbr]:
|
||||
index = dataset_metrics[dataset_abbr].index(metric)
|
||||
else:
|
||||
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(model_abbrs))
|
||||
continue
|
||||
|
||||
row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric,
|
||||
dataset_eval_mode.get(dataset_abbr, '-')]
|
||||
for model_abbr in model_abbrs:
|
||||
if dataset_abbr in parsed_results[model_abbr]:
|
||||
if incorrect_bpb != -1 and correct_bpb != -1:
|
||||
row.append('{:.02f}/{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][correct_bpb],
|
||||
parsed_results[model_abbr][dataset_abbr][incorrect_bpb]))
|
||||
else:
|
||||
row.append('{:.02f}'.format(-1))
|
||||
else:
|
||||
row.append('-')
|
||||
table.append(row)
|
||||
with open(output_csv_path, 'a', encoding='utf-8') as f:
|
||||
f.write('\n'.join([','.join(row) for row in table]) + '\n')
|
@ -287,7 +287,7 @@ class OpenICLEvalTask(BaseTask):
|
||||
result['prompt'] = origin_prediction['origin_prompt']
|
||||
result['origin_prediction'] = pred_dicts[i]['prediction']
|
||||
result['predictions'] = details[i]['pred']
|
||||
result['references'] = details[i]['answers']
|
||||
result['references'] = details[i]['answer']
|
||||
result['correct'] = details[i]['correct']
|
||||
results[str(i)] = result
|
||||
return results
|
||||
@ -324,7 +324,7 @@ class OpenICLEvalTask(BaseTask):
|
||||
bpbs = [value['BPB'] for value in values]
|
||||
incorrect_bpb_list.append(
|
||||
(sum(bpbs) - min(bpbs)) / (len(bpbs) - 1))
|
||||
bpb_list.append(statistics.mean(bpbs))
|
||||
bpb_list.append(min(bpbs))
|
||||
|
||||
def filters(origins):
|
||||
targets = [target for target in origins if not math.isnan(target)]
|
||||
|
Loading…
Reference in New Issue
Block a user