diff --git a/.gitignore b/.gitignore index f2eab368..8e81c083 100644 --- a/.gitignore +++ b/.gitignore @@ -89,3 +89,4 @@ docs/zh_cn/_build/ # sft config ignore list configs/sft_cfg/*B_* +configs/cky/ diff --git a/configs/datasets/GLUE_MRPC/GLUE_MRPC_ppl_96564c.py b/configs/datasets/GLUE_MRPC/GLUE_MRPC_ppl_96564c.py index f9a29619..e6399b82 100644 --- a/configs/datasets/GLUE_MRPC/GLUE_MRPC_ppl_96564c.py +++ b/configs/datasets/GLUE_MRPC/GLUE_MRPC_ppl_96564c.py @@ -22,8 +22,8 @@ MRPC_infer_cfg = dict( }, ice_token='', ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4])) + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer)) MRPC_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) diff --git a/configs/datasets/humaneval/humaneval_gen_4a6eef.py b/configs/datasets/humaneval/humaneval_gen_4a6eef.py new file mode 100644 index 00000000..f528c7ce --- /dev/null +++ b/configs/datasets/humaneval/humaneval_gen_4a6eef.py @@ -0,0 +1,35 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import HFDataset, HumanEvaluator, humaneval_postprocess + +humaneval_reader_cfg = dict( + input_columns=['prompt'], output_column='task_id', train_split='test') + +# TODO: allow empty output-column +humaneval_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a Python script for this problem:\n{prompt}\n\n### Response:\n'), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +humaneval_eval_cfg = dict( + evaluator=dict(type=HumanEvaluator), + pred_role='BOT', + k=[1, 10, 100], # the parameter only for humaneval + pred_postprocessor=dict(type=humaneval_postprocess), +) + +humaneval_datasets = [ + dict( + type=HFDataset, + path='openai_humaneval', + reader_cfg=humaneval_reader_cfg, + infer_cfg=humaneval_infer_cfg, + eval_cfg=humaneval_eval_cfg) +] diff --git a/configs/models/aquila/hf_aquila2_34b.py b/configs/models/aquila/hf_aquila2_34b.py new file mode 100644 index 00000000..e0194a5a --- /dev/null +++ b/configs/models/aquila/hf_aquila2_34b.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFaceCausalLM + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='aquila2-34b-hf', + path="BAAI/Aquila2-34B", + tokenizer_path='BAAI/Aquila2-34B', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=2, num_procs=1), + ) +] diff --git a/configs/models/aquila/hf_aquila2_7b.py b/configs/models/aquila/hf_aquila2_7b.py new file mode 100644 index 00000000..95af1f7d --- /dev/null +++ b/configs/models/aquila/hf_aquila2_7b.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFaceCausalLM + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='aquila2-7b-hf', + path="BAAI/Aquila2-7B", + tokenizer_path='BAAI/Aquila2-7B', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/aquila/hf_aquilachat2_34b.py b/configs/models/aquila/hf_aquilachat2_34b.py new file mode 100644 index 00000000..112b39df --- /dev/null +++ b/configs/models/aquila/hf_aquilachat2_34b.py @@ -0,0 +1,33 @@ +from opencompass.models import HuggingFaceCausalLM + +_meta_template = dict( + round=[ + dict(role='HUMAN', begin='### Human: ', end='\n'), + dict(role='BOT', begin='### Assistant: ', end='', generate=True), + ], + eos_token_id=100007, +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='aquilachat2-34b-hf', + path="BAAI/AquilaChat2-34B", + tokenizer_path='BAAI/AquilaChat2-34B', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=2, num_procs=1), + ) +] diff --git a/configs/models/aquila/hf_aquilachat2_34b_16k.py b/configs/models/aquila/hf_aquilachat2_34b_16k.py new file mode 100644 index 00000000..ccf28dde --- /dev/null +++ b/configs/models/aquila/hf_aquilachat2_34b_16k.py @@ -0,0 +1,34 @@ +from opencompass.models import HuggingFaceCausalLM + +_meta_template = dict( + begin='###', + round=[ + dict(role='HUMAN', begin='Human: ', end='###'), + dict(role='BOT', begin='Assistant: ', end='', generate=True), + ], + eos_token_id=100007, +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='aquilachat2-34b-16k-hf', + path="BAAI/AquilaChat2-34B-16K", + tokenizer_path='BAAI/AquilaChat2-34B-16K', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=2, num_procs=1), + ) +] diff --git a/configs/models/aquila/hf_aquilachat2_7b.py b/configs/models/aquila/hf_aquilachat2_7b.py new file mode 100644 index 00000000..ff964d05 --- /dev/null +++ b/configs/models/aquila/hf_aquilachat2_7b.py @@ -0,0 +1,33 @@ +from opencompass.models import HuggingFaceCausalLM + +_meta_template = dict( + round=[ + dict(role='HUMAN', begin='<|startofpiece|>', end=''), + dict(role='BOT', begin='<|endofpiece|>', end='', generate=True), + ], + eos_token_id=2, +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='aquilachat2-7b-hf', + path="BAAI/AquilaChat2-7B", + tokenizer_path='BAAI/AquilaChat2-7B', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/aquila/hf_aquilachat2_7b_16k.py b/configs/models/aquila/hf_aquilachat2_7b_16k.py new file mode 100644 index 00000000..55794259 --- /dev/null +++ b/configs/models/aquila/hf_aquilachat2_7b_16k.py @@ -0,0 +1,34 @@ +from opencompass.models import HuggingFaceCausalLM + +_meta_template = dict( + begin='###', + round=[ + dict(role='HUMAN', begin='Human: ', end='###'), + dict(role='BOT', begin='Assistant: ', end='', generate=True), + ], + eos_token_id=100007, +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='aquilachat2-7b-16k-hf', + path="BAAI/AquilaChat2-7B-16K", + tokenizer_path='BAAI/AquilaChat2-7B-16K', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/chatglm/hf_chatglm2_6b.py b/configs/models/chatglm/hf_chatglm2_6b.py index ad3c00f6..e59ae099 100644 --- a/configs/models/chatglm/hf_chatglm2_6b.py +++ b/configs/models/chatglm/hf_chatglm2_6b.py @@ -7,15 +7,18 @@ models = [ abbr='chatglm2-6b-hf', path='THUDM/chatglm2-6b', tokenizer_path='THUDM/chatglm2-6b', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), tokenizer_kwargs=dict( padding_side='left', truncation_side='left', trust_remote_code=True, ), max_out_len=100, - max_seq_len=2048, + max_seq_len=4096, batch_size=8, - model_kwargs=dict(trust_remote_code=True, device_map='auto', revision='a6d54fac46dff2db65d53416c207a4485ca6bd40'), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/configs/models/chatglm/hf_chatglm3_6b.py b/configs/models/chatglm/hf_chatglm3_6b.py new file mode 100644 index 00000000..8088bcd0 --- /dev/null +++ b/configs/models/chatglm/hf_chatglm3_6b.py @@ -0,0 +1,31 @@ +from opencompass.models import HuggingFaceChatGLM3 + +api_meta_template = dict( + round=[ + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ] +) + +models = [ + dict( + type=HuggingFaceChatGLM3, + abbr='chatglm3-6b-hf', + path='THUDM/chatglm3-6b', + tokenizer_path='THUDM/chatglm3-6b', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + meta_template=api_meta_template, + max_out_len=100, + max_seq_len=4096, + batch_size=1, + run_cfg=dict(num_gpus=1, num_procs=1) + ) +] diff --git a/configs/models/chatglm/hf_chatglm3_6b_base.py b/configs/models/chatglm/hf_chatglm3_6b_base.py new file mode 100644 index 00000000..17f5d5ba --- /dev/null +++ b/configs/models/chatglm/hf_chatglm3_6b_base.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFace + + +models = [ + dict( + type=HuggingFace, + abbr='chatglm3-6b-base-hf', + path='THUDM/chatglm3-6b-base', + tokenizer_path='THUDM/chatglm3-6b-base', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=4096, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/chatglm/hf_chatglm_6b.py b/configs/models/chatglm/hf_chatglm_6b.py index f51f00a1..e013de2e 100644 --- a/configs/models/chatglm/hf_chatglm_6b.py +++ b/configs/models/chatglm/hf_chatglm_6b.py @@ -7,6 +7,10 @@ models = [ abbr='chatglm-6b-hf', path='THUDM/chatglm-6b', tokenizer_path='THUDM/chatglm-6b', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), tokenizer_kwargs=dict( padding_side='left', truncation_side='left', @@ -15,7 +19,6 @@ models = [ max_out_len=100, max_seq_len=2048, batch_size=8, - model_kwargs=dict(trust_remote_code=True, device_map='auto', revision='1d240ba371910e9282298d4592532d7f0f3e9f3e'), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/configs/models/hf_internlm/hf_internlm_20b.py b/configs/models/hf_internlm/hf_internlm_20b.py new file mode 100644 index 00000000..9af67533 --- /dev/null +++ b/configs/models/hf_internlm/hf_internlm_20b.py @@ -0,0 +1,22 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='internlm-20b-hf', + path="internlm/internlm-20b", + tokenizer_path='internlm/internlm-20b', + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + use_fast=False, + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + model_kwargs=dict(trust_remote_code=True, device_map='auto'), + run_cfg=dict(num_gpus=2, num_procs=1), + ) +] diff --git a/configs/models/hf_internlm/hf_internlm_7b.py b/configs/models/hf_internlm/hf_internlm_7b.py index 31ec8484..649e0c75 100644 --- a/configs/models/hf_internlm/hf_internlm_7b.py +++ b/configs/models/hf_internlm/hf_internlm_7b.py @@ -7,6 +7,10 @@ models = [ abbr='internlm-7b-hf', path="internlm/internlm-7b", tokenizer_path='internlm/internlm-7b', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), tokenizer_kwargs=dict( padding_side='left', truncation_side='left', @@ -16,7 +20,6 @@ models = [ max_out_len=100, max_seq_len=2048, batch_size=8, - model_kwargs=dict(trust_remote_code=True, device_map='auto'), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/configs/models/hf_internlm/hf_internlm_chat_20b.py b/configs/models/hf_internlm/hf_internlm_chat_20b.py new file mode 100644 index 00000000..b5f82b04 --- /dev/null +++ b/configs/models/hf_internlm/hf_internlm_chat_20b.py @@ -0,0 +1,33 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + round=[ + dict(role='HUMAN', begin='<|User|>:', end='\n'), + dict(role='BOT', begin='<|Bot|>:', end='\n', generate=True), + ], +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='internlm-chat-20b-hf', + path="internlm/internlm-chat-20b", + tokenizer_path='internlm/internlm-chat-20b', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + use_fast=False, + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + meta_template=_meta_template, + run_cfg=dict(num_gpus=2, num_procs=1), + ) +] diff --git a/configs/models/hf_internlm/hf_internlm_chat_7b.py b/configs/models/hf_internlm/hf_internlm_chat_7b.py index 0d0fc61f..8a6fd9a2 100644 --- a/configs/models/hf_internlm/hf_internlm_chat_7b.py +++ b/configs/models/hf_internlm/hf_internlm_chat_7b.py @@ -14,21 +14,20 @@ models = [ abbr='internlm-chat-7b-hf', path="internlm/internlm-chat-7b", tokenizer_path='internlm/internlm-chat-7b', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), tokenizer_kwargs=dict( padding_side='left', truncation_side='left', use_fast=False, trust_remote_code=True, - revision="1a6328795c6e207904e1eb58177e03ad24ae06f3" ), max_out_len=100, max_seq_len=2048, batch_size=8, meta_template=_meta_template, - model_kwargs=dict( - trust_remote_code=True, - device_map='auto', - revision="1a6328795c6e207904e1eb58177e03ad24ae06f3"), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/configs/models/hf_internlm/hf_internlm_chat_7b_8k.py b/configs/models/hf_internlm/hf_internlm_chat_7b_8k.py index 19b9a757..e3907c7c 100644 --- a/configs/models/hf_internlm/hf_internlm_chat_7b_8k.py +++ b/configs/models/hf_internlm/hf_internlm_chat_7b_8k.py @@ -14,6 +14,10 @@ models = [ abbr='internlm-chat-7b-8k-hf', path="internlm/internlm-chat-7b-8k", tokenizer_path='internlm/internlm-chat-7b-8k', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), tokenizer_kwargs=dict( padding_side='left', truncation_side='left', @@ -24,7 +28,6 @@ models = [ max_seq_len=2048, batch_size=8, meta_template=_meta_template, - model_kwargs=dict(trust_remote_code=True, device_map='auto'), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/configs/models/lingowhale/hf_lingowhale_8b.py b/configs/models/lingowhale/hf_lingowhale_8b.py new file mode 100644 index 00000000..45544e75 --- /dev/null +++ b/configs/models/lingowhale/hf_lingowhale_8b.py @@ -0,0 +1,25 @@ +from opencompass.models import HuggingFace + + +models = [ + dict( + type=HuggingFace, + abbr='lingowhale-8b-hf', + path='deeplang-ai/LingoWhale-8B', + tokenizer_path='deeplang-ai/LingoWhale-8B', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + torch_dtype='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/mistral/hf_mistral_7b.py b/configs/models/mistral/hf_mistral_7b.py new file mode 100644 index 00000000..bae2ce32 --- /dev/null +++ b/configs/models/mistral/hf_mistral_7b.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + dict( + abbr='mistral-7b-v0.1-hf', + type=HuggingFaceCausalLM, + path='mistralai/Mistral-7B-v0.1', + tokenizer_path='mistralai/Mistral-7B-v0.1', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/mistral/hf_mistral_7b_instruct.py b/configs/models/mistral/hf_mistral_7b_instruct.py new file mode 100644 index 00000000..3f8256f4 --- /dev/null +++ b/configs/models/mistral/hf_mistral_7b_instruct.py @@ -0,0 +1,34 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + begin="", + round=[ + dict(role="HUMAN", begin='[INST]', end='[/INST]'), + dict(role="BOT", begin="", end='', generate=True), + ], + eos_token_id=2 +) + +models = [ + dict( + abbr='mistral-7b-instruct-v0.1-hf', + type=HuggingFaceCausalLM, + path='mistralai/Mistral-7B-Instruct-v0.1', + tokenizer_path='mistralai/Mistral-7B-Instruct-v0.1', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/qwen/hf_qwen_14b.py b/configs/models/qwen/hf_qwen_14b.py new file mode 100644 index 00000000..3e4541f2 --- /dev/null +++ b/configs/models/qwen/hf_qwen_14b.py @@ -0,0 +1,25 @@ +from opencompass.models import HuggingFaceCausalLM + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='qwen-14b-hf', + path="Qwen/Qwen-14B", + tokenizer_path='Qwen/Qwen-14B', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + pad_token_id=151643, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/qwen/hf_qwen_14b_chat.py b/configs/models/qwen/hf_qwen_14b_chat.py new file mode 100644 index 00000000..5d5bdac7 --- /dev/null +++ b/configs/models/qwen/hf_qwen_14b_chat.py @@ -0,0 +1,33 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + round=[ + dict(role="HUMAN", begin='\n<|im_start|>user\n', end='<|im_end|>'), + dict(role="BOT", begin="\n<|im_start|>assistant\n", end='<|im_end|>', generate=True), + ], +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='qwen-14b-chat-hf', + path="Qwen/Qwen-14B-Chat", + tokenizer_path='Qwen/Qwen-14B-Chat', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False,), + pad_token_id=151643, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + meta_template=_meta_template, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/qwen/hf_qwen_7b.py b/configs/models/qwen/hf_qwen_7b.py index 64d765aa..9e535cad 100644 --- a/configs/models/qwen/hf_qwen_7b.py +++ b/configs/models/qwen/hf_qwen_7b.py @@ -1,33 +1,25 @@ from opencompass.models import HuggingFaceCausalLM -# Please note that we have specified the revision here. Recently (on 20230827), -# during our evaluations, we found that the newer revision models have a drop -# of more than 5 points on datasets like GaokaoBench / mbpp. -# We are not yet sure whether this drop is due to incorrect logic in OpenCompass -# calling qwen or some other reasons. We would like to highlight this. - models = [ dict( type=HuggingFaceCausalLM, abbr='qwen-7b-hf', path="Qwen/Qwen-7B", tokenizer_path='Qwen/Qwen-7B', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), tokenizer_kwargs=dict( padding_side='left', truncation_side='left', trust_remote_code=True, use_fast=False, - revision='39fc5fdcb95c8c367bbdb3bfc0db71d96266de09' ), pad_token_id=151643, max_out_len=100, max_seq_len=2048, batch_size=8, - model_kwargs=dict( - device_map='auto', - trust_remote_code=True, - revision='39fc5fdcb95c8c367bbdb3bfc0db71d96266de09' - ), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/configs/models/qwen/hf_qwen_7b_chat.py b/configs/models/qwen/hf_qwen_7b_chat.py index b84ea5ca..a1dd8e84 100644 --- a/configs/models/qwen/hf_qwen_7b_chat.py +++ b/configs/models/qwen/hf_qwen_7b_chat.py @@ -14,6 +14,10 @@ models = [ abbr='qwen-7b-chat-hf', path="Qwen/Qwen-7B-Chat", tokenizer_path='Qwen/Qwen-7B-Chat', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True + ), tokenizer_kwargs=dict( padding_side='left', truncation_side='left', @@ -24,7 +28,6 @@ models = [ max_seq_len=2048, batch_size=8, meta_template=_meta_template, - model_kwargs=dict(device_map='auto', trust_remote_code=True), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/configs/models/skywork/hf_skywork_13b.py b/configs/models/skywork/hf_skywork_13b.py new file mode 100644 index 00000000..495a3392 --- /dev/null +++ b/configs/models/skywork/hf_skywork_13b.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFaceCausalLM + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='skywork-13b-hf', + path="Skywork/Skywork-13B-base", + tokenizer_path='Skywork/Skywork-13B-base', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + use_fast=False, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/tigerbot/hf_tigerbot_70b_base.py b/configs/models/tigerbot/hf_tigerbot_70b_base.py new file mode 100644 index 00000000..12cfd4d0 --- /dev/null +++ b/configs/models/tigerbot/hf_tigerbot_70b_base.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='tigerbot-70b-base-v1-hf', + path='TigerResearch/tigerbot-70b-base', + tokenizer_path='TigerResearch/tigerbot-70b-base', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=4, num_procs=1), + ), +] diff --git a/configs/models/tigerbot/hf_tigerbot_70b_chat_v2.py b/configs/models/tigerbot/hf_tigerbot_70b_chat_v2.py new file mode 100644 index 00000000..7bdb7d15 --- /dev/null +++ b/configs/models/tigerbot/hf_tigerbot_70b_chat_v2.py @@ -0,0 +1,29 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + round=[ + dict(role='HUMAN', begin='\n\n### Instruction:\n'), + dict(role='BOT', begin='\n\n### Response:\n', generate=True), + ], +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='tigerbot-70b-chat-v2-hf', + path="TigerResearch/tigerbot-70b-chat-v2", + tokenizer_path='TigerResearch/tigerbot-70b-chat-v2', + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + meta_template=_meta_template, + model_kwargs=dict(trust_remote_code=True, device_map='auto'), + run_cfg=dict(num_gpus=4, num_procs=1), + ) +] diff --git a/configs/models/tigerbot/hf_tigerbot_70b_chat_v3.py b/configs/models/tigerbot/hf_tigerbot_70b_chat_v3.py new file mode 100644 index 00000000..55911755 --- /dev/null +++ b/configs/models/tigerbot/hf_tigerbot_70b_chat_v3.py @@ -0,0 +1,32 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + round=[ + dict(role='HUMAN', begin='\n\n### Instruction:\n'), + dict(role='BOT', begin='\n\n### Response:\n', generate=True), + ], +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='tigerbot-70b-chat-v3-hf', + path="TigerResearch/tigerbot-70b-chat-v3", + tokenizer_path='TigerResearch/tigerbot-70b-chat-v3', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=4, num_procs=1), + ) +] diff --git a/configs/models/vicuna/hf_vicuna_13b_v15.py b/configs/models/vicuna/hf_vicuna_13b_v15.py index 958e9101..ede90573 100644 --- a/configs/models/vicuna/hf_vicuna_13b_v15.py +++ b/configs/models/vicuna/hf_vicuna_13b_v15.py @@ -17,6 +17,6 @@ models = [ batch_size=8, model_kwargs=dict(device_map='auto'), batch_padding=False, # if false, inference with for-loop without batch padding - run_cfg=dict(num_gpus=2, num_procs=1) + run_cfg=dict(num_gpus=1, num_procs=1) ) ] diff --git a/configs/models/yi/hf_yi_34b.py b/configs/models/yi/hf_yi_34b.py new file mode 100644 index 00000000..3f20f416 --- /dev/null +++ b/configs/models/yi/hf_yi_34b.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFace + + +models = [ + dict( + type=HuggingFace, + abbr='yi-34b-hf', + path='01-ai/Yi-34B', + tokenizer_path='01-ai/Yi-34B', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=4, num_procs=1), + ) +] diff --git a/configs/models/yi/hf_yi_6b.py b/configs/models/yi/hf_yi_6b.py new file mode 100644 index 00000000..c376d868 --- /dev/null +++ b/configs/models/yi/hf_yi_6b.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFace + + +models = [ + dict( + type=HuggingFace, + abbr='yi-6b-hf', + path='01-ai/Yi-6B', + tokenizer_path='01-ai/Yi-6B', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/opencompass/datasets/gsm8k.py b/opencompass/datasets/gsm8k.py index d440b2ec..d7655e82 100644 --- a/opencompass/datasets/gsm8k.py +++ b/opencompass/datasets/gsm8k.py @@ -66,7 +66,7 @@ class Gsm8kEvaluator(BaseEvaluator): count = 0 details = [] for i, j in zip(predictions, references): - detail = {'pred': i, 'answers': j, 'correct': False} + detail = {'pred': i, 'answer': j, 'correct': False} count += 1 if i == j: correct += 1 diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index bf653c14..b727e72c 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -4,6 +4,7 @@ from .claude_api import Claude # noqa: F401 from .glm import GLM130B # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403 from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 +from .huggingface import HuggingFaceChatGLM3 # noqa: F401, F403 from .intern_model import InternLM # noqa: F401, F403 from .llama2 import Llama2, Llama2Chat # noqa: F401, F403 from .minimax_api import MiniMax # noqa: F401 diff --git a/opencompass/models/huggingface.py b/opencompass/models/huggingface.py index ffa9a8aa..c38c2d25 100644 --- a/opencompass/models/huggingface.py +++ b/opencompass/models/huggingface.py @@ -5,6 +5,7 @@ import numpy as np import torch from opencompass.models.base import BaseModel +from opencompass.models.base_api import APITemplateParser from opencompass.registry import MODELS from opencompass.utils.logging import get_logger from opencompass.utils.prompt import PromptList @@ -442,3 +443,85 @@ class HuggingFaceCausalLM(HuggingFace): is_trainable=False) self.model.eval() self.model.generation_config.do_sample = False + + +class HuggingFaceChatGLM3(HuggingFace): + """Model wrapper around HuggingFace's ChatGLM3. Details available in + `https://huggingface.co/THUDM/chatglm3-6b`. + + model.chat() is used for inference. + """ + + def __init__(self, + path: str, + hf_cache_dir: Optional[str] = None, + max_seq_len: int = 2048, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + peft_path: Optional[str] = None, + tokenizer_only: bool = False, + model_kwargs: dict = dict(device_map='auto'), + meta_template: Optional[Dict] = None, + extract_pred_after_decode: bool = False, + batch_padding: bool = False, + pad_token_id: Optional[int] = None, + mode: str = 'none', + num_extra_tokens: int = 50): + super().__init__(path=path, + hf_cache_dir=hf_cache_dir, + max_seq_len=max_seq_len, + tokenizer_path=tokenizer_path, + tokenizer_kwargs=tokenizer_kwargs, + peft_path=peft_path, + tokenizer_only=tokenizer_only, + model_kwargs=model_kwargs, + meta_template=meta_template, + extract_pred_after_decode=extract_pred_after_decode, + batch_padding=batch_padding, + pad_token_id=pad_token_id, + mode=mode) + self.template_parser = APITemplateParser(meta_template) + # used to compensate for #tokens occupied by sth like system prompt + self.num_extra_tokens = num_extra_tokens + + def generate(self, + inputs: List[str or PromptList], + max_out_len: int = 512, + temperature: float = 0.6) -> str: + """Generate response from input prompt. + + Args: + inputs (list): input prompt + max_out_len (int): max output length + temperature (float): temperature for sampling + """ + responses = [] + for _input in inputs: + assert isinstance(_input, (str, PromptList)) + if isinstance(_input, str): + history = [{'role': 'user', 'content': _input}] + else: + history = [] + for item in _input: + msg = { + 'content': item['prompt'], + 'role': { + 'HUMAN': 'user', + 'BOT': 'assistant', + 'SYSTEM': 'system' + }[item['role']] + } + history.append(msg) + user_content = history[-1]['content'] + history = history[:-1] + try: + response, history = self.model.chat(self.tokenizer, + user_content, + history=history) + responses.append(response) + except Exception: + responses.append('') + return responses + + def get_token_len(self, prompt: str) -> int: + return len(self.tokenizer.encode(prompt)) + self.num_extra_tokens diff --git a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py index 2a355592..d1a3d8d2 100644 --- a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py @@ -266,7 +266,13 @@ class EDAccEvaluator(AccEvaluator): for i in range(len(predictions)): pred, ref = predictions[i], references[i] - dists = [self.dist(pred, cand) for cand in ref['candidates']] + dists = [] + for cands in ref['candidates']: + if isinstance(cands, str): + d = self.dist(pred, cands) + else: + d = np.min([self.dist(pred, cand) for cand in cands]) + dists.append(d) preds.append(np.argmin(dists)) golds.append(ref['label']) diff --git a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py index 0fa60bee..e82a015d 100644 --- a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py @@ -190,7 +190,7 @@ class PPLInferencer(BaseInferencer): label, prompt.replace(ice_str, ''), prompt, res, index) output_handler.results_dict[str( index)][f'label: {str(label)}'][ - 'BPB'] = res * token_num_list[idx] / len( + 'BPB'] = res * token_num_list[index] / len( prompt.replace(ice_str, '').encode()) index = index + 1 ppl.append(sub_ppl_list) diff --git a/opencompass/runners/__init__.py b/opencompass/runners/__init__.py index 5e2bb1ef..f4a3207e 100644 --- a/opencompass/runners/__init__.py +++ b/opencompass/runners/__init__.py @@ -1,3 +1,4 @@ from .dlc import * # noqa: F401, F403 from .local import * # noqa: F401, F403 from .slurm import * # noqa: F401, F403 +from .slurm_sequential import * # noqa: F401, F403 diff --git a/opencompass/runners/slurm_sequential.py b/opencompass/runners/slurm_sequential.py new file mode 100644 index 00000000..aa3f5493 --- /dev/null +++ b/opencompass/runners/slurm_sequential.py @@ -0,0 +1,242 @@ +import os +import os.path as osp +import re +import subprocess +import time +import traceback +from functools import partial +from multiprocessing import Pipe, Pool +from typing import Any, Dict, List, Tuple + +import mmengine +from mmengine.config import ConfigDict +from tqdm import tqdm + +from opencompass.registry import RUNNERS, TASKS +from opencompass.utils import get_logger + +from .base import BaseRunner + + +@RUNNERS.register_module() +class SlurmSequentialRunner(BaseRunner): + """Distributed runner based on Slurm. It will launch tasks in parallel + using `srun` command. + + This runner launches tasks one by one for execution. A new task will only + be launched when and only when max_num_workers is not met, and the previous + task has been successfully allocated to a machine. Therefore, unlike the + `SlurmRunner`, at most only one task will be in the PENDING status at the + same time during a run, making the random_sleep strategy no longer + necessary. In addition, this runner also includes a feature to + automatically kill all jobs by the job_id on exit. + + The runner will obtain the job_id by reading the srun output similar to + `srun: Job 123456 scheduled successfully!`. If the output of srun does not + match this pattern, the runner will not work properly. + + Args: + task (ConfigDict): Task type config. + max_num_workers (int): Max number of workers to run in parallel. + Defaults to 32. + retry (int): Number of retries if the job failed. Defaults to 2. + partition (str): Slurm partition name. Defaults to None. + quotatype (str): Slurm quota type. Defaults to None. + qos (str): Slurm quality of service. Defaults to None. + debug (bool): Whether to run in debug mode. Defaults to False. + lark_bot_url (str): Lark bot url. Defaults to None. + """ + + def __init__(self, + task: ConfigDict, + task_prefix: str = '', + max_num_workers: int = 32, + retry: int = 2, + partition: str = None, + quotatype: str = None, + qos: str = None, + debug: bool = False, + lark_bot_url: str = None): + super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) + self.max_num_workers = max_num_workers + self.retry = retry + self.partition = partition + self.quotatype = quotatype + self.qos = qos + self.task_prefix = task_prefix + + logger = get_logger() + if self.quotatype in ['spot', 'auto']: + logger.warning( + 'Quotatype spot or auto may cause stability issues, ' + 'reserved is recommended.') + + def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: + if not self.debug: + return self._launch_wo_debug(tasks) + else: + return [self._launch(task) for task in tasks] + + def _launch_wo_debug(self, + tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: + launched_bar = tqdm(total=len(tasks), desc='Launched') + finished_bar = tqdm(total=len(tasks), desc='Finished') + job_ids = [] + status = [] + + def _update(result): + finished_bar.update() + status.append(result) + return result + + def _err_update(err): + finished_bar.update() + traceback.print_exc() + status.append(('', -1)) + + try: + parent_conns = [] + num_workers = min(self.max_num_workers, len(tasks)) + with Pool(processes=num_workers) as pool: + for task in tasks: + parent_conn, child_conn = Pipe() + _ = pool.apply_async(self._launch, + kwds={ + 'cfg': task, + 'child_conn': child_conn + }, + callback=_update, + error_callback=_err_update) + time.sleep(0.5) + + job_id = parent_conn.recv() + launched_bar.update() + parent_conns.append(parent_conn) + job_ids.append(job_id) + + pool.close() + pool.join() + return status + except KeyboardInterrupt: + raise + finally: + launched_bar.close() + finished_bar.close() + for parent_conn in parent_conns: + while parent_conn.poll(): + try: + job_id = parent_conn.recv() + job_ids.append(job_id) + except EOFError: + break + parent_conn.close() + + for job_id in tqdm(job_ids, desc='clear sruns'): + if job_id is None: + continue + cmd = f'scancel {job_id}' + p = subprocess.Popen(cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + p.wait() + + def _launch(self, cfg: ConfigDict, child_conn: Pipe = None): + logger = get_logger() + + task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type'])) + num_gpus = task.num_gpus + task_name = task.name + task_name = self.task_prefix + task_name + + # Dump task config to file + mmengine.mkdir_or_exist('tmp/') + param_file = f'tmp/{os.getpid()}_params.py' + process = None + try: + cfg.dump(param_file) + + # Build up slurm command + tmpl = 'srun' + if self.partition: + tmpl += f' -p {self.partition}' + if self.quotatype: + tmpl += f' --quotatype={self.quotatype}' + if self.qos: + tmpl += f' --qos={self.qos}' + if num_gpus > 0: + tmpl += f' --gres=gpu:{num_gpus}' + tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' + get_cmd = partial(task.get_command, + cfg_path=param_file, + template=tmpl) + cmd = get_cmd() + + logger.debug(f'Running command: {cmd}') + + retry = self.retry + output_paths = task.get_output_paths() + + if self.debug: + while True: + process = subprocess.Popen(cmd, shell=True, text=True) + process.communicate() + process.wait() + if self._job_failed(process.returncode, output_paths): + if retry > 0: + logger.warning( + f'task {task_name} failed, retrying...') + retry -= 1 + cmd = get_cmd() + else: + break + else: + break + else: + out_path = task.get_log_path(file_extension='out') + mmengine.mkdir_or_exist(osp.split(out_path)[0]) + stdout = open(out_path, 'w', encoding='utf-8') + stderr = subprocess.PIPE + while True: + process = subprocess.Popen(cmd, + shell=True, + text=True, + stdout=stdout, + stderr=stderr) + job_id = None + while True: + line = process.stderr.readline() + if not line: + break + match = re.search( + r'srun: Job (\d+) scheduled successfully!', line) + if match and job_id is None: + job_id = match.group(1) + child_conn.send(job_id) + stdout.write(line) + process.wait() + if self._job_failed(process.returncode, output_paths): + if retry > 0: + retry -= 1 + cmd = get_cmd() + else: + logger.warning( + f'task {task_name} fail, see\n{out_path}') + break + else: + break + except KeyboardInterrupt: + raise + finally: + # Clean up + if child_conn is not None: + child_conn.send(None) + child_conn.close() + if process is not None: + process.kill() + os.remove(param_file) + return task_name, process.returncode + + def _job_failed(self, return_code: int, output_paths: List[str]) -> bool: + return return_code != 0 or not all( + osp.exists(output_path) for output_path in output_paths) diff --git a/opencompass/summarizers/summarizer_pretrain.py b/opencompass/summarizers/summarizer_pretrain.py new file mode 100644 index 00000000..c63cfc43 --- /dev/null +++ b/opencompass/summarizers/summarizer_pretrain.py @@ -0,0 +1,337 @@ +# flake8: noqa +# yapf: disable +import getpass +import os.path as osp +from datetime import datetime +from typing import List, Optional + +import mmengine +import pytz +import tabulate +from mmengine import ConfigDict + +from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg, + get_infer_output_path, get_logger, + model_abbr_from_cfg) +from opencompass.utils.prompt import get_prompt_hash + +METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth'] +METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len'] + +class PretrainSummarizer: + """""" + + def __init__(self, config: ConfigDict, dataset_abbrs: Optional[List[str]] = None, summary_groups: List = [], prompt_db = None) -> None: + self.tasks = [] + self.cfg = config + self.logger = get_logger() + + # Enable lark bot if lark_url is presented + self.lark_reporter = None + if self.cfg.get('lark_bot_url', None): + self.lark_reporter = LarkReporter(self.cfg['lark_bot_url']) + + def summarize( + self, + output_path: str = None, + time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): # noqa + + model_cfgs = self.cfg['models'] + dataset_cfgs = self.cfg['datasets'] + summarizer_cfg = self.cfg.get('summarizer', {}) + work_dir = self.cfg['work_dir'] + + # pick up results + raw_results = {} + parsed_results = {} + dataset_metrics = {} + + model_abbrs = [model_abbr_from_cfg(model) for model in model_cfgs] + for model in model_cfgs: + model_abbr = model_abbr_from_cfg(model) + parsed_results[model_abbr] = {} + raw_results[model_abbr] = {} + for dataset in dataset_cfgs: + dataset_abbr = dataset_abbr_from_cfg(dataset) + filepath = get_infer_output_path(model, dataset, osp.join(work_dir, 'results')) + if not osp.exists(filepath): + continue + result = mmengine.load(filepath) + raw_results[model_abbr][dataset_abbr] = result + if 'error' in result: + self.debug(f'error in {model_abbr} {dataset_abbr} {result["error"]}') + continue + else: + parsed_results[model_abbr][dataset_abbr] = [] + dataset_metrics[dataset_abbr] = [] + for metric, score in result.items(): + if metric not in METRIC_BLACKLIST and isinstance(score, (int, float)): + parsed_results[model_abbr][dataset_abbr].append(score) + dataset_metrics[dataset_abbr].append(metric) + else: + continue + if len(parsed_results[model_abbr][dataset_abbr]) == 0: + self.logger.warning(f'unknown result format: {result}, continue') + del parsed_results[model_abbr][dataset_abbr] + del dataset_metrics[dataset_abbr] + continue + indice = sorted( + list(range(len(dataset_metrics[dataset_abbr]))), + key=lambda i: ( + METRIC_WHITELIST.index(dataset_metrics[dataset_abbr][i]) + if dataset_metrics[dataset_abbr][i] in METRIC_WHITELIST + else len(METRIC_WHITELIST) + ) + ) + parsed_results[model_abbr][dataset_abbr] = [parsed_results[model_abbr][dataset_abbr][i] for i in indice] + dataset_metrics[dataset_abbr] = [dataset_metrics[dataset_abbr][i] for i in indice] + + # parse eval mode + dataset_eval_mode = {} + for dataset in dataset_cfgs: + inferencer = dataset.get('infer_cfg', {}).get('inferencer', {}).get('type', '') + inferencer = inferencer if isinstance(inferencer, str) else inferencer.__name__ + dataset_abbr = dataset_abbr_from_cfg(dataset) + if 'GenInferencer' in inferencer: + dataset_eval_mode[dataset_abbr] = 'gen' + elif 'PPLInferencer' in inferencer: + dataset_eval_mode[dataset_abbr] = 'ppl' + else: + dataset_eval_mode[dataset_abbr] = 'unknown' + self.logger.warning(f'unknown inferencer: {inferencer} - {dataset_abbr}') + + # calculate group metrics + summary_groups = summarizer_cfg.get('summary_groups', []) + for sg in summary_groups: + for model_abbr in model_abbrs: + results = {} + eval_modes = [] + for dataset_abbr in sg['subsets']: + if dataset_abbr in parsed_results[model_abbr]: + results[dataset_abbr] = parsed_results[model_abbr][dataset_abbr][0] + eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown')) + if len(results) == len(sg['subsets']): + if 'weights' in sg: + numerator = sum(results[k] * sg['weights'][k] for k in sg['weights']) + denominator = sum(sg['weights'].values()) + metric = 'weighted_average' + else: + numerator = sum(results[k] for k in results) + denominator = len(results) + metric = 'naive_average' + results[metric] = numerator / denominator + eval_modes = list(set(eval_modes)) + eval_mode = eval_modes[0] if len(eval_modes) == 1 else 'mixed' + + # add to global results + raw_results[model_abbr][sg['name']] = results + parsed_results[model_abbr][sg['name']] = [numerator / denominator] + dataset_metrics[sg['name']] = [metric] + dataset_eval_mode[sg['name']] = eval_mode + elif len(results) == 0: + continue + else: + raw_results[model_abbr][sg['name']] = {'error': 'missing datasets: {}'.format(set(sg['subsets']) - set(results.keys()))} + + prompt_version = {dataset_abbr_from_cfg(d): get_prompt_hash(d)[:6] for d in dataset_cfgs} + + # format table + summarizer_dataset_abbrs = [] + if summarizer_cfg.get('dataset_abbrs') is None: + for dataset in dataset_cfgs: + dataset_abbr = dataset_abbr_from_cfg(dataset) + if dataset_abbr in dataset_metrics: + for metric in dataset_metrics[dataset_abbr]: + summarizer_dataset_abbrs.append((dataset_abbr, metric)) + else: + summarizer_dataset_abbrs.append((dataset_abbr, None)) + for dataset_abbr in dataset_metrics: + for metric in dataset_metrics[dataset_abbr]: + if (dataset_abbr, metric) not in summarizer_dataset_abbrs: + summarizer_dataset_abbrs.append((dataset_abbr, metric)) + else: + for item in summarizer_cfg['dataset_abbrs']: + if isinstance(item, str): + summarizer_dataset_abbrs.append((item, None)) + elif isinstance(item, (list, tuple)): + summarizer_dataset_abbrs.append((item[0], item[1])) + table = [] + checkpoints = [model_abbr.rsplit('_', 1)[1] if '_' in model_abbr else model_abbr for model_abbr in model_abbrs] + # model_abbrs = [model_abbr.rsplit("_", 1)[0] for model_abbr in model_abbrs] + header = ['dataset', 'version', 'metric', 'mode'] + model_abbrs + time_zone = pytz.timezone('Asia/Shanghai') + now = datetime.now(time_zone) + time = now.strftime('%m/%d %H:%M') + times = [time] * len(model_abbrs) + table.append(header) + table.append(['dataset', 'version', 'metric', 'mode'] + times) + table.append(['dataset', 'version', 'metric', 'mode']+ checkpoints) + dataset_score = [0]* len(model_abbrs) + dataset_num = [0] * len(model_abbrs) + + for dataset_abbr, metric in summarizer_dataset_abbrs: + # if dataset_abbr not in dataset_metrics: + # table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(model_abbrs)) + # continue + if metric is None and dataset_abbr in dataset_metrics: + index = 0 + metric = dataset_metrics[dataset_abbr][0] + elif dataset_abbr in dataset_metrics and metric in dataset_metrics[dataset_abbr]: + index = dataset_metrics[dataset_abbr].index(metric) + elif not dataset_abbr.startswith('---'): + table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(model_abbrs)) + continue + if dataset_abbr.startswith('---'): + row = [dataset_abbr,'-','-','-'] + else: + row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric, dataset_eval_mode.get(dataset_abbr, '-')] + for i, model_abbr in enumerate(model_abbrs): + if dataset_abbr in parsed_results[model_abbr]: + if index == 0: + row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][index])) + dataset_score[i] += parsed_results[model_abbr][dataset_abbr][index] + dataset_num[i] += 1 + # row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][index])) + else: + if dataset_abbr.startswith('---') and dataset_num[i] != 0: + row.append('{:.02f}'.format(dataset_score[i] / dataset_num[i])) + dataset_score[i] = 0 + dataset_num[i] = 0 + else: + row.append('-') + table.append(row) + + # format raw txt + raw_dataset_abbrs = [] + for model_abbr in model_abbrs: + for dataset_abbr in raw_results[model_abbr]: + if dataset_abbr not in raw_dataset_abbrs: + raw_dataset_abbrs.append(dataset_abbr) + raw_txts = [] + for model_abbr in model_abbrs: + raw_txts.append('-------------------------------') + raw_txts.append(f'Model: {model_abbr}') + for dataset_abbr in raw_dataset_abbrs: + result = raw_results[model_abbr].get(dataset_abbr, '{}') + raw_txts.append(f'{dataset_abbr}: {result}') + raw_txts = '\n'.join(raw_txts) + + # output to screean + print(tabulate.tabulate(table, headers='firstrow')) + + # output to file + if output_path is None: + output_path = osp.join(work_dir, 'summary', f'summary_{time_str}.txt') + output_csv_path = osp.join(work_dir, 'summary', f'summary_{time_str}.csv') + else: + output_csv_path = output_path.replace('.txt', '.csv') + + output_dir = osp.split(output_path)[0] + mmengine.mkdir_or_exist(output_dir) + with open(output_path, 'w', encoding='utf-8') as f: + f.write(time_str + '\n') + f.write('tabulate format\n') + f.write('^' * 128 + '\n') + f.write(tabulate.tabulate(table, headers='firstrow') + '\n') + f.write('$' * 128 + '\n') + f.write('\n' + '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n') + f.write('csv format\n') + f.write('^' * 128 + '\n') + f.write('\n'.join([','.join(row) for row in table]) + '\n') + f.write('$' * 128 + '\n') + f.write('\n' + '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n') + f.write('raw format\n') + f.write('^' * 128 + '\n') + f.write(raw_txts + '\n') + f.write('$' * 128 + '\n') + self.logger.info(f'write summary to {osp.abspath(output_path)}') + + if self.lark_reporter: + content = f'{getpass.getuser()} 的' + content += f'详细评测汇总已输出至 {osp.abspath(output_path)}' + self.lark_reporter.post(content) + + with open(output_csv_path, 'w', encoding='utf-8') as f: + f.write('\n'.join([','.join(row) for row in table]) + '\n') + self.logger.info(f'write csv to {osp.abspath(output_csv_path)}') + + + summary_groups = summarizer_cfg.get('summary_groups', []) + for sg in summary_groups: + for model_abbr in model_abbrs: + results = {} + eval_modes = [] + for dataset_abbr in sg['subsets']: + if dataset_abbr in parsed_results[model_abbr]: + results[dataset_abbr] = (parsed_results[model_abbr][dataset_abbr][-1],parsed_results[model_abbr][dataset_abbr][-2]) + eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown')) + + if len(results) == len(sg['subsets']): + numerator1 = sum(results[k][0] for k in results) + numerator2 = sum(results[k][1] for k in results) + denominator = len(results) + metric = 'correct_bpb-incorrect_bpb' + + count_ppl = eval_modes.count('ppl') + count_gen = len(eval_modes)-count_ppl + if count_ppl==0: + results[metric] = -1 + else: + results[metric] = (numerator1+count_gen) / count_ppl + eval_modes = list(set(eval_modes)) + eval_mode = eval_modes[0] if len(eval_modes) == 1 else 'mixed' + # add to global results + + raw_results[model_abbr][sg['name']] = results + parsed_results[model_abbr][sg['name']] = [((numerator1+count_gen) / count_ppl) if count_ppl != 0 else -1, ((numerator2+count_gen) / count_ppl) if count_ppl != 0 else -1] + dataset_metrics[sg['name']] = ['incorrect_bpb','correct_bpb'] + dataset_eval_mode[sg['name']] = eval_mode + + elif len(results) == 0: + continue + else: + raw_results[model_abbr][sg['name']] = {'error': 'missing datasets: {}'.format(set(sg['subsets']) - set(results.keys()))} + + table = [] + table.append(['', '', '', ''] + [''] * len(model_abbrs)) + table.append(['', '', '', ''] + [''] * len(model_abbrs)) + table.append(['', '', '', ''] + [''] * len(model_abbrs)) + for dataset_abbr, metric in summarizer_dataset_abbrs: + incorrect_bpb = -1 + correct_bpb = -1 + if dataset_abbr not in dataset_metrics: + table.append([dataset_abbr, '', '', ''] + [''] * len(model_abbrs)) + continue + if metric is None: + index = 0 + try: + incorrect_bpb = dataset_metrics[dataset_abbr].index('incorrect_bpb') + correct_bpb = dataset_metrics[dataset_abbr].index('correct_bpb') + except ValueError: + try: + incorrect_bpb = dataset_metrics[dataset_abbr].index('wrong_bpb') + correct_bpb = dataset_metrics[dataset_abbr].index('right_bpb') + except ValueError: + incorrect_bpb = -1 + correct_bpb = -1 + metric = 'correct_bpb-incorrect_bpb' + elif metric in dataset_metrics[dataset_abbr]: + index = dataset_metrics[dataset_abbr].index(metric) + else: + table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(model_abbrs)) + continue + + row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric, + dataset_eval_mode.get(dataset_abbr, '-')] + for model_abbr in model_abbrs: + if dataset_abbr in parsed_results[model_abbr]: + if incorrect_bpb != -1 and correct_bpb != -1: + row.append('{:.02f}/{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][correct_bpb], + parsed_results[model_abbr][dataset_abbr][incorrect_bpb])) + else: + row.append('{:.02f}'.format(-1)) + else: + row.append('-') + table.append(row) + with open(output_csv_path, 'a', encoding='utf-8') as f: + f.write('\n'.join([','.join(row) for row in table]) + '\n') diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 68be3d27..de76d01d 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -287,7 +287,7 @@ class OpenICLEvalTask(BaseTask): result['prompt'] = origin_prediction['origin_prompt'] result['origin_prediction'] = pred_dicts[i]['prediction'] result['predictions'] = details[i]['pred'] - result['references'] = details[i]['answers'] + result['references'] = details[i]['answer'] result['correct'] = details[i]['correct'] results[str(i)] = result return results @@ -324,7 +324,7 @@ class OpenICLEvalTask(BaseTask): bpbs = [value['BPB'] for value in values] incorrect_bpb_list.append( (sum(bpbs) - min(bpbs)) / (len(bpbs) - 1)) - bpb_list.append(statistics.mean(bpbs)) + bpb_list.append(min(bpbs)) def filters(origins): targets = [target for target in origins if not math.isnan(target)]