mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature]: Add other public datasets (#206)
* [Feature]: Refactor class name * [Feature]: Add minigpt-4 coco caption * [Feature]: Update minigpt-4 coco caption * [Feature]: Add MiniGPT-4 ScienceQA * [Feature]: Add minigpt-4 vqav2 * [Feature]: Add VSR * [Feature]: Revert task to previous version
This commit is contained in:
parent
3a46b6c64f
commit
78df9bd0cb
52
configs/multimodal/minigpt_4/minigpt_4_7b_coco_caption.py
Normal file
52
configs/multimodal/minigpt_4/minigpt_4_7b_coco_caption.py
Normal file
@ -0,0 +1,52 @@
|
||||
from opencompass.multimodal.models.minigpt_4 import (
|
||||
MiniGPT4COCOCaotionPromptConstructor,
|
||||
MiniGPT4COCOCaptionPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(384, 384),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.COCOCaption',
|
||||
data_root='data/coco',
|
||||
data_prefix=dict(img_path='images'),
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
minigpt_4_coco_caption_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
minigpt_4_coco_caption_model = dict(
|
||||
type='minigpt-4',
|
||||
low_resource=False,
|
||||
img_size=384,
|
||||
llama_model='/path/to/vicuna-7b/',
|
||||
prompt_constructor=dict(type=MiniGPT4COCOCaotionPromptConstructor,
|
||||
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||
reply_prompt='###Assistant:'),
|
||||
post_processor=dict(type=MiniGPT4COCOCaptionPostProcessor))
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_coco_caption_evaluator = [
|
||||
dict(
|
||||
type='mmpretrain.COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
) # noqa
|
||||
]
|
||||
|
||||
minigpt_4_coco_caption_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
|
@ -1,5 +1,5 @@
|
||||
from opencompass.multimodal.models.minigpt_4 import (
|
||||
MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor)
|
||||
MiniGPT4MMBenchPromptConstructor, MiniGPT4MMBenchPostProcessor)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
@ -29,13 +29,13 @@ minigpt_4_dataloader = dict(batch_size=1,
|
||||
|
||||
# model settings
|
||||
minigpt_4_model = dict(
|
||||
type='minigpt-4-mmbench',
|
||||
type='minigpt-4',
|
||||
low_resource=False,
|
||||
llama_model='/path/to/vicuna-7b/',
|
||||
prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor,
|
||||
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||
reply_prompt='###Assistant:'),
|
||||
post_processor=dict(type=MiniGPT4PostProcessor))
|
||||
post_processor=dict(type=MiniGPT4MMBenchPostProcessor))
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_evaluator = [
|
||||
|
52
configs/multimodal/minigpt_4/minigpt_4_7b_scienceqa.py
Normal file
52
configs/multimodal/minigpt_4/minigpt_4_7b_scienceqa.py
Normal file
@ -0,0 +1,52 @@
|
||||
from opencompass.multimodal.models import (MiniGPT4ScienceQAPromptConstructor,
|
||||
MiniGPT4ScienceQAPostProcessor)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.ScienceQA',
|
||||
data_root='./data/scienceqa',
|
||||
split='val',
|
||||
split_file='pid_splits.json',
|
||||
ann_file='problems.json',
|
||||
image_only=True,
|
||||
data_prefix=dict(img_path='val'),
|
||||
pipeline=val_pipeline)
|
||||
|
||||
minigpt_4_scienceqa_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
minigpt_4_scienceqa_model = dict(
|
||||
type='minigpt-4',
|
||||
low_resource=False,
|
||||
img_size=224,
|
||||
max_length=10,
|
||||
llama_model='/path/to/vicuna-7b/',
|
||||
prompt_constructor=dict(type=MiniGPT4ScienceQAPromptConstructor,
|
||||
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||
reply_prompt='###Assistant:'),
|
||||
post_processor=dict(type=MiniGPT4ScienceQAPostProcessor))
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]
|
||||
|
||||
minigpt_4_scienceqa_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
|
55
configs/multimodal/minigpt_4/minigpt_4_7b_vqav2.py
Normal file
55
configs/multimodal/minigpt_4/minigpt_4_7b_vqav2.py
Normal file
@ -0,0 +1,55 @@
|
||||
from opencompass.multimodal.models.minigpt_4 import (
|
||||
MiniGPT4VQAPromptConstructor,
|
||||
MiniGPT4VQAPostProcessor,
|
||||
)
|
||||
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(
|
||||
type='mmpretrain.COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='images/val2014',
|
||||
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
|
||||
ann_file='annotations/v2_mscoco_val2014_annotations.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
minigpt_4_vqav2_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
minigpt_4_vqav2_model = dict(
|
||||
type='minigpt-4',
|
||||
low_resource=False,
|
||||
img_size=224,
|
||||
max_length=10,
|
||||
llama_model='/path/to/vicuna-7b/',
|
||||
prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor,
|
||||
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||
reply_prompt='###Assistant:'),
|
||||
post_processor=dict(type=MiniGPT4VQAPostProcessor))
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]
|
||||
|
||||
minigpt_4_vqav2_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
|
52
configs/multimodal/minigpt_4/minigpt_4_7b_vsr.py
Normal file
52
configs/multimodal/minigpt_4/minigpt_4_7b_vsr.py
Normal file
@ -0,0 +1,52 @@
|
||||
from opencompass.multimodal.models.minigpt_4 import (
|
||||
MiniGPT4VSRPromptConstructor,
|
||||
MiniGPT4VSRPostProcessor,
|
||||
)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
)
|
||||
]
|
||||
|
||||
dataset = dict(type='mmpretrain.VSR',
|
||||
data_root='data/vsr/',
|
||||
data_prefix='images/',
|
||||
ann_file='annotations/test.json',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
minigpt_4_vsr_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
minigpt_4_vsr_model = dict(
|
||||
type='minigpt-4',
|
||||
low_resource=True,
|
||||
img_size=224,
|
||||
max_length=10,
|
||||
llama_model='/path/to/vicuna-7b/',
|
||||
prompt_constructor=dict(type=MiniGPT4VSRPromptConstructor,
|
||||
image_prompt='###Human: <Img><ImageHere></Img>',
|
||||
reply_prompt='###Assistant:'),
|
||||
post_processor=dict(type=MiniGPT4VSRPostProcessor))
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]
|
||||
|
||||
minigpt_4_vsr_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
|
@ -12,4 +12,4 @@ evaluators = [minigpt_4_evaluator]
|
||||
load_froms = [minigpt_4_load_from]
|
||||
num_gpus = 8
|
||||
num_procs = 8
|
||||
launcher = 'pytorch'
|
||||
launcher = 'pytorch'
|
@ -1,8 +1,20 @@
|
||||
from .minigpt_4 import MiniGPT4MMBench
|
||||
from .post_processor import MiniGPT4PostProcessor
|
||||
from .prompt_constructor import MiniGPT4MMBenchPromptConstructor
|
||||
from .minigpt_4 import MiniGPT4Inferencer
|
||||
from .post_processor import (MiniGPT4COCOCaptionPostProcessor,
|
||||
MiniGPT4MMBenchPostProcessor,
|
||||
MiniGPT4ScienceQAPostProcessor,
|
||||
MiniGPT4VQAPostProcessor,
|
||||
MiniGPT4VSRPostProcessor)
|
||||
from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor,
|
||||
MiniGPT4MMBenchPromptConstructor,
|
||||
MiniGPT4ScienceQAPromptConstructor,
|
||||
MiniGPT4VQAPromptConstructor,
|
||||
MiniGPT4VSRPromptConstructor)
|
||||
|
||||
__all__ = [
|
||||
'MiniGPT4MMBench', 'MiniGPT4PostProcessor',
|
||||
'MiniGPT4MMBenchPromptConstructor'
|
||||
'MiniGPT4Inferencer', 'MiniGPT4MMBenchPostProcessor',
|
||||
'MiniGPT4MMBenchPromptConstructor', 'MiniGPT4COCOCaotionPromptConstructor',
|
||||
'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor',
|
||||
'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor',
|
||||
'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor',
|
||||
'MiniGPT4VSRPromptConstructor'
|
||||
]
|
||||
|
@ -37,14 +37,17 @@ def load_package():
|
||||
MiniGPT4 = load_package()
|
||||
|
||||
|
||||
@MM_MODELS.register_module('minigpt-4-mmbench')
|
||||
class MiniGPT4MMBench(MiniGPT4):
|
||||
"""Inference code of MiniGPT-4 on MMBench.
|
||||
@MM_MODELS.register_module('minigpt-4')
|
||||
class MiniGPT4Inferencer(MiniGPT4):
|
||||
"""Inference code of MiniGPT-4.
|
||||
|
||||
Args:
|
||||
llama_model (str): The path of vicuna path.
|
||||
prompt_constructor (dict): The config of prompt constructor.
|
||||
post_processor (dict): The config of post processor.
|
||||
do_sample (bool): Whether use sampling. Defaults to False.
|
||||
max_length (int): The max length of output. Defaults to 30.
|
||||
img_size (int): The size of image. Defaults to 224.
|
||||
low_resource (bool): Whether loaded in low precision.
|
||||
Defaults to False.
|
||||
"""
|
||||
@ -53,8 +56,13 @@ class MiniGPT4MMBench(MiniGPT4):
|
||||
llama_model: str,
|
||||
prompt_constructor: dict,
|
||||
post_processor: dict,
|
||||
do_sample: bool = False,
|
||||
max_length: int = 30,
|
||||
img_size: int = 224,
|
||||
low_resource: bool = False) -> None:
|
||||
super().__init__(llama_model=llama_model, low_resource=low_resource)
|
||||
super().__init__(llama_model=llama_model,
|
||||
low_resource=low_resource,
|
||||
img_size=img_size)
|
||||
|
||||
cur_device = get_device()
|
||||
stop_words_ids = [
|
||||
@ -67,6 +75,8 @@ class MiniGPT4MMBench(MiniGPT4):
|
||||
prompt_constructor, MM_MODELS)
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
self.do_sample = do_sample
|
||||
self.max_length = max_length
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
@ -125,9 +135,9 @@ class MiniGPT4MMBench(MiniGPT4):
|
||||
# generate output
|
||||
outputs = self.llama_model.generate(
|
||||
inputs_embeds=prompt_embs,
|
||||
max_new_tokens=20,
|
||||
max_length=self.max_length,
|
||||
num_beams=5,
|
||||
do_sample=False,
|
||||
do_sample=self.do_sample,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
|
@ -1,9 +1,10 @@
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MiniGPT4PostProcessor:
|
||||
class MiniGPT4MMBenchPostProcessor:
|
||||
""""Post processor for MiniGPT-4 on MMBench."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -32,3 +33,89 @@ class MiniGPT4PostProcessor:
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
||||
|
||||
|
||||
class MiniGPT4COCOCaptionPostProcessor:
|
||||
""""Post processor for MiniGPT-4 on COCO Caption."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.split('. ')[0]
|
||||
output_text = output_text.strip('<Img>')
|
||||
output_text = output_text.strip()
|
||||
return output_text
|
||||
|
||||
|
||||
class MiniGPT4ScienceQAPostProcessor:
|
||||
""""Post processor for MiniGPT-4 on ScienceQA."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
pattern = re.compile(r'\(([A-Z])\)')
|
||||
output_text = pattern.findall(output_text)
|
||||
if len(output_text) == 0:
|
||||
output_text = random.choice(['A', 'B', 'C', 'D'])
|
||||
else:
|
||||
output_text = output_text[0]
|
||||
return output_text
|
||||
|
||||
|
||||
class MiniGPT4VQAPostProcessor:
|
||||
""""Post processor for MiniGPT-4 on VQA."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
return output_text
|
||||
|
||||
|
||||
class MiniGPT4VSRPostProcessor:
|
||||
""""Post processor for MiniGPT-4 on VSR."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = tokenizer.decode(output_token, add_special_tokens=False)
|
||||
pattern = r'yes|no|Yes|No'
|
||||
output_text = re.findall(pattern, output_text)
|
||||
if len(output_text) > 0:
|
||||
output_text = output_text[0].lower()
|
||||
return output_text
|
||||
|
@ -53,3 +53,68 @@ class MiniGPT4MMBenchPromptConstructor:
|
||||
else:
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
||||
|
||||
|
||||
class MiniGPT4COCOCaotionPromptConstructor(MiniGPT4MMBenchPromptConstructor):
|
||||
"""Prompt constructor for MiniGPT-4 on COCO Caption."""
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
prompt = self.image_prompt + ' ' + 'a photo of' + self.reply_prompt
|
||||
return prompt
|
||||
|
||||
|
||||
class MiniGPT4ScienceQAPromptConstructor(MiniGPT4MMBenchPromptConstructor):
|
||||
"""Prompt constructor for MiniGPT-4 on ScienceQA."""
|
||||
|
||||
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
'Question: ' + data_sample.get('question') + '\n'
|
||||
for data_sample in data_samples
|
||||
] # noqa
|
||||
choices = [data_sample.get('choices') for data_sample in data_samples]
|
||||
choices = [[
|
||||
f'({self.choice_mapping[i]}) ' + item
|
||||
for i, item in enumerate(choice)
|
||||
] for choice in choices]
|
||||
choices = [
|
||||
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
|
||||
] # noqa
|
||||
contexts = [
|
||||
'Context: ' + data_sample.get('hint') + '\n'
|
||||
for data_sample in data_samples
|
||||
] # noqa
|
||||
question = questions[0]
|
||||
choice = choices[0]
|
||||
context = contexts[0]
|
||||
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + choice + self.reply_prompt + ' ' + 'The answer is' # noqa
|
||||
return prompt
|
||||
|
||||
|
||||
class MiniGPT4VQAPromptConstructor(MiniGPT4MMBenchPromptConstructor):
|
||||
"""Prompt constructor for MiniGPT-4 on VQA."""
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
question = questions[0]
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + 'Answer this question in a single word.' + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
||||
|
||||
|
||||
class MiniGPT4VSRPromptConstructor(MiniGPT4MMBenchPromptConstructor):
|
||||
"""Prompt constructor for MiniGPT-4 on VSR."""
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
question = questions[0]
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
||||
|
Loading…
Reference in New Issue
Block a user