mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add open source dataset eval config of instruct-blip (#370)
* add configs * refactor model * add post processor and prompt constructor
This commit is contained in:
parent
49c467458f
commit
fada77a31c
53
configs/multimodal/instructblip/instructblip_coco_caption.py
Normal file
53
configs/multimodal/instructblip/instructblip_coco_caption.py
Normal file
@ -0,0 +1,53 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipCOCOCaotionPromptConstructor,
|
||||
InstructBlipCOCOCaptionPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(384, 384),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.COCOCaption',
|
||||
data_root='data/coco',
|
||||
data_prefix=dict(img_path='images'),
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_coco_caption_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_coco_caption_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
img_size=384,
|
||||
is_caption_task=True,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_coco_caption_evaluator = [
|
||||
dict(
|
||||
type='mmpretrain.COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
) # noqa
|
||||
]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
54
configs/multimodal/instructblip/instructblip_flickr30k.py
Normal file
54
configs/multimodal/instructblip/instructblip_flickr30k.py
Normal file
@ -0,0 +1,54 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipCOCOCaotionPromptConstructor,
|
||||
InstructBlipCOCOCaptionPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(384, 384),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.Flickr30kCaption',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='val',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_flickr30k_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_flickr30k_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
img_size=384,
|
||||
is_caption_task=True,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_flickr30k_evaluator = [
|
||||
dict(
|
||||
type='mmpretrain.COCOCaption',
|
||||
ann_file='data/flickr30k/annotations/flickr30k_val_gt.json',
|
||||
) # noqa
|
||||
]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
52
configs/multimodal/instructblip/instructblip_gqa.py
Normal file
52
configs/multimodal/instructblip/instructblip_gqa.py
Normal file
@ -0,0 +1,52 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipVQAPromptConstructor,
|
||||
InstructBlipVQAPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.GQA',
|
||||
data_root='data/gqa',
|
||||
data_prefix='images',
|
||||
ann_file='annotations/testdev_balanced_questions.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_gqa_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_gqa_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
max_output_txt_len=10,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
# evaluation settings
|
||||
instruct_blip_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
51
configs/multimodal/instructblip/instructblip_ocr_vqa.py
Normal file
51
configs/multimodal/instructblip/instructblip_ocr_vqa.py
Normal file
@ -0,0 +1,51 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipVQAPromptConstructor,
|
||||
InstructBlipVQAPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.OCRVQA',
|
||||
data_root='data/ocrvqa',
|
||||
ann_file='annotations/dataset.json',
|
||||
split='test',
|
||||
data_prefix='images',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_ocr_vqa_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_ocr_vqa_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_ocr_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
54
configs/multimodal/instructblip/instructblip_ok_vqa.py
Normal file
54
configs/multimodal/instructblip/instructblip_ok_vqa.py
Normal file
@ -0,0 +1,54 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipVQAPromptConstructor,
|
||||
InstructBlipVQAPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(
|
||||
type='mmpretrain.COCOVQA',
|
||||
data_root='data/okvqa',
|
||||
question_file='annotations/OpenEnded_mscoco_val2014_questions.json',
|
||||
ann_file='annotations/mscoco_val2014_annotations.json',
|
||||
pipeline=val_pipeline,
|
||||
data_prefix='images/val2014',
|
||||
)
|
||||
|
||||
instruct_blip_ok_vqa_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_ok_vqa_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
max_output_txt_len=10,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_ok_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
53
configs/multimodal/instructblip/instructblip_scienceqa.py
Normal file
53
configs/multimodal/instructblip/instructblip_scienceqa.py
Normal file
@ -0,0 +1,53 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipScienceQAPromptConstructor,
|
||||
InstructBlipScienceQAPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.ScienceQA',
|
||||
data_root='./data/scienceqa',
|
||||
split='val',
|
||||
split_file='pid_splits.json',
|
||||
ann_file='problems.json',
|
||||
image_only=True,
|
||||
data_prefix=dict(img_path='val'),
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_scienceqa_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_scienceqa_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipScienceQAPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipScienceQAPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
max_output_txt_len=10,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
53
configs/multimodal/instructblip/instructblip_textvqa.py
Normal file
53
configs/multimodal/instructblip/instructblip_textvqa.py
Normal file
@ -0,0 +1,53 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipVQAPromptConstructor,
|
||||
InstructBlipVQAPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(
|
||||
type='mmpretrain.TextVQA',
|
||||
data_root='data/textvqa',
|
||||
ann_file='annotations/TextVQA_0.5.1_val.json',
|
||||
pipeline=val_pipeline,
|
||||
data_prefix='images/train_images',
|
||||
)
|
||||
|
||||
instruct_blip_textvqa_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_textvqa_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
max_output_txt_len=10,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
51
configs/multimodal/instructblip/instructblip_vizwiz.py
Normal file
51
configs/multimodal/instructblip/instructblip_vizwiz.py
Normal file
@ -0,0 +1,51 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipVQAPromptConstructor,
|
||||
InstructBlipVQAPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.VizWiz',
|
||||
data_root='data/vizwiz/',
|
||||
data_prefix='Images/val',
|
||||
ann_file='Annotations/val.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_vizwiz_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_vizwiz_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
max_output_txt_len=10,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
53
configs/multimodal/instructblip/instructblip_vqav2.py
Normal file
53
configs/multimodal/instructblip/instructblip_vqav2.py
Normal file
@ -0,0 +1,53 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipVQAPromptConstructor,
|
||||
InstructBlipVQAPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(
|
||||
type='mmpretrain.COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='images/val2014',
|
||||
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
|
||||
ann_file='annotations/v2_mscoco_val2014_annotations.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_vqav2_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_vqav2_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipVQAPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
max_output_txt_len=10,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
51
configs/multimodal/instructblip/instructblip_vsr.py
Normal file
51
configs/multimodal/instructblip/instructblip_vsr.py
Normal file
@ -0,0 +1,51 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipVSRPromptConstructor,
|
||||
InstructBlipVSRPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.VSR',
|
||||
data_root='data/vsr/',
|
||||
data_prefix='images/',
|
||||
ann_file='annotations/test.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
instruct_blip_vsr_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
instruct_blip_vsr_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipVSRPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipVSRPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
max_output_txt_len=10,
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
instruct_blip_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]
|
||||
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'
|
@ -1,8 +1,25 @@
|
||||
from .blip2_vicuna_instruct import InstructBlipInferencer
|
||||
from .post_processor import InstructBlipMMBenchPostProcessor
|
||||
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
|
||||
from .post_processor import (InstructBlipCOCOCaptionPostProcessor,
|
||||
InstructBlipMMBenchPostProcessor,
|
||||
InstructBlipScienceQAPostProcessor,
|
||||
InstructBlipVQAPostProcessor,
|
||||
InstructBlipVSRPostProcessor)
|
||||
from .prompt_constructor import (InstructBlipCOCOCaotionPromptConstructor,
|
||||
InstructBlipMMBenchPromptConstructor,
|
||||
InstructBlipScienceQAPromptConstructor,
|
||||
InstructBlipVQAPromptConstructor,
|
||||
InstructBlipVSRPromptConstructor)
|
||||
|
||||
__all__ = [
|
||||
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
|
||||
'InstructBlipMMBenchPostProcessor'
|
||||
'InstructBlipInferencer',
|
||||
'InstructBlipMMBenchPromptConstructor',
|
||||
'InstructBlipMMBenchPostProcessor',
|
||||
'InstructBlipCOCOCaotionPromptConstructor',
|
||||
'InstructBlipCOCOCaptionPostProcessor',
|
||||
'InstructBlipVQAPromptConstructor',
|
||||
'InstructBlipVQAPostProcessor',
|
||||
'InstructBlipScienceQAPromptConstructor',
|
||||
'InstructBlipScienceQAPostProcessor',
|
||||
'InstructBlipVSRPromptConstructor',
|
||||
'InstructBlipVSRPostProcessor',
|
||||
]
|
||||
|
@ -34,6 +34,7 @@ class InstructBlipInferencer(Blip2Base):
|
||||
qformer_text_input: bool = True,
|
||||
low_resource: bool = False,
|
||||
mode: str = 'generation',
|
||||
is_caption_task=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
@ -96,6 +97,7 @@ class InstructBlipInferencer(Blip2Base):
|
||||
self.max_output_txt_len = max_output_txt_len
|
||||
self.sys_prompt = sys_prompt
|
||||
self.prompt = prompt
|
||||
self.is_caption_task = is_caption_task
|
||||
|
||||
self._lemmatizer = None
|
||||
|
||||
@ -228,7 +230,7 @@ class InstructBlipInferencer(Blip2Base):
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_beams=num_beams,
|
||||
max_length=max_length,
|
||||
max_length=self.max_output_txt_len,
|
||||
min_length=min_length,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
@ -238,6 +240,9 @@ class InstructBlipInferencer(Blip2Base):
|
||||
for i, data_sample in enumerate(data_samples):
|
||||
output_token = outputs[i]
|
||||
output_text = self.post_processor(output_token, self.llm_tokenizer)
|
||||
data_sample.pred_answer = output_text
|
||||
if self.is_caption_task:
|
||||
data_sample.pred_caption = output_text
|
||||
else:
|
||||
data_sample.pred_answer = output_text
|
||||
data_samples[i] = data_sample
|
||||
return data_samples
|
||||
|
@ -1,3 +1,4 @@
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
@ -29,3 +30,82 @@ class InstructBlipMMBenchPostProcessor:
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
||||
|
||||
|
||||
class InstructBlipCOCOCaptionPostProcessor:
|
||||
""""Post processor for InstructBlip on COCO Caption."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
output_token[output_token == 0] = 2
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
return output_text
|
||||
|
||||
|
||||
class InstructBlipVQAPostProcessor:
|
||||
""""Post processor for InstructBlip on VQA."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
output_token[output_token == 0] = 2
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
return output_text
|
||||
|
||||
|
||||
class InstructBlipScienceQAPostProcessor:
|
||||
""""Post processor for InstructBlip on ScienceQA."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
output_token[output_token == 0] = 2
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
pattern = re.compile(r'\(([A-Z])\)')
|
||||
output_text = pattern.findall(output_text)
|
||||
if len(output_text) == 0:
|
||||
output_text = random.choice(['A', 'B', 'C', 'D'])
|
||||
else:
|
||||
output_text = output_text[0]
|
||||
return output_text
|
||||
|
||||
|
||||
class InstructBlipVSRPostProcessor:
|
||||
""""Post processor for InstructBlip on VSR."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
output_token[output_token == 0] = 2
|
||||
output_text = tokenizer.decode(output_token, add_special_tokens=False)
|
||||
pattern = r'yes|no|Yes|No'
|
||||
output_text = re.findall(pattern, output_text)
|
||||
if len(output_text) > 0:
|
||||
output_text = output_text[0].lower()
|
||||
return output_text
|
||||
|
@ -53,3 +53,70 @@ class InstructBlipMMBenchPromptConstructor:
|
||||
else:
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
||||
|
||||
|
||||
class InstructBlipCOCOCaotionPromptConstructor(
|
||||
InstructBlipMMBenchPromptConstructor):
|
||||
"""Prompt constructor for InstructBlip on COCO Caption."""
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
prompt = self.image_prompt + ' ' + 'a photo of' + self.reply_prompt
|
||||
return prompt
|
||||
|
||||
|
||||
class InstructBlipVQAPromptConstructor(InstructBlipMMBenchPromptConstructor):
|
||||
"""Prompt constructor for InstructBlip on VQA."""
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
question = questions[0]
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + 'Answer this question in a single word.' + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
||||
|
||||
|
||||
class InstructBlipScienceQAPromptConstructor(
|
||||
InstructBlipMMBenchPromptConstructor):
|
||||
"""Prompt constructor for InstructBlip on ScienceQA."""
|
||||
|
||||
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
'Question: ' + data_sample.get('question') + '\n'
|
||||
for data_sample in data_samples
|
||||
] # noqa
|
||||
choices = [data_sample.get('choices') for data_sample in data_samples]
|
||||
choices = [[
|
||||
f'({self.choice_mapping[i]}) ' + item
|
||||
for i, item in enumerate(choice)
|
||||
] for choice in choices]
|
||||
choices = [
|
||||
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
|
||||
] # noqa
|
||||
contexts = [
|
||||
'Context: ' + data_sample.get('hint') + '\n'
|
||||
for data_sample in data_samples
|
||||
] # noqa
|
||||
question = questions[0]
|
||||
choice = choices[0]
|
||||
context = contexts[0]
|
||||
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + choice + self.reply_prompt + ' ' + 'The answer is' # noqa
|
||||
return prompt
|
||||
|
||||
|
||||
class InstructBlipVSRPromptConstructor(InstructBlipMMBenchPromptConstructor):
|
||||
"""Prompt constructor for InstructBlip on VSR."""
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
question = questions[0]
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
||||
|
Loading…
Reference in New Issue
Block a user