mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add open source dataset eval config of instruct-blip (#370)
* add configs * refactor model * add post processor and prompt constructor
This commit is contained in:
parent
49c467458f
commit
fada77a31c
53
configs/multimodal/instructblip/instructblip_coco_caption.py
Normal file
53
configs/multimodal/instructblip/instructblip_coco_caption.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipCOCOCaotionPromptConstructor,
|
||||||
|
InstructBlipCOCOCaptionPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(384, 384),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='mmpretrain.COCOCaption',
|
||||||
|
data_root='data/coco',
|
||||||
|
data_prefix=dict(img_path='images'),
|
||||||
|
ann_file='annotations/coco_karpathy_val.json',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_coco_caption_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_coco_caption_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
img_size=384,
|
||||||
|
is_caption_task=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_coco_caption_evaluator = [
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.COCOCaption',
|
||||||
|
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||||
|
) # noqa
|
||||||
|
]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
54
configs/multimodal/instructblip/instructblip_flickr30k.py
Normal file
54
configs/multimodal/instructblip/instructblip_flickr30k.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipCOCOCaotionPromptConstructor,
|
||||||
|
InstructBlipCOCOCaptionPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(384, 384),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='mmpretrain.Flickr30kCaption',
|
||||||
|
data_root='data/flickr30k',
|
||||||
|
ann_file='annotations/dataset_flickr30k.json',
|
||||||
|
data_prefix='images',
|
||||||
|
split='val',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_flickr30k_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_flickr30k_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
img_size=384,
|
||||||
|
is_caption_task=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_flickr30k_evaluator = [
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.COCOCaption',
|
||||||
|
ann_file='data/flickr30k/annotations/flickr30k_val_gt.json',
|
||||||
|
) # noqa
|
||||||
|
]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
52
configs/multimodal/instructblip/instructblip_gqa.py
Normal file
52
configs/multimodal/instructblip/instructblip_gqa.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipVQAPromptConstructor,
|
||||||
|
InstructBlipVQAPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='mmpretrain.GQA',
|
||||||
|
data_root='data/gqa',
|
||||||
|
data_prefix='images',
|
||||||
|
ann_file='annotations/testdev_balanced_questions.json',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_gqa_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_gqa_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
max_output_txt_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
51
configs/multimodal/instructblip/instructblip_ocr_vqa.py
Normal file
51
configs/multimodal/instructblip/instructblip_ocr_vqa.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipVQAPromptConstructor,
|
||||||
|
InstructBlipVQAPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='mmpretrain.OCRVQA',
|
||||||
|
data_root='data/ocrvqa',
|
||||||
|
ann_file='annotations/dataset.json',
|
||||||
|
split='test',
|
||||||
|
data_prefix='images',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_ocr_vqa_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_ocr_vqa_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_ocr_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
54
configs/multimodal/instructblip/instructblip_ok_vqa.py
Normal file
54
configs/multimodal/instructblip/instructblip_ok_vqa.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipVQAPromptConstructor,
|
||||||
|
InstructBlipVQAPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(
|
||||||
|
type='mmpretrain.COCOVQA',
|
||||||
|
data_root='data/okvqa',
|
||||||
|
question_file='annotations/OpenEnded_mscoco_val2014_questions.json',
|
||||||
|
ann_file='annotations/mscoco_val2014_annotations.json',
|
||||||
|
pipeline=val_pipeline,
|
||||||
|
data_prefix='images/val2014',
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_blip_ok_vqa_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_ok_vqa_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
max_output_txt_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_ok_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
53
configs/multimodal/instructblip/instructblip_scienceqa.py
Normal file
53
configs/multimodal/instructblip/instructblip_scienceqa.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipScienceQAPromptConstructor,
|
||||||
|
InstructBlipScienceQAPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=[
|
||||||
|
'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='mmpretrain.ScienceQA',
|
||||||
|
data_root='./data/scienceqa',
|
||||||
|
split='val',
|
||||||
|
split_file='pid_splits.json',
|
||||||
|
ann_file='problems.json',
|
||||||
|
image_only=True,
|
||||||
|
data_prefix=dict(img_path='val'),
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_scienceqa_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_scienceqa_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipScienceQAPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipScienceQAPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
max_output_txt_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
53
configs/multimodal/instructblip/instructblip_textvqa.py
Normal file
53
configs/multimodal/instructblip/instructblip_textvqa.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipVQAPromptConstructor,
|
||||||
|
InstructBlipVQAPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(
|
||||||
|
type='mmpretrain.TextVQA',
|
||||||
|
data_root='data/textvqa',
|
||||||
|
ann_file='annotations/TextVQA_0.5.1_val.json',
|
||||||
|
pipeline=val_pipeline,
|
||||||
|
data_prefix='images/train_images',
|
||||||
|
)
|
||||||
|
|
||||||
|
instruct_blip_textvqa_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_textvqa_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
max_output_txt_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
51
configs/multimodal/instructblip/instructblip_vizwiz.py
Normal file
51
configs/multimodal/instructblip/instructblip_vizwiz.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipVQAPromptConstructor,
|
||||||
|
InstructBlipVQAPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='mmpretrain.VizWiz',
|
||||||
|
data_root='data/vizwiz/',
|
||||||
|
data_prefix='Images/val',
|
||||||
|
ann_file='Annotations/val.json',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_vizwiz_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_vizwiz_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
max_output_txt_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
53
configs/multimodal/instructblip/instructblip_vqav2.py
Normal file
53
configs/multimodal/instructblip/instructblip_vqav2.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipVQAPromptConstructor,
|
||||||
|
InstructBlipVQAPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(
|
||||||
|
type='mmpretrain.COCOVQA',
|
||||||
|
data_root='data/coco',
|
||||||
|
data_prefix='images/val2014',
|
||||||
|
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
|
||||||
|
ann_file='annotations/v2_mscoco_val2014_annotations.json',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_vqav2_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_vqav2_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
max_output_txt_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
51
configs/multimodal/instructblip/instructblip_vsr.py
Normal file
51
configs/multimodal/instructblip/instructblip_vsr.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipVSRPromptConstructor,
|
||||||
|
InstructBlipVSRPostProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.LoadImageFromFile'),
|
||||||
|
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='mmpretrain.VSR',
|
||||||
|
data_root='data/vsr/',
|
||||||
|
data_prefix='images/',
|
||||||
|
ann_file='annotations/test.json',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
instruct_blip_vsr_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
instruct_blip_vsr_model = dict(
|
||||||
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipVSRPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipVSRPostProcessor),
|
||||||
|
freeze_vit=True,
|
||||||
|
low_resource=False,
|
||||||
|
llm_model='/path/to/vicuna-7b/',
|
||||||
|
max_output_txt_len=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
instruct_blip_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]
|
||||||
|
|
||||||
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
@ -1,8 +1,25 @@
|
|||||||
from .blip2_vicuna_instruct import InstructBlipInferencer
|
from .blip2_vicuna_instruct import InstructBlipInferencer
|
||||||
from .post_processor import InstructBlipMMBenchPostProcessor
|
from .post_processor import (InstructBlipCOCOCaptionPostProcessor,
|
||||||
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
|
InstructBlipMMBenchPostProcessor,
|
||||||
|
InstructBlipScienceQAPostProcessor,
|
||||||
|
InstructBlipVQAPostProcessor,
|
||||||
|
InstructBlipVSRPostProcessor)
|
||||||
|
from .prompt_constructor import (InstructBlipCOCOCaotionPromptConstructor,
|
||||||
|
InstructBlipMMBenchPromptConstructor,
|
||||||
|
InstructBlipScienceQAPromptConstructor,
|
||||||
|
InstructBlipVQAPromptConstructor,
|
||||||
|
InstructBlipVSRPromptConstructor)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
|
'InstructBlipInferencer',
|
||||||
'InstructBlipMMBenchPostProcessor'
|
'InstructBlipMMBenchPromptConstructor',
|
||||||
|
'InstructBlipMMBenchPostProcessor',
|
||||||
|
'InstructBlipCOCOCaotionPromptConstructor',
|
||||||
|
'InstructBlipCOCOCaptionPostProcessor',
|
||||||
|
'InstructBlipVQAPromptConstructor',
|
||||||
|
'InstructBlipVQAPostProcessor',
|
||||||
|
'InstructBlipScienceQAPromptConstructor',
|
||||||
|
'InstructBlipScienceQAPostProcessor',
|
||||||
|
'InstructBlipVSRPromptConstructor',
|
||||||
|
'InstructBlipVSRPostProcessor',
|
||||||
]
|
]
|
||||||
|
@ -34,6 +34,7 @@ class InstructBlipInferencer(Blip2Base):
|
|||||||
qformer_text_input: bool = True,
|
qformer_text_input: bool = True,
|
||||||
low_resource: bool = False,
|
low_resource: bool = False,
|
||||||
mode: str = 'generation',
|
mode: str = 'generation',
|
||||||
|
is_caption_task=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
@ -96,6 +97,7 @@ class InstructBlipInferencer(Blip2Base):
|
|||||||
self.max_output_txt_len = max_output_txt_len
|
self.max_output_txt_len = max_output_txt_len
|
||||||
self.sys_prompt = sys_prompt
|
self.sys_prompt = sys_prompt
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
|
self.is_caption_task = is_caption_task
|
||||||
|
|
||||||
self._lemmatizer = None
|
self._lemmatizer = None
|
||||||
|
|
||||||
@ -228,7 +230,7 @@ class InstructBlipInferencer(Blip2Base):
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
max_length=max_length,
|
max_length=self.max_output_txt_len,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
@ -238,6 +240,9 @@ class InstructBlipInferencer(Blip2Base):
|
|||||||
for i, data_sample in enumerate(data_samples):
|
for i, data_sample in enumerate(data_samples):
|
||||||
output_token = outputs[i]
|
output_token = outputs[i]
|
||||||
output_text = self.post_processor(output_token, self.llm_tokenizer)
|
output_text = self.post_processor(output_token, self.llm_tokenizer)
|
||||||
data_sample.pred_answer = output_text
|
if self.is_caption_task:
|
||||||
|
data_sample.pred_caption = output_text
|
||||||
|
else:
|
||||||
|
data_sample.pred_answer = output_text
|
||||||
data_samples[i] = data_sample
|
data_samples[i] = data_sample
|
||||||
return data_samples
|
return data_samples
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -29,3 +30,82 @@ class InstructBlipMMBenchPostProcessor:
|
|||||||
if len(res) > 0:
|
if len(res) > 0:
|
||||||
output_text = res[0][:-1]
|
output_text = res[0][:-1]
|
||||||
return output_text
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipCOCOCaptionPostProcessor:
|
||||||
|
""""Post processor for InstructBlip on COCO Caption."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
|
||||||
|
output_token[output_token == 0] = 2
|
||||||
|
output_text = tokenizer.decode(output_token,
|
||||||
|
add_special_tokens=False) # noqa
|
||||||
|
output_text = output_text.split('###')[0]
|
||||||
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||||||
|
output_text = output_text.strip('</s><s>')
|
||||||
|
output_text = output_text.strip('</Img>')
|
||||||
|
output_text = output_text.strip()
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipVQAPostProcessor:
|
||||||
|
""""Post processor for InstructBlip on VQA."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
output_token[output_token == 0] = 2
|
||||||
|
output_text = tokenizer.decode(output_token,
|
||||||
|
add_special_tokens=False) # noqa
|
||||||
|
output_text = output_text.split('###')[0]
|
||||||
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||||||
|
output_text = output_text.strip('</s><s>')
|
||||||
|
output_text = output_text.strip('</Img>')
|
||||||
|
output_text = output_text.strip()
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipScienceQAPostProcessor:
|
||||||
|
""""Post processor for InstructBlip on ScienceQA."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
|
||||||
|
output_token[output_token == 0] = 2
|
||||||
|
output_text = tokenizer.decode(output_token,
|
||||||
|
add_special_tokens=False) # noqa
|
||||||
|
output_text = output_text.split('###')[0]
|
||||||
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||||||
|
output_text = output_text.strip('</s><s>')
|
||||||
|
output_text = output_text.strip('</Img>')
|
||||||
|
output_text = output_text.strip()
|
||||||
|
pattern = re.compile(r'\(([A-Z])\)')
|
||||||
|
output_text = pattern.findall(output_text)
|
||||||
|
if len(output_text) == 0:
|
||||||
|
output_text = random.choice(['A', 'B', 'C', 'D'])
|
||||||
|
else:
|
||||||
|
output_text = output_text[0]
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipVSRPostProcessor:
|
||||||
|
""""Post processor for InstructBlip on VSR."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
|
||||||
|
output_token[output_token == 0] = 2
|
||||||
|
output_text = tokenizer.decode(output_token, add_special_tokens=False)
|
||||||
|
pattern = r'yes|no|Yes|No'
|
||||||
|
output_text = re.findall(pattern, output_text)
|
||||||
|
if len(output_text) > 0:
|
||||||
|
output_text = output_text[0].lower()
|
||||||
|
return output_text
|
||||||
|
@ -53,3 +53,70 @@ class InstructBlipMMBenchPromptConstructor:
|
|||||||
else:
|
else:
|
||||||
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipCOCOCaotionPromptConstructor(
|
||||||
|
InstructBlipMMBenchPromptConstructor):
|
||||||
|
"""Prompt constructor for InstructBlip on COCO Caption."""
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||||
|
prompt = self.image_prompt + ' ' + 'a photo of' + self.reply_prompt
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipVQAPromptConstructor(InstructBlipMMBenchPromptConstructor):
|
||||||
|
"""Prompt constructor for InstructBlip on VQA."""
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||||
|
questions = [
|
||||||
|
data_sample.get('question') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
question = questions[0]
|
||||||
|
prompt = self.image_prompt + ' ' + question + ' ' + 'Answer this question in a single word.' + ' ' + self.reply_prompt # noqa
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipScienceQAPromptConstructor(
|
||||||
|
InstructBlipMMBenchPromptConstructor):
|
||||||
|
"""Prompt constructor for InstructBlip on ScienceQA."""
|
||||||
|
|
||||||
|
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||||
|
questions = [
|
||||||
|
'Question: ' + data_sample.get('question') + '\n'
|
||||||
|
for data_sample in data_samples
|
||||||
|
] # noqa
|
||||||
|
choices = [data_sample.get('choices') for data_sample in data_samples]
|
||||||
|
choices = [[
|
||||||
|
f'({self.choice_mapping[i]}) ' + item
|
||||||
|
for i, item in enumerate(choice)
|
||||||
|
] for choice in choices]
|
||||||
|
choices = [
|
||||||
|
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
|
||||||
|
] # noqa
|
||||||
|
contexts = [
|
||||||
|
'Context: ' + data_sample.get('hint') + '\n'
|
||||||
|
for data_sample in data_samples
|
||||||
|
] # noqa
|
||||||
|
question = questions[0]
|
||||||
|
choice = choices[0]
|
||||||
|
context = contexts[0]
|
||||||
|
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + choice + self.reply_prompt + ' ' + 'The answer is' # noqa
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipVSRPromptConstructor(InstructBlipMMBenchPromptConstructor):
|
||||||
|
"""Prompt constructor for InstructBlip on VSR."""
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||||
|
questions = [
|
||||||
|
data_sample.get('question') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
question = questions[0]
|
||||||
|
prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
|
||||||
|
return prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user