[Feature] Add open source dataset eval config of instruct-blip (#370)

* add configs

* refactor model

* add post processor and prompt constructor
This commit is contained in:
Yixiao Fang 2023-09-08 15:07:09 +08:00 committed by GitHub
parent 49c467458f
commit fada77a31c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 700 additions and 6 deletions

View File

@ -0,0 +1,53 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipCOCOCaotionPromptConstructor,
InstructBlipCOCOCaptionPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(384, 384),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
]
dataset = dict(type='mmpretrain.COCOCaption',
data_root='data/coco',
data_prefix=dict(img_path='images'),
ann_file='annotations/coco_karpathy_val.json',
pipeline=val_pipeline)
instruct_blip_coco_caption_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
instruct_blip_coco_caption_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
img_size=384,
is_caption_task=True,
)
# evaluation settings
instruct_blip_coco_caption_evaluator = [
dict(
type='mmpretrain.COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
) # noqa
]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,54 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipCOCOCaotionPromptConstructor,
InstructBlipCOCOCaptionPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(384, 384),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
]
dataset = dict(type='mmpretrain.Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=val_pipeline)
instruct_blip_flickr30k_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
instruct_blip_flickr30k_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipCOCOCaotionPromptConstructor),
post_processor=dict(type=InstructBlipCOCOCaptionPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
img_size=384,
is_caption_task=True,
)
# evaluation settings
instruct_blip_flickr30k_evaluator = [
dict(
type='mmpretrain.COCOCaption',
ann_file='data/flickr30k/annotations/flickr30k_val_gt.json',
) # noqa
]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,52 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipVQAPromptConstructor,
InstructBlipVQAPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=val_pipeline)
instruct_blip_gqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
instruct_blip_gqa_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
post_processor=dict(type=InstructBlipVQAPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
max_output_txt_len=10,
)
# evaluation settings
# evaluation settings
instruct_blip_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,51 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipVQAPromptConstructor,
InstructBlipVQAPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.OCRVQA',
data_root='data/ocrvqa',
ann_file='annotations/dataset.json',
split='test',
data_prefix='images',
pipeline=val_pipeline)
instruct_blip_ocr_vqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
instruct_blip_ocr_vqa_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
post_processor=dict(type=InstructBlipVQAPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
)
# evaluation settings
instruct_blip_ocr_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,54 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipVQAPromptConstructor,
InstructBlipVQAPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.COCOVQA',
data_root='data/okvqa',
question_file='annotations/OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/mscoco_val2014_annotations.json',
pipeline=val_pipeline,
data_prefix='images/val2014',
)
instruct_blip_ok_vqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
instruct_blip_ok_vqa_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
post_processor=dict(type=InstructBlipVQAPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
max_output_txt_len=10,
)
# evaluation settings
instruct_blip_ok_vqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,53 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipScienceQAPromptConstructor,
InstructBlipScienceQAPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
])
]
dataset = dict(type='mmpretrain.ScienceQA',
data_root='./data/scienceqa',
split='val',
split_file='pid_splits.json',
ann_file='problems.json',
image_only=True,
data_prefix=dict(img_path='val'),
pipeline=val_pipeline)
instruct_blip_scienceqa_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
instruct_blip_scienceqa_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipScienceQAPromptConstructor),
post_processor=dict(type=InstructBlipScienceQAPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
max_output_txt_len=10,
)
# evaluation settings
instruct_blip_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,53 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipVQAPromptConstructor,
InstructBlipVQAPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.TextVQA',
data_root='data/textvqa',
ann_file='annotations/TextVQA_0.5.1_val.json',
pipeline=val_pipeline,
data_prefix='images/train_images',
)
instruct_blip_textvqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
instruct_blip_textvqa_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
post_processor=dict(type=InstructBlipVQAPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
max_output_txt_len=10,
)
# evaluation settings
instruct_blip_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,51 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipVQAPromptConstructor,
InstructBlipVQAPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VizWiz',
data_root='data/vizwiz/',
data_prefix='Images/val',
ann_file='Annotations/val.json',
pipeline=val_pipeline)
instruct_blip_vizwiz_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
instruct_blip_vizwiz_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
post_processor=dict(type=InstructBlipVQAPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
max_output_txt_len=10,
)
# evaluation settings
instruct_blip_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,53 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipVQAPromptConstructor,
InstructBlipVQAPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.COCOVQA',
data_root='data/coco',
data_prefix='images/val2014',
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=val_pipeline)
instruct_blip_vqav2_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
instruct_blip_vqav2_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipVQAPromptConstructor),
post_processor=dict(type=InstructBlipVQAPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
max_output_txt_len=10,
)
# evaluation settings
instruct_blip_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -0,0 +1,51 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipVSRPromptConstructor,
InstructBlipVSRPostProcessor,
)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VSR',
data_root='data/vsr/',
data_prefix='images/',
ann_file='annotations/test.json',
pipeline=val_pipeline)
instruct_blip_vsr_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
instruct_blip_vsr_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipVSRPromptConstructor),
post_processor=dict(type=InstructBlipVSRPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
max_output_txt_len=10,
)
# evaluation settings
instruct_blip_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth'

View File

