From 3f601f420b7561d0f0c3d73653eff4c6a67e620d Mon Sep 17 00:00:00 2001 From: Yike Yuan <32432002+yyk-wew@users.noreply.github.com> Date: Fri, 25 Aug 2023 15:44:32 +0800 Subject: [PATCH] [Feat] Support public dataset of visualglm and llava. (#265) * [Feat] Add public dataset support of VisualGLM. * [Feat] Refactor LLaVA. * [Feat] Add public dataset support of LlaVA. * [Fix] Add arg. --- .../multimodal/llava/llava_7b_coco_caption.py | 50 +++++ .../multimodal/llava/llava_7b_flickr30k.py | 52 ++++++ configs/multimodal/llava/llava_7b_gqa.py | 49 +++++ configs/multimodal/llava/llava_7b_mmbench.py | 4 + configs/multimodal/llava/llava_7b_ocr_vqa.py | 49 +++++ configs/multimodal/llava/llava_7b_ok_vqa.py | 51 ++++++ .../multimodal/llava/llava_7b_scienceqa.py | 50 +++++ configs/multimodal/llava/llava_7b_textvqa.py | 50 +++++ configs/multimodal/llava/llava_7b_vizwiz.py | 48 +++++ configs/multimodal/llava/llava_7b_vqav2.py | 50 +++++ configs/multimodal/llava/llava_7b_vsr.py | 48 +++++ .../visualglm/visualglm_6b_coco_caption.py | 45 +++++ .../visualglm/visualglm_6b_flickr30k.py | 46 +++++ .../multimodal/visualglm/visualglm_6b_gqa.py | 42 +++++ .../visualglm/visualglm_6b_mmbench.py | 6 +- .../visualglm/visualglm_6b_ocr_vqa.py | 43 +++++ .../visualglm/visualglm_6b_ok_vqa.py | 45 +++++ .../visualglm/visualglm_6b_scienceqa.py | 44 +++++ .../visualglm/visualglm_6b_textvqa.py | 44 +++++ .../visualglm/visualglm_6b_vizwiz.py | 42 +++++ .../visualglm/visualglm_6b_vqav2.py | 44 +++++ .../multimodal/visualglm/visualglm_6b_vsr.py | 43 +++++ .../multimodal/models/llava/__init__.py | 11 +- opencompass/multimodal/models/llava/llava.py | 43 +++-- .../multimodal/models/llava/post_processor.py | 28 +++ .../models/llava/prompt_constructor.py | 110 +++++++++-- .../multimodal/models/visualglm/__init__.py | 16 +- .../models/visualglm/post_processor.py | 21 ++- .../models/visualglm/prompt_constructor.py | 171 +++++++++++++++++- .../multimodal/models/visualglm/visualglm.py | 13 +- 30 files changed, 1313 insertions(+), 45 deletions(-) create mode 100644 configs/multimodal/llava/llava_7b_coco_caption.py create mode 100644 configs/multimodal/llava/llava_7b_flickr30k.py create mode 100644 configs/multimodal/llava/llava_7b_gqa.py create mode 100644 configs/multimodal/llava/llava_7b_ocr_vqa.py create mode 100644 configs/multimodal/llava/llava_7b_ok_vqa.py create mode 100644 configs/multimodal/llava/llava_7b_scienceqa.py create mode 100644 configs/multimodal/llava/llava_7b_textvqa.py create mode 100644 configs/multimodal/llava/llava_7b_vizwiz.py create mode 100644 configs/multimodal/llava/llava_7b_vqav2.py create mode 100644 configs/multimodal/llava/llava_7b_vsr.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_coco_caption.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_flickr30k.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_gqa.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_ocr_vqa.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_ok_vqa.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_scienceqa.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_textvqa.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_vizwiz.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_vqav2.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_vsr.py create mode 100644 opencompass/multimodal/models/llava/post_processor.py diff --git a/configs/multimodal/llava/llava_7b_coco_caption.py b/configs/multimodal/llava/llava_7b_coco_caption.py new file mode 100644 index 00000000..e0793494 --- /dev/null +++ b/configs/multimodal/llava/llava_7b_coco_caption.py @@ -0,0 +1,50 @@ +from opencompass.multimodal.models.llava import LLaVABasePromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id']), +] + + +dataset = dict(type='mmpretrain.COCOCaption', + data_root='data/coco', + data_prefix=dict(img_path='images'), + ann_file='annotations/coco_karpathy_val.json', + pipeline=val_pipeline) + +llava_coco_caption_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_coco_caption_model = dict( + type='llava', + model_path='/path/to/llava', + is_caption_task=True, + prompt_constructor=dict(type=LLaVABasePromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_coco_caption_evaluator = [ + dict( + type='mmpretrain.COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', + ) # noqa +] + diff --git a/configs/multimodal/llava/llava_7b_flickr30k.py b/configs/multimodal/llava/llava_7b_flickr30k.py new file mode 100644 index 00000000..cdb151b3 --- /dev/null +++ b/configs/multimodal/llava/llava_7b_flickr30k.py @@ -0,0 +1,52 @@ +from opencompass.multimodal.models.llava import LLaVABasePromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id']), +] + + +dataset = dict(type='mmpretrain.Flickr30kCaption', + data_root='data/flickr30k', + ann_file='annotations/dataset_flickr30k.json', + data_prefix='images', + split='val', + pipeline=val_pipeline) + +llava_flickr30k_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_flickr30k_model = dict( + type='llava', + model_path='/path/to/llava', + is_caption_task=True, + prompt_constructor=dict(type=LLaVABasePromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_flickr30k_evaluator = [ + dict( + type='mmpretrain.COCOCaption', + ann_file='data/flickr30k/annotations/flickr30k_val_gt.json', + ) # noqa +] + + diff --git a/configs/multimodal/llava/llava_7b_gqa.py b/configs/multimodal/llava/llava_7b_gqa.py new file mode 100644 index 00000000..fe80ac22 --- /dev/null +++ b/configs/multimodal/llava/llava_7b_gqa.py @@ -0,0 +1,49 @@ +from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + + +dataset = dict(type='mmpretrain.GQA', + data_root='data/gqa', + data_prefix='images', + ann_file='annotations/testdev_balanced_questions.json', + pipeline=val_pipeline) + +llava_gqa_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_gqa_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAVQAPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')] + + diff --git a/configs/multimodal/llava/llava_7b_mmbench.py b/configs/multimodal/llava/llava_7b_mmbench.py index 9bef7e8f..2722e391 100644 --- a/configs/multimodal/llava/llava_7b_mmbench.py +++ b/configs/multimodal/llava/llava_7b_mmbench.py @@ -1,3 +1,5 @@ +from opencompass.multimodal.models.llava import LLaVAMMBenchPromptConstructor, LLaVABasePostProcessor + # dataloader settings val_pipeline = [ dict(type='mmpretrain.torchvision/Resize', @@ -34,6 +36,8 @@ mmbench_dataloader = dict( llava_model = dict( type='llava', model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAMMBenchPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) ) # noqa # evaluation settings diff --git a/configs/multimodal/llava/llava_7b_ocr_vqa.py b/configs/multimodal/llava/llava_7b_ocr_vqa.py new file mode 100644 index 00000000..9926128f --- /dev/null +++ b/configs/multimodal/llava/llava_7b_ocr_vqa.py @@ -0,0 +1,49 @@ +from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict(type='mmpretrain.OCRVQA', + data_root='data/ocrvqa', + ann_file='annotations/dataset.json', + split='test', + data_prefix='images', + pipeline=val_pipeline) + +llava_ocrvqa_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_ocrvqa_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAVQAPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_ocrvqa_evaluator = [dict(type='mmpretrain.VQAAcc')] + + diff --git a/configs/multimodal/llava/llava_7b_ok_vqa.py b/configs/multimodal/llava/llava_7b_ok_vqa.py new file mode 100644 index 00000000..f2d79cee --- /dev/null +++ b/configs/multimodal/llava/llava_7b_ok_vqa.py @@ -0,0 +1,51 @@ +from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict( + type='mmpretrain.COCOVQA', + data_root='data/okvqa', + question_file='annotations/OpenEnded_mscoco_val2014_questions.json', + ann_file='annotations/mscoco_val2014_annotations.json', + pipeline=val_pipeline, + data_prefix='images/val2014', +) + +llava_okvqa_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_okvqa_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAVQAPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_okvqa_evaluator = [dict(type='mmpretrain.VQAAcc')] + + diff --git a/configs/multimodal/llava/llava_7b_scienceqa.py b/configs/multimodal/llava/llava_7b_scienceqa.py new file mode 100644 index 00000000..9e48870a --- /dev/null +++ b/configs/multimodal/llava/llava_7b_scienceqa.py @@ -0,0 +1,50 @@ +from opencompass.multimodal.models.llava import LLaVAScienceQAPromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution' + ]) +] + +dataset = dict(type='mmpretrain.ScienceQA', + data_root='./data/scienceqa', + split='val', + split_file='pid_splits.json', + ann_file='problems.json', + image_only=True, + data_prefix=dict(img_path='val'), + pipeline=val_pipeline) + +llava_scienceqa_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_scienceqa_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAScienceQAPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')] + + diff --git a/configs/multimodal/llava/llava_7b_textvqa.py b/configs/multimodal/llava/llava_7b_textvqa.py new file mode 100644 index 00000000..52dbb030 --- /dev/null +++ b/configs/multimodal/llava/llava_7b_textvqa.py @@ -0,0 +1,50 @@ +from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict( + type='mmpretrain.TextVQA', + data_root='data/textvqa', + ann_file='annotations/TextVQA_0.5.1_val.json', + pipeline=val_pipeline, + data_prefix='images/train_images', +) + +llava_textvqa_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_textvqa_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAVQAPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')] + + diff --git a/configs/multimodal/llava/llava_7b_vizwiz.py b/configs/multimodal/llava/llava_7b_vizwiz.py new file mode 100644 index 00000000..5a26176b --- /dev/null +++ b/configs/multimodal/llava/llava_7b_vizwiz.py @@ -0,0 +1,48 @@ +from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict(type='mmpretrain.VizWiz', + data_root='data/vizwiz/', + data_prefix='Images/val', + ann_file='Annotations/val.json', + pipeline=val_pipeline) + +llava_vizwiz_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_vizwiz_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAVQAPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')] + + diff --git a/configs/multimodal/llava/llava_7b_vqav2.py b/configs/multimodal/llava/llava_7b_vqav2.py new file mode 100644 index 00000000..22a322c5 --- /dev/null +++ b/configs/multimodal/llava/llava_7b_vqav2.py @@ -0,0 +1,50 @@ +from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict( + type='mmpretrain.COCOVQA', + data_root='data/coco', + data_prefix='images/val2014', + question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json', + ann_file='annotations/v2_mscoco_val2014_annotations.json', + pipeline=val_pipeline) + +llava_vqav2_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_vqav2_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAVQAPromptConstructor), + post_processor=dict(type=LLaVABasePostProcessor) +) # noqa + +# evaluation settings +llava_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')] + + diff --git a/configs/multimodal/llava/llava_7b_vsr.py b/configs/multimodal/llava/llava_7b_vsr.py new file mode 100644 index 00000000..7985d143 --- /dev/null +++ b/configs/multimodal/llava/llava_7b_vsr.py @@ -0,0 +1,48 @@ +from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVAVSRPostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict(type='mmpretrain.VSR', + data_root='data/vsr/', + data_prefix='images/', + ann_file='annotations/test.json', + pipeline=val_pipeline) + +llava_vsr_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_vsr_model = dict( + type='llava', + model_path='/path/to/llava', + prompt_constructor=dict(type=LLaVAVQAPromptConstructor), + post_processor=dict(type=LLaVAVSRPostProcessor) +) # noqa + +# evaluation settings +llava_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')] + + diff --git a/configs/multimodal/visualglm/visualglm_6b_coco_caption.py b/configs/multimodal/visualglm/visualglm_6b_coco_caption.py new file mode 100644 index 00000000..e2fbceca --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_coco_caption.py @@ -0,0 +1,45 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMBasePromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id']) +] + + +dataset = dict(type='mmpretrain.COCOCaption', + data_root='data/coco', + data_prefix=dict(img_path='images'), + ann_file='annotations/coco_karpathy_val.json', + pipeline=val_pipeline) + +visualglm_coco_caption_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_coco_caption_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + is_caption_task=True, + prompt_constructor=dict(type=VisualGLMBasePromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_coco_caption_evaluator = [ + dict( + type='mmpretrain.COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', + ) # noqa +] diff --git a/configs/multimodal/visualglm/visualglm_6b_flickr30k.py b/configs/multimodal/visualglm/visualglm_6b_flickr30k.py new file mode 100644 index 00000000..b88e519f --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_flickr30k.py @@ -0,0 +1,46 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMBasePromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id']) +] + + +dataset = dict(type='mmpretrain.Flickr30kCaption', + data_root='data/flickr30k', + ann_file='annotations/dataset_flickr30k.json', + data_prefix='images', + split='val', + pipeline=val_pipeline) + +visualglm_flickr30k_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_flickr30k_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + is_caption_task=True, + prompt_constructor=dict(type=VisualGLMBasePromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_flickr30k_evaluator = [ + dict( + type='mmpretrain.COCOCaption', + ann_file='data/flickr30k/annotations/flickr30k_val_gt.json', + ) # noqa +] diff --git a/configs/multimodal/visualglm/visualglm_6b_gqa.py b/configs/multimodal/visualglm/visualglm_6b_gqa.py new file mode 100644 index 00000000..c812afbb --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_gqa.py @@ -0,0 +1,42 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict(type='mmpretrain.GQA', + data_root='data/gqa', + data_prefix='images', + ann_file='annotations/testdev_balanced_questions.json', + pipeline=val_pipeline) + +visualglm_gqa_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_gqa_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')] diff --git a/configs/multimodal/visualglm/visualglm_6b_mmbench.py b/configs/multimodal/visualglm/visualglm_6b_mmbench.py index bd50b5b0..0dbbbd27 100644 --- a/configs/multimodal/visualglm/visualglm_6b_mmbench.py +++ b/configs/multimodal/visualglm/visualglm_6b_mmbench.py @@ -1,4 +1,4 @@ -from opencompass.multimodal.models.visualglm import (VisualGLMPostProcessor, VisualGLMPromptConstructor) +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMMMBenchPromptConstructor) # dataloader settings val_pipeline = [ @@ -30,8 +30,8 @@ mmbench_dataloader = dict(batch_size=1, visualglm_model = dict( type='visualglm', pretrained_path='/path/to/visualglm', # or Huggingface repo id - prompt_constructor=dict(type=VisualGLMPromptConstructor), - post_processor=dict(type=VisualGLMPostProcessor) + prompt_constructor=dict(type=VisualGLMMMBenchPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) ) # evaluation settings diff --git a/configs/multimodal/visualglm/visualglm_6b_ocr_vqa.py b/configs/multimodal/visualglm/visualglm_6b_ocr_vqa.py new file mode 100644 index 00000000..5b991cfa --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_ocr_vqa.py @@ -0,0 +1,43 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict(type='mmpretrain.OCRVQA', + data_root='data/ocrvqa', + ann_file='annotations/dataset.json', + split='test', + data_prefix='images', + pipeline=val_pipeline) + +visualglm_ocrvqa_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_ocrvqa_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_ocrvqa_evaluator = [dict(type='mmpretrain.VQAAcc')] diff --git a/configs/multimodal/visualglm/visualglm_6b_ok_vqa.py b/configs/multimodal/visualglm/visualglm_6b_ok_vqa.py new file mode 100644 index 00000000..f3c7784b --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_ok_vqa.py @@ -0,0 +1,45 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict( + type='mmpretrain.COCOVQA', + data_root='data/okvqa', + question_file='annotations/OpenEnded_mscoco_val2014_questions.json', + ann_file='annotations/mscoco_val2014_annotations.json', + pipeline=val_pipeline, + data_prefix='images/val2014', +) + +visualglm_okvqa_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_okvqa_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_okvqa_evaluator = [dict(type='mmpretrain.VQAAcc')] diff --git a/configs/multimodal/visualglm/visualglm_6b_scienceqa.py b/configs/multimodal/visualglm/visualglm_6b_scienceqa.py new file mode 100644 index 00000000..79f29828 --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_scienceqa.py @@ -0,0 +1,44 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMScienceQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution' + ]) +] + +dataset = dict(type='mmpretrain.ScienceQA', + data_root='./data/scienceqa', + split='val', + split_file='pid_splits.json', + ann_file='problems.json', + image_only=True, + data_prefix=dict(img_path='val'), + pipeline=val_pipeline) + +visualglm_vizwiz_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_scienceqa_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMScienceQAPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')] \ No newline at end of file diff --git a/configs/multimodal/visualglm/visualglm_6b_textvqa.py b/configs/multimodal/visualglm/visualglm_6b_textvqa.py new file mode 100644 index 00000000..20774938 --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_textvqa.py @@ -0,0 +1,44 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict( + type='mmpretrain.TextVQA', + data_root='data/textvqa', + ann_file='annotations/TextVQA_0.5.1_val.json', + pipeline=val_pipeline, + data_prefix='images/train_images', +) + +visualglm_textvqa_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')] diff --git a/configs/multimodal/visualglm/visualglm_6b_vizwiz.py b/configs/multimodal/visualglm/visualglm_6b_vizwiz.py new file mode 100644 index 00000000..b49a8c6e --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_vizwiz.py @@ -0,0 +1,42 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict(type='mmpretrain.VizWiz', + data_root='data/vizwiz/', + data_prefix='Images/val', + ann_file='Annotations/val.json', + pipeline=val_pipeline) + +visualglm_vizwiz_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')] diff --git a/configs/multimodal/visualglm/visualglm_6b_vqav2.py b/configs/multimodal/visualglm/visualglm_6b_vqav2.py new file mode 100644 index 00000000..4bbb8426 --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_vqav2.py @@ -0,0 +1,44 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + +dataset = dict( + type='mmpretrain.COCOVQA', + data_root='data/coco', + data_prefix='images/val2014', + question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json', + ann_file='annotations/v2_mscoco_val2014_annotations.json', + pipeline=val_pipeline) + +visualglm_vqav2_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), + post_processor=dict(type=VisualGLMBasePostProcessor) +) + +# evaluation settings +visualglm_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')] diff --git a/configs/multimodal/visualglm/visualglm_6b_vsr.py b/configs/multimodal/visualglm/visualglm_6b_vsr.py new file mode 100644 index 00000000..69664835 --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_vsr.py @@ -0,0 +1,43 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMVSRPostProcessor, VisualGLMVQAPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict(type='mmpretrain.ToPIL', to_rgb=True), + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ) +] + + +dataset = dict(type='mmpretrain.VSR', + data_root='data/vsr/', + data_prefix='images/', + ann_file='annotations/test.json', + pipeline=val_pipeline) + +visualglm_vsr_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), + post_processor=dict(type=VisualGLMVSRPostProcessor) +) + +# evaluation settings +visualglm_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')] diff --git a/opencompass/multimodal/models/llava/__init__.py b/opencompass/multimodal/models/llava/__init__.py index 5c367473..4fc919fa 100644 --- a/opencompass/multimodal/models/llava/__init__.py +++ b/opencompass/multimodal/models/llava/__init__.py @@ -1,3 +1,12 @@ from .llava import LLaVA +from .post_processor import LLaVABasePostProcessor, LLaVAVSRPostProcessor +from .prompt_constructor import (LLaVABasePromptConstructor, + LLaVAMMBenchPromptConstructor, + LLaVAScienceQAPromptConstructor, + LLaVAVQAPromptConstructor) -__all__ = ['LLaVA'] +__all__ = [ + 'LLaVA', 'LLaVABasePromptConstructor', 'LLaVAMMBenchPromptConstructor', + 'LLaVABasePostProcessor', 'LLaVAVQAPromptConstructor', + 'LLaVAScienceQAPromptConstructor', 'LLaVAVSRPostProcessor' +] diff --git a/opencompass/multimodal/models/llava/llava.py b/opencompass/multimodal/models/llava/llava.py index 046fbad8..54835a78 100644 --- a/opencompass/multimodal/models/llava/llava.py +++ b/opencompass/multimodal/models/llava/llava.py @@ -2,6 +2,7 @@ import importlib import os import sys +import mmengine import torch import torch.nn as nn from mmengine.device import get_device @@ -9,8 +10,6 @@ from transformers import StoppingCriteria from opencompass.registry import MM_MODELS -from .prompt_constructor import LLaVAMMBenchPromptConstructor - IMAGE_TOKEN_INDEX = -200 @@ -53,19 +52,27 @@ class LLaVA(nn.Module): Args: model_path (str): The path of llava checkpoint. + prompt_constructor (dict): The config of prompt constructor. + post_processor (dict): The config of post processor. + is_caption_task (bool): Whether the task is caption task. + Defaults to False. """ - def __init__(self, model_path: str) -> None: + def __init__( + self, + model_path: str, + prompt_constructor: dict, + post_processor: dict, + is_caption_task: bool = False, + ) -> None: super().__init__() self.dtype = torch.float16 + self.is_caption_task = is_caption_task # load LLaVA modules load_package() mm_utils = importlib.import_module('llava.mm_utils') builder = importlib.import_module('llava.model.builder') - conversation = importlib.import_module('llava.conversation') - self.SeparatorStyle = conversation.SeparatorStyle - self.conv_templates = conversation.conv_templates # load pretrained LLaVA # Note: When encounters with device related errors, @@ -86,13 +93,16 @@ class LLaVA(nn.Module): conv_mode = 'multimodal' mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False) - + prompt_constructor.update({ + 'conv_mode': conv_mode, + 'mm_use_im_start_end': mm_use_im_start_end + }) + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) self.model = model self.tokenizer = tokenizer - self.prompt_constructor = LLaVAMMBenchPromptConstructor( - conv_templates=conversation.conv_templates, - conv_mode=conv_mode, - mm_use_im_start_end=mm_use_im_start_end) def generate(self, batch): @@ -133,12 +143,13 @@ class LLaVA(nn.Module): ) outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] - outputs = outputs.strip() - if outputs.endswith(stop_str): - outputs = outputs[:-len(stop_str)] - output_text = outputs.strip() - data_sample.pred_answer = output_text + output_text = self.post_processor(outputs, stop_str) + + if self.is_caption_task: + data_sample.pred_caption = output_text + else: + data_sample.pred_answer = output_text return data_sample def forward(self, batch): diff --git a/opencompass/multimodal/models/llava/post_processor.py b/opencompass/multimodal/models/llava/post_processor.py new file mode 100644 index 00000000..51066182 --- /dev/null +++ b/opencompass/multimodal/models/llava/post_processor.py @@ -0,0 +1,28 @@ +class LLaVABasePostProcessor: + """Base post processor for LLaVA on MMBench.""" + + def __init__(self) -> None: + pass + + def __call__(self, outputs: str, stop_str: str) -> str: + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + output_text = outputs.strip() + return output_text + + +class LLaVAVSRPostProcessor(LLaVABasePostProcessor): + """VSR post processor for LLaVA on MMBench.""" + + def __init__(self) -> None: + super().__init__() + + def __call__(self, outputs: str, stop_str: str) -> str: + output_text = super().__call__(outputs, stop_str) + if 'yes' in output_text.lower(): + return 'yes' + elif 'no' in output_text.lower(): + return 'no' + else: + return 'unknown' diff --git a/opencompass/multimodal/models/llava/prompt_constructor.py b/opencompass/multimodal/models/llava/prompt_constructor.py index c055c207..c25496a4 100644 --- a/opencompass/multimodal/models/llava/prompt_constructor.py +++ b/opencompass/multimodal/models/llava/prompt_constructor.py @@ -1,5 +1,4 @@ import importlib -from typing import Any DEFAULT_IMAGE_TOKEN = '' DEFAULT_IMAGE_PATCH_TOKEN = '' @@ -7,23 +6,26 @@ DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' -class LLaVAMMBenchPromptConstructor: - """Prompt constructor for LLaVA on MMBench. +class LLaVABasePromptConstructor: + """Base prompt constructor for LLaVA. Args: - conv_templates (Any): Conversation class to build prompt. conv_mode (str): Version control args for different version of LLaVA. mm_use_im_start_end (bool): Config arg. Use start and end token when build prompt or not. + reply_prompt (str): Reply prompt added at the end. (Default: '') """ - def __init__(self, conv_templates: Any, conv_mode: str, - mm_use_im_start_end: bool) -> None: - self.conv_templates = conv_templates + def __init__(self, + conv_mode: str, + mm_use_im_start_end: bool, + reply_prompt: str = '') -> None: + conversation = importlib.import_module('llava.conversation') + self.conv_templates = conversation.conv_templates self.conv_mode = conv_mode self.mm_use_im_start_end = mm_use_im_start_end - conversation = importlib.import_module('llava.conversation') self.SeparatorStyle = conversation.SeparatorStyle + self.reply_prompt = reply_prompt def __call__(self, inputs: dict) -> tuple: """Construct prompt. @@ -36,13 +38,7 @@ class LLaVAMMBenchPromptConstructor: """ data_samples = inputs['data_samples'] assert len(data_samples) == 1 - question = data_samples[0].get('question') - options = data_samples[0].get('options') - context = data_samples[0].get('context') - if context is not None: - prompt = context + ' ' + question + ' ' + options - else: - prompt = question + ' ' + options + prompt = self._build_prompt(data_samples[0]) if self.mm_use_im_start_end: prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt) @@ -57,3 +53,87 @@ class LLaVAMMBenchPromptConstructor: stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa return output_prompt, stop_str + + def _build_prompt(self, data_sample): + return self.reply_prompt + + +class LLaVAMMBenchPromptConstructor(LLaVABasePromptConstructor): + """MMBench prompt constructor for LLaVA. + + Args: + conv_mode (str): Version control args for different version of LLaVA. + mm_use_im_start_end (bool): + Config arg. Use start and end token when build prompt or not. + reply_prompt (str): Reply prompt added at the end. (Default: '') + """ + + def __init__(self, + conv_mode: str, + mm_use_im_start_end: bool, + reply_prompt: str = '') -> None: + super().__init__(conv_mode, mm_use_im_start_end, reply_prompt) + + def _build_prompt(self, data_sample): + question = data_sample.get('question') + options = data_sample.get('options') + context = data_sample.get('context') + if context is not None: + prompt = context + ' ' + question + ' ' + options + else: + prompt = question + ' ' + options + prompt += self.reply_prompt + return prompt + + +class LLaVAVQAPromptConstructor(LLaVABasePromptConstructor): + """VQA prompt constructor for LLaVA. + + Args: + conv_mode (str): Version control args for different version of LLaVA. + mm_use_im_start_end (bool): + Config arg. Use start and end token when build prompt or not. + reply_prompt (str): Reply prompt added at the end. (Default: '') + """ + + def __init__(self, + conv_mode: str, + mm_use_im_start_end: bool, + reply_prompt: str = '') -> None: + super().__init__(conv_mode, mm_use_im_start_end, reply_prompt) + + def _build_prompt(self, data_sample): + prompt = data_sample.get('question') + prompt += self.reply_prompt + return prompt + + +class LLaVAScienceQAPromptConstructor(LLaVABasePromptConstructor): + """ScienceQA prompt constructor for LLaVA. + + Args: + conv_mode (str): Version control args for different version of LLaVA. + mm_use_im_start_end (bool): + Config arg. Use start and end token when build prompt or not. + reply_prompt (str): Reply prompt added at the end. (Default: '') + """ + + choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'} + + def __init__(self, + conv_mode: str, + mm_use_im_start_end: bool, + reply_prompt: str = '') -> None: + super().__init__(conv_mode, mm_use_im_start_end, reply_prompt) + + def _build_prompt(self, data_sample): + question = data_sample.get('question') + choices = data_sample.get('choices') + choices = [ + f'({self.choice_mapping[i]}) ' + item + for i, item in enumerate(choices) + ] + choices = 'Choices: ' + ' '.join(choices) + '\n' + context = 'Context: ' + data_sample.get('hint') + '\n' + prompt = context + question + choices + self.reply_prompt + return prompt diff --git a/opencompass/multimodal/models/visualglm/__init__.py b/opencompass/multimodal/models/visualglm/__init__.py index 69b12b4a..e2d6753a 100644 --- a/opencompass/multimodal/models/visualglm/__init__.py +++ b/opencompass/multimodal/models/visualglm/__init__.py @@ -1,5 +1,15 @@ -from .post_processor import VisualGLMPostProcessor -from .prompt_constructor import VisualGLMPromptConstructor +from .post_processor import (VisualGLMBasePostProcessor, + VisualGLMVSRPostProcessor) +from .prompt_constructor import (VisualGLMBasePromptConstructor, + VisualGLMIconQAPromptConstructor, + VisualGLMMMBenchPromptConstructor, + VisualGLMScienceQAPromptConstructor, + VisualGLMVQAPromptConstructor) from .visualglm import VisualGLM -__all__ = ['VisualGLM', 'VisualGLMPostProcessor', 'VisualGLMPromptConstructor'] +__all__ = [ + 'VisualGLM', 'VisualGLMBasePostProcessor', 'VisualGLMVSRPostProcessor', + 'VisualGLMBasePromptConstructor', 'VisualGLMMMBenchPromptConstructor', + 'VisualGLMVQAPromptConstructor', 'VisualGLMScienceQAPromptConstructor', + 'VisualGLMIconQAPromptConstructor' +] diff --git a/opencompass/multimodal/models/visualglm/post_processor.py b/opencompass/multimodal/models/visualglm/post_processor.py index ce048ea9..4ff3b4a8 100644 --- a/opencompass/multimodal/models/visualglm/post_processor.py +++ b/opencompass/multimodal/models/visualglm/post_processor.py @@ -3,8 +3,8 @@ from typing import Any import torch -class VisualGLMPostProcessor: - """"Post processor for VisualGLM on MMBench.""" +class VisualGLMBasePostProcessor: + """Base post processor for VisualGLM.""" def __init__(self) -> None: pass @@ -12,3 +12,20 @@ class VisualGLMPostProcessor: def __call__(self, output_token: torch.tensor, tokenizer: Any, input_len: int) -> str: return tokenizer.decode(output_token[input_len:]) + + +class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor): + """VSR post processor for VisualGLM.""" + + def __init__(self) -> None: + super().__init__() + + def __call__(self, output_token: torch.tensor, tokenizer: Any, + input_len: int) -> str: + output_text = tokenizer.decode(output_token[input_len:]) + if 'yes' in output_text.lower(): + return 'yes' + elif 'no' in output_text.lower(): + return 'no' + else: + return 'unknown' diff --git a/opencompass/multimodal/models/visualglm/prompt_constructor.py b/opencompass/multimodal/models/visualglm/prompt_constructor.py index 3ff50f17..ea644c85 100644 --- a/opencompass/multimodal/models/visualglm/prompt_constructor.py +++ b/opencompass/multimodal/models/visualglm/prompt_constructor.py @@ -1,8 +1,8 @@ import torch -class VisualGLMPromptConstructor: - """Prompt constructor for VisualGLM. +class VisualGLMMMBenchPromptConstructor: + """MMBench prompt constructor for VisualGLM. The overall prompt will be formulated as "system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt". @@ -30,7 +30,7 @@ class VisualGLMPromptConstructor: batch (dict): Input data containing image and data_samples. Returns: - tuple: A tuple containing prompt, images and data_samples. + A tuple containing images, prompt, data_samples and image_position. """ images = batch.pop('inputs') @@ -53,3 +53,168 @@ class VisualGLMPromptConstructor: image_position = 5 return images, prompt, data_samples, image_position + + +class VisualGLMBasePromptConstructor: + """Base prompt constructor for VisualGLM. + + The prompt will concat and the given system prompt. + Args: + system_prompt (str): System prompt. (Default: '') + """ + + def __init__(self, system_prompt='') -> None: + self.prompt = system_prompt + + def __call__(self, batch: dict) -> tuple: + """Construct prompt. + + Args: + batch (dict): Input data containing image and data_samples. + + Returns: + A tuple containing images, prompt, data_samples and image_position. + """ + + images = batch.pop('inputs') + images = torch.stack(images, dim=0) + data_samples = batch.pop('data_samples') + + # generate text prompt + img_prompt = '' + prompt = img_prompt + self.prompt + image_position = prompt.rfind('') + 5 + + image_position = 5 + + return images, prompt, data_samples, image_position + + +class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor): + """VQA prompt constructor for VisualGLM. + + The prompt will concat , the question and the system prompt. + Args: + system_prompt (str): System prompt. (Default: '') + """ + + def __init__(self, system_prompt='') -> None: + super().__init__(system_prompt) + + def __call__(self, batch: dict) -> tuple: + """Construct prompt. + + Args: + batch (dict): Input data containing image and data_samples. + + Returns: + A tuple containing images, prompt, data_samples and image_position. + """ + + images = batch.pop('inputs') + images = torch.stack(images, dim=0) + data_samples = batch.pop('data_samples') + questions = [sample.get('question') for sample in data_samples] + + # generate text prompt + prompt = [ + 'Q:{} {}\nA:'.format(question, self.prompt) + for question in questions + ] + image_position = 5 + + return images, prompt, data_samples, image_position + + +class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor): + """ScienceQA prompt constructor for VisualGLM. + + The prompt will concat image and all terms in a question. + Args: + system_prompt (str): System prompt. (Default: '') + """ + + choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'} + + def __init__(self, system_prompt='') -> None: + super().__init__(system_prompt) + + def __call__(self, batch: dict) -> tuple: + """Construct prompt. + + Args: + batch (dict): Input data containing image and data_samples. + + Returns: + A tuple containing images, prompt, data_samples and image_position. + """ + + images = batch.pop('inputs') + images = torch.stack(images, dim=0) + data_samples = batch.pop('data_samples') + questions = [ + 'Q: ' + sample.get('question') + '\n' for sample in data_samples + ] + choices = [sample.get('choices') for sample in data_samples] + choices = [[ + f'({self.choice_mapping[i]}) ' + item + for i, item in enumerate(choice) + ] for choice in choices] + choices = [ + 'Choices: ' + ' '.join(choice) + '\n' for choice in choices + ] # noqa + contexts = [ + 'Context: ' + data_sample.get('hint') + '\n' + for data_sample in data_samples + ] # noqa + + # generate text prompt + prompt = [ + '' + context + question + choice + self.prompt + for context, question, choice in zip(contexts, questions, choices) + ] + image_position = 5 + + return images, prompt, data_samples, image_position + + +class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor): + """IconQA prompt constructor for VisualGLM. + + The prompt will concat , the question and the system prompt. + Args: + system_prompt (str): System prompt. (Default: '') + """ + + def __init__(self, system_prompt='') -> None: + super().__init__(system_prompt) + + def __call__(self, batch: dict) -> tuple: + """Construct prompt. + + Args: + batch (dict): Input data containing image and data_samples. + + Returns: + A tuple containing images, prompt, data_samples and image_position. + """ + + images = batch.pop('inputs') + images = torch.stack(images, dim=0) + data_samples = batch.pop('data_samples') + questions = [ + 'Q: ' + sample.get('question') + '\n' for sample in data_samples + ] + choices = [sample.get('choices') for sample in data_samples] + choices = [ + 'Options: ' + ', '.join(choice) + '.\n' for choice in choices + ] # noqa + + # generate text prompt + prompt = [ + '' + question + choice + self.prompt + for question, choice in zip(questions, choices) + ] + image_position = 5 + + return images, prompt, data_samples, image_position diff --git a/opencompass/multimodal/models/visualglm/visualglm.py b/opencompass/multimodal/models/visualglm/visualglm.py index e5b103bc..1bb99853 100644 --- a/opencompass/multimodal/models/visualglm/visualglm.py +++ b/opencompass/multimodal/models/visualglm/visualglm.py @@ -18,13 +18,17 @@ class VisualGLM(nn.Module): pretrained_path (str): Path to visualGLM checkpoint or repo id. prompt_constructor (dict): The config of prompt constructor. post_processor (dict): The config of post processor. + is_caption_task (bool): Whether the task is caption task. + Defaults to False. gen_kwargs (dict): Customize generate function arguments. + Defaults to None. """ def __init__(self, pretrained_path: str, prompt_constructor: dict, post_processor: dict, + is_caption_task: bool = False, gen_kwargs: Optional[dict] = None) -> None: super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path, @@ -40,6 +44,7 @@ class VisualGLM(nn.Module): self.gen_kwargs = gen_kwargs else: self.gen_kwargs = dict() + self.is_caption_task = is_caption_task def encode_by_tokenizer(self, multi_prompts, image_position): input_ids = [] @@ -89,8 +94,12 @@ class VisualGLM(nn.Module): # format output outputs = outputs.tolist() for i, sample in enumerate(data_sample): - data_sample[i].pred_answer = self.post_processor( - outputs[i], self.tokenizer, input_all.shape[1]) + answer = self.post_processor(outputs[i], self.tokenizer, + input_all.shape[1]) + if self.is_caption_task: + data_sample[i].pred_caption = answer + else: + data_sample[i].pred_answer = answer return data_sample