diff --git a/configs/multimodal/instructblip/instructblip_coco_caption.py b/configs/multimodal/instructblip/instructblip_coco_caption.py
new file mode 100644
index 00000000..54ec3d2b
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_coco_caption.py
@@ -0,0 +1,53 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipCOCOCaotionPromptConstructor,
+ InstructBlipCOCOCaptionPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(384, 384),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
+]
+
+dataset = dict(type='mmpretrain.COCOCaption',
+ data_root='data/coco',
+ data_prefix=dict(img_path='images'),
+ ann_file='annotations/coco_karpathy_val.json',
+ pipeline=val_pipeline)
+
+instruct_blip_coco_caption_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler', shuffle=False))
+
+# model settings
+instruct_blip_coco_caption_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
+ post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ img_size=384,
+ is_caption_task=True,
+)
+
+# evaluation settings
+instruct_blip_coco_caption_evaluator = [
+ dict(
+ type='mmpretrain.COCOCaption',
+ ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
+ ) # noqa
+]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_flickr30k.py b/configs/multimodal/instructblip/instructblip_flickr30k.py
new file mode 100644
index 00000000..76e0f6f3
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_flickr30k.py
@@ -0,0 +1,54 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipCOCOCaotionPromptConstructor,
+ InstructBlipCOCOCaptionPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(384, 384),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
+]
+
+dataset = dict(type='mmpretrain.Flickr30kCaption',
+ data_root='data/flickr30k',
+ ann_file='annotations/dataset_flickr30k.json',
+ data_prefix='images',
+ split='val',
+ pipeline=val_pipeline)
+
+instruct_blip_flickr30k_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler', shuffle=False))
+
+# model settings
+instruct_blip_flickr30k_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
+ post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ img_size=384,
+ is_caption_task=True,
+)
+
+# evaluation settings
+instruct_blip_flickr30k_evaluator = [
+ dict(
+ type='mmpretrain.COCOCaption',
+ ann_file='data/flickr30k/annotations/flickr30k_val_gt.json',
+ ) # noqa
+]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_gqa.py b/configs/multimodal/instructblip/instructblip_gqa.py
new file mode 100644
index 00000000..beb1e626
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_gqa.py
@@ -0,0 +1,52 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipVQAPromptConstructor,
+ InstructBlipVQAPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(
+ type='mmpretrain.PackInputs',
+ algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
+ meta_keys=['question_id', 'image_id'],
+ )
+]
+
+dataset = dict(type='mmpretrain.GQA',
+ data_root='data/gqa',
+ data_prefix='images',
+ ann_file='annotations/testdev_balanced_questions.json',
+ pipeline=val_pipeline)
+
+instruct_blip_gqa_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+instruct_blip_gqa_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
+ post_processor=dict(type=InstructBlipVQAPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ max_output_txt_len=10,
+)
+
+# evaluation settings
+# evaluation settings
+instruct_blip_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_ocr_vqa.py b/configs/multimodal/instructblip/instructblip_ocr_vqa.py
new file mode 100644
index 00000000..3c46266c
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_ocr_vqa.py
@@ -0,0 +1,51 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipVQAPromptConstructor,
+ InstructBlipVQAPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(
+ type='mmpretrain.PackInputs',
+ algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
+ meta_keys=['question_id', 'image_id'],
+ )
+]
+
+dataset = dict(type='mmpretrain.OCRVQA',
+ data_root='data/ocrvqa',
+ ann_file='annotations/dataset.json',
+ split='test',
+ data_prefix='images',
+ pipeline=val_pipeline)
+
+instruct_blip_ocr_vqa_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+instruct_blip_ocr_vqa_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
+ post_processor=dict(type=InstructBlipVQAPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+)
+
+# evaluation settings
+instruct_blip_ocr_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_ok_vqa.py b/configs/multimodal/instructblip/instructblip_ok_vqa.py
new file mode 100644
index 00000000..7d45e265
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_ok_vqa.py
@@ -0,0 +1,54 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipVQAPromptConstructor,
+ InstructBlipVQAPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(
+ type='mmpretrain.PackInputs',
+ algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
+ meta_keys=['question_id', 'image_id'],
+ )
+]
+
+dataset = dict(
+ type='mmpretrain.COCOVQA',
+ data_root='data/okvqa',
+ question_file='annotations/OpenEnded_mscoco_val2014_questions.json',
+ ann_file='annotations/mscoco_val2014_annotations.json',
+ pipeline=val_pipeline,
+ data_prefix='images/val2014',
+)
+
+instruct_blip_ok_vqa_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+instruct_blip_ok_vqa_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
+ post_processor=dict(type=InstructBlipVQAPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ max_output_txt_len=10,
+)
+
+# evaluation settings
+instruct_blip_ok_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_scienceqa.py b/configs/multimodal/instructblip/instructblip_scienceqa.py
new file mode 100644
index 00000000..66302597
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_scienceqa.py
@@ -0,0 +1,53 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipScienceQAPromptConstructor,
+ InstructBlipScienceQAPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs',
+ algorithm_keys=[
+ 'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
+ ])
+]
+
+dataset = dict(type='mmpretrain.ScienceQA',
+ data_root='./data/scienceqa',
+ split='val',
+ split_file='pid_splits.json',
+ ann_file='problems.json',
+ image_only=True,
+ data_prefix=dict(img_path='val'),
+ pipeline=val_pipeline)
+
+instruct_blip_scienceqa_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler', shuffle=False))
+
+# model settings
+instruct_blip_scienceqa_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipScienceQAPromptConstructor),
+ post_processor=dict(type=InstructBlipScienceQAPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ max_output_txt_len=10,
+)
+
+# evaluation settings
+instruct_blip_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_textvqa.py b/configs/multimodal/instructblip/instructblip_textvqa.py
new file mode 100644
index 00000000..6b59aaec
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_textvqa.py
@@ -0,0 +1,53 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipVQAPromptConstructor,
+ InstructBlipVQAPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(
+ type='mmpretrain.PackInputs',
+ algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
+ meta_keys=['question_id', 'image_id'],
+ )
+]
+
+dataset = dict(
+ type='mmpretrain.TextVQA',
+ data_root='data/textvqa',
+ ann_file='annotations/TextVQA_0.5.1_val.json',
+ pipeline=val_pipeline,
+ data_prefix='images/train_images',
+)
+
+instruct_blip_textvqa_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+instruct_blip_textvqa_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
+ post_processor=dict(type=InstructBlipVQAPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ max_output_txt_len=10,
+)
+
+# evaluation settings
+instruct_blip_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_vizwiz.py b/configs/multimodal/instructblip/instructblip_vizwiz.py
new file mode 100644
index 00000000..00ca79f8
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_vizwiz.py
@@ -0,0 +1,51 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipVQAPromptConstructor,
+ InstructBlipVQAPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(
+ type='mmpretrain.PackInputs',
+ algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
+ meta_keys=['question_id', 'image_id'],
+ )
+]
+
+dataset = dict(type='mmpretrain.VizWiz',
+ data_root='data/vizwiz/',
+ data_prefix='Images/val',
+ ann_file='Annotations/val.json',
+ pipeline=val_pipeline)
+
+instruct_blip_vizwiz_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+instruct_blip_vizwiz_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
+ post_processor=dict(type=InstructBlipVQAPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ max_output_txt_len=10,
+)
+
+# evaluation settings
+instruct_blip_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_vqav2.py b/configs/multimodal/instructblip/instructblip_vqav2.py
new file mode 100644
index 00000000..0dbc56a3
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_vqav2.py
@@ -0,0 +1,53 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipVQAPromptConstructor,
+ InstructBlipVQAPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(
+ type='mmpretrain.PackInputs',
+ algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
+ meta_keys=['question_id', 'image_id'],
+ )
+]
+
+dataset = dict(
+ type='mmpretrain.COCOVQA',
+ data_root='data/coco',
+ data_prefix='images/val2014',
+ question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
+ ann_file='annotations/v2_mscoco_val2014_annotations.json',
+ pipeline=val_pipeline)
+
+instruct_blip_vqav2_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+instruct_blip_vqav2_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
+ post_processor=dict(type=InstructBlipVQAPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ max_output_txt_len=10,
+)
+
+# evaluation settings
+instruct_blip_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/configs/multimodal/instructblip/instructblip_vsr.py b/configs/multimodal/instructblip/instructblip_vsr.py
new file mode 100644
index 00000000..083527a2
--- /dev/null
+++ b/configs/multimodal/instructblip/instructblip_vsr.py
@@ -0,0 +1,51 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipVSRPromptConstructor,
+ InstructBlipVSRPostProcessor,
+)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(
+ type='mmpretrain.PackInputs',
+ algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
+ meta_keys=['question_id', 'image_id'],
+ )
+]
+
+dataset = dict(type='mmpretrain.VSR',
+ data_root='data/vsr/',
+ data_prefix='images/',
+ ann_file='annotations/test.json',
+ pipeline=val_pipeline)
+
+instruct_blip_vsr_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+instruct_blip_vsr_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipVSRPromptConstructor),
+ post_processor=dict(type=InstructBlipVSRPostProcessor),
+ freeze_vit=True,
+ low_resource=False,
+ llm_model='/path/to/vicuna-7b/',
+ max_output_txt_len=10,
+)
+
+# evaluation settings
+instruct_blip_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]
+
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
diff --git a/opencompass/multimodal/models/instructblip/__init__.py b/opencompass/multimodal/models/instructblip/__init__.py
index af926280..6505ec42 100644
--- a/opencompass/multimodal/models/instructblip/__init__.py
+++ b/opencompass/multimodal/models/instructblip/__init__.py
@@ -1,8 +1,25 @@
from .blip2_vicuna_instruct import InstructBlipInferencer
-from .post_processor import InstructBlipMMBenchPostProcessor
-from .prompt_constructor import InstructBlipMMBenchPromptConstructor
+from .post_processor import (InstructBlipCOCOCaptionPostProcessor,
+ InstructBlipMMBenchPostProcessor,
+ InstructBlipScienceQAPostProcessor,
+ InstructBlipVQAPostProcessor,
+ InstructBlipVSRPostProcessor)
+from .prompt_constructor import (InstructBlipCOCOCaotionPromptConstructor,
+ InstructBlipMMBenchPromptConstructor,
+ InstructBlipScienceQAPromptConstructor,
+ InstructBlipVQAPromptConstructor,
+ InstructBlipVSRPromptConstructor)
__all__ = [
- 'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
- 'InstructBlipMMBenchPostProcessor'
+ 'InstructBlipInferencer',
+ 'InstructBlipMMBenchPromptConstructor',
+ 'InstructBlipMMBenchPostProcessor',
+ 'InstructBlipCOCOCaotionPromptConstructor',
+ 'InstructBlipCOCOCaptionPostProcessor',
+ 'InstructBlipVQAPromptConstructor',
+ 'InstructBlipVQAPostProcessor',
+ 'InstructBlipScienceQAPromptConstructor',
+ 'InstructBlipScienceQAPostProcessor',
+ 'InstructBlipVSRPromptConstructor',
+ 'InstructBlipVSRPostProcessor',
]
diff --git a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
index bc08a31d..0b91cf24 100644
--- a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
+++ b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
@@ -34,6 +34,7 @@ class InstructBlipInferencer(Blip2Base):
qformer_text_input: bool = True,
low_resource: bool = False,
mode: str = 'generation',
+ is_caption_task=False,
):
super().__init__()
self.mode = mode
@@ -96,6 +97,7 @@ class InstructBlipInferencer(Blip2Base):
self.max_output_txt_len = max_output_txt_len
self.sys_prompt = sys_prompt
self.prompt = prompt
+ self.is_caption_task = is_caption_task
self._lemmatizer = None
@@ -228,7 +230,7 @@ class InstructBlipInferencer(Blip2Base):
top_p=top_p,
temperature=temperature,
num_beams=num_beams,
- max_length=max_length,
+ max_length=self.max_output_txt_len,
min_length=min_length,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
@@ -238,6 +240,9 @@ class InstructBlipInferencer(Blip2Base):
for i, data_sample in enumerate(data_samples):
output_token = outputs[i]
output_text = self.post_processor(output_token, self.llm_tokenizer)
- data_sample.pred_answer = output_text
+ if self.is_caption_task:
+ data_sample.pred_caption = output_text
+ else:
+ data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples
diff --git a/opencompass/multimodal/models/instructblip/post_processor.py b/opencompass/multimodal/models/instructblip/post_processor.py
index 0b124a6f..b67949f7 100644
--- a/opencompass/multimodal/models/instructblip/post_processor.py
+++ b/opencompass/multimodal/models/instructblip/post_processor.py
@@ -1,3 +1,4 @@
+import random
import re
import torch
@@ -29,3 +30,82 @@ class InstructBlipMMBenchPostProcessor:
if len(res) > 0:
output_text = res[0][:-1]
return output_text
+
+
+class InstructBlipCOCOCaptionPostProcessor:
+ """"Post processor for InstructBlip on COCO Caption."""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, output_token: torch.tensor, tokenizer) -> str:
+
+ output_token[output_token == 0] = 2
+ output_text = tokenizer.decode(output_token,
+ add_special_tokens=False) # noqa
+ output_text = output_text.split('###')[0]
+ output_text = output_text.split('Assistant:')[-1].strip()
+ output_text = output_text.strip('')
+ output_text = output_text.strip('')
+ output_text = output_text.strip()
+ return output_text
+
+
+class InstructBlipVQAPostProcessor:
+ """"Post processor for InstructBlip on VQA."""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, output_token: torch.tensor, tokenizer) -> str:
+ output_token[output_token == 0] = 2
+ output_text = tokenizer.decode(output_token,
+ add_special_tokens=False) # noqa
+ output_text = output_text.split('###')[0]
+ output_text = output_text.split('Assistant:')[-1].strip()
+ output_text = output_text.strip('')
+ output_text = output_text.strip('')
+ output_text = output_text.strip()
+ return output_text
+
+
+class InstructBlipScienceQAPostProcessor:
+ """"Post processor for InstructBlip on ScienceQA."""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, output_token: torch.tensor, tokenizer) -> str:
+
+ output_token[output_token == 0] = 2
+ output_text = tokenizer.decode(output_token,
+ add_special_tokens=False) # noqa
+ output_text = output_text.split('###')[0]
+ output_text = output_text.split('Assistant:')[-1].strip()
+ output_text = output_text.strip('')
+ output_text = output_text.strip('')
+ output_text = output_text.strip()
+ pattern = re.compile(r'\(([A-Z])\)')
+ output_text = pattern.findall(output_text)
+ if len(output_text) == 0:
+ output_text = random.choice(['A', 'B', 'C', 'D'])
+ else:
+ output_text = output_text[0]
+ return output_text
+
+
+class InstructBlipVSRPostProcessor:
+ """"Post processor for InstructBlip on VSR."""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, output_token: torch.tensor, tokenizer) -> str:
+
+ output_token[output_token == 0] = 2
+ output_text = tokenizer.decode(output_token, add_special_tokens=False)
+ pattern = r'yes|no|Yes|No'
+ output_text = re.findall(pattern, output_text)
+ if len(output_text) > 0:
+ output_text = output_text[0].lower()
+ return output_text
diff --git a/opencompass/multimodal/models/instructblip/prompt_constructor.py b/opencompass/multimodal/models/instructblip/prompt_constructor.py
index f617e929..818b7a93 100644
--- a/opencompass/multimodal/models/instructblip/prompt_constructor.py
+++ b/opencompass/multimodal/models/instructblip/prompt_constructor.py
@@ -53,3 +53,70 @@ class InstructBlipMMBenchPromptConstructor:
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt
+
+
+class InstructBlipCOCOCaotionPromptConstructor(
+ InstructBlipMMBenchPromptConstructor):
+ """Prompt constructor for InstructBlip on COCO Caption."""
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ prompt = self.image_prompt + ' ' + 'a photo of' + self.reply_prompt
+ return prompt
+
+
+class InstructBlipVQAPromptConstructor(InstructBlipMMBenchPromptConstructor):
+ """Prompt constructor for InstructBlip on VQA."""
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ questions = [
+ data_sample.get('question') for data_sample in data_samples
+ ]
+ question = questions[0]
+ prompt = self.image_prompt + ' ' + question + ' ' + 'Answer this question in a single word.' + ' ' + self.reply_prompt # noqa
+ return prompt
+
+
+class InstructBlipScienceQAPromptConstructor(
+ InstructBlipMMBenchPromptConstructor):
+ """Prompt constructor for InstructBlip on ScienceQA."""
+
+ choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ questions = [
+ 'Question: ' + data_sample.get('question') + '\n'
+ for data_sample in data_samples
+ ] # noqa
+ choices = [data_sample.get('choices') for data_sample in data_samples]
+ choices = [[
+ f'({self.choice_mapping[i]}) ' + item
+ for i, item in enumerate(choice)
+ ] for choice in choices]
+ choices = [
+ 'Choices: ' + ' '.join(choice) + '\n' for choice in choices
+ ] # noqa
+ contexts = [
+ 'Context: ' + data_sample.get('hint') + '\n'
+ for data_sample in data_samples
+ ] # noqa
+ question = questions[0]
+ choice = choices[0]
+ context = contexts[0]
+ prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + choice + self.reply_prompt + ' ' + 'The answer is' # noqa
+ return prompt
+
+
+class InstructBlipVSRPromptConstructor(InstructBlipMMBenchPromptConstructor):
+ """Prompt constructor for InstructBlip on VSR."""
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ questions = [
+ data_sample.get('question') for data_sample in data_samples
+ ]
+ question = questions[0]
+ prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
+ return prompt