@ -1,8 +1,25 @@
from .blip2_vicuna_instruct import InstructBlipInferencer
from .post_processor import InstructBlipMMBenchPostProcessor
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
from .post_processor import (InstructBlipCOCOCaptionPostProcessor,
InstructBlipMMBenchPostProcessor,
InstructBlipScienceQAPostProcessor,
InstructBlipVQAPostProcessor,
InstructBlipVSRPostProcessor)
from .prompt_constructor import (InstructBlipCOCOCaotionPromptConstructor,
InstructBlipMMBenchPromptConstructor,
InstructBlipScienceQAPromptConstructor,
InstructBlipVQAPromptConstructor,
InstructBlipVSRPromptConstructor)
__all__ = [
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
'InstructBlipMMBenchPostProcessor'
'InstructBlipInferencer',
'InstructBlipMMBenchPromptConstructor',
'InstructBlipMMBenchPostProcessor',
'InstructBlipCOCOCaotionPromptConstructor',
'InstructBlipCOCOCaptionPostProcessor',
'InstructBlipVQAPromptConstructor',
'InstructBlipVQAPostProcessor',
'InstructBlipScienceQAPromptConstructor',
'InstructBlipScienceQAPostProcessor',
'InstructBlipVSRPromptConstructor',
'InstructBlipVSRPostProcessor',
]

View File

@ -34,6 +34,7 @@ class InstructBlipInferencer(Blip2Base):
qformer_text_input: bool = True,
low_resource: bool = False,
mode: str = 'generation',
is_caption_task=False,
):
super().__init__()
self.mode = mode
@ -96,6 +97,7 @@ class InstructBlipInferencer(Blip2Base):
self.max_output_txt_len = max_output_txt_len
self.sys_prompt = sys_prompt
self.prompt = prompt
self.is_caption_task = is_caption_task
self._lemmatizer = None
@ -228,7 +230,7 @@ class InstructBlipInferencer(Blip2Base):
top_p=top_p,
temperature=temperature,
num_beams=num_beams,
max_length=max_length,
max_length=self.max_output_txt_len,
min_length=min_length,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
@ -238,6 +240,9 @@ class InstructBlipInferencer(Blip2Base):
for i, data_sample in enumerate(data_samples):
output_token = outputs[i]
output_text = self.post_processor(output_token, self.llm_tokenizer)
data_sample.pred_answer = output_text
if self.is_caption_task:
data_sample.pred_caption = output_text
else:
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples

View File

@ -1,3 +1,4 @@
import random
import re
import torch
@ -29,3 +30,82 @@ class InstructBlipMMBenchPostProcessor:
if len(res) > 0:
output_text = res[0][:-1]
return output_text
class InstructBlipCOCOCaptionPostProcessor:
""""Post processor for InstructBlip on COCO Caption."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
output_token[output_token == 0] = 2
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
return output_text
class InstructBlipVQAPostProcessor:
""""Post processor for InstructBlip on VQA."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
output_token[output_token == 0] = 2
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
return output_text
class InstructBlipScienceQAPostProcessor:
""""Post processor for InstructBlip on ScienceQA."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
output_token[output_token == 0] = 2
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'\(([A-Z])\)')
output_text = pattern.findall(output_text)
if len(output_text) == 0:
output_text = random.choice(['A', 'B', 'C', 'D'])
else:
output_text = output_text[0]
return output_text
class InstructBlipVSRPostProcessor:
""""Post processor for InstructBlip on VSR."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
output_token[output_token == 0] = 2
output_text = tokenizer.decode(output_token, add_special_tokens=False)
pattern = r'yes|no|Yes|No'
output_text = re.findall(pattern, output_text)
if len(output_text) > 0:
output_text = output_text[0].lower()
return output_text

View File

@ -53,3 +53,70 @@ class InstructBlipMMBenchPromptConstructor:
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt
class InstructBlipCOCOCaotionPromptConstructor(
InstructBlipMMBenchPromptConstructor):
"""Prompt constructor for InstructBlip on COCO Caption."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
prompt = self.image_prompt + ' ' + 'a photo of' + self.reply_prompt
return prompt
class InstructBlipVQAPromptConstructor(InstructBlipMMBenchPromptConstructor):
"""Prompt constructor for InstructBlip on VQA."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + 'Answer this question in a single word.' + ' ' + self.reply_prompt # noqa
return prompt
class InstructBlipScienceQAPromptConstructor(
InstructBlipMMBenchPromptConstructor):
"""Prompt constructor for InstructBlip on ScienceQA."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
'Question: ' + data_sample.get('question') + '\n'
for data_sample in data_samples
] # noqa
choices = [data_sample.get('choices') for data_sample in data_samples]
choices = [[
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choice)
] for choice in choices]
choices = [
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
] # noqa
contexts = [
'Context: ' + data_sample.get('hint') + '\n'
for data_sample in data_samples
] # noqa
question = questions[0]
choice = choices[0]
context = contexts[0]
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + choice + self.reply_prompt + ' ' + 'The answer is' # noqa
return prompt
class InstructBlipVSRPromptConstructor(InstructBlipMMBenchPromptConstructor):
"""Prompt constructor for InstructBlip on VSR."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
return prompt