[Feat] Support public dataset of visualglm and llava. (#265)

* [Feat] Add public dataset support of VisualGLM.

* [Feat] Refactor LLaVA.

* [Feat] Add public dataset support of LlaVA.

* [Fix] Add  arg.
This commit is contained in:
Yike Yuan 2023-08-25 15:44:32 +08:00 committed by GitHub
parent dc6e54f6f4
commit 3f601f420b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1313 additions and 45 deletions

View File

@ -0,0 +1,50 @@
from opencompass.multimodal.models.llava import LLaVABasePromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id']),
]
dataset = dict(type='mmpretrain.COCOCaption',
data_root='data/coco',
data_prefix=dict(img_path='images'),
ann_file='annotations/coco_karpathy_val.json',
pipeline=val_pipeline)
llava_coco_caption_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_coco_caption_model = dict(
type='llava',
model_path='/path/to/llava',
is_caption_task=True,
prompt_constructor=dict(type=LLaVABasePromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_coco_caption_evaluator = [
dict(
type='mmpretrain.COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
) # noqa
]

View File

@ -0,0 +1,52 @@
from opencompass.multimodal.models.llava import LLaVABasePromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id']),
]
dataset = dict(type='mmpretrain.Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=val_pipeline)
llava_flickr30k_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_flickr30k_model = dict(
type='llava',
model_path='/path/to/llava',
is_caption_task=True,
prompt_constructor=dict(type=LLaVABasePromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_flickr30k_evaluator = [
dict(
type='mmpretrain.COCOCaption',
ann_file='data/flickr30k/annotations/flickr30k_val_gt.json',
) # noqa
]

View File

@ -0,0 +1,49 @@
from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=val_pipeline)
llava_gqa_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_gqa_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAVQAPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')]

View File

@ -1,3 +1,5 @@
from opencompass.multimodal.models.llava import LLaVAMMBenchPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
@ -34,6 +36,8 @@ mmbench_dataloader = dict(
llava_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAMMBenchPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings

View File

@ -0,0 +1,49 @@
from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.OCRVQA',
data_root='data/ocrvqa',
ann_file='annotations/dataset.json',
split='test',
data_prefix='images',
pipeline=val_pipeline)
llava_ocrvqa_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_ocrvqa_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAVQAPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_ocrvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,51 @@
from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.COCOVQA',
data_root='data/okvqa',
question_file='annotations/OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/mscoco_val2014_annotations.json',
pipeline=val_pipeline,
data_prefix='images/val2014',
)
llava_okvqa_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_okvqa_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAVQAPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_okvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,50 @@
from opencompass.multimodal.models.llava import LLaVAScienceQAPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
])
]
dataset = dict(type='mmpretrain.ScienceQA',
data_root='./data/scienceqa',
split='val',
split_file='pid_splits.json',
ann_file='problems.json',
image_only=True,
data_prefix=dict(img_path='val'),
pipeline=val_pipeline)
llava_scienceqa_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_scienceqa_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAScienceQAPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]

View File

@ -0,0 +1,50 @@
from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.TextVQA',
data_root='data/textvqa',
ann_file='annotations/TextVQA_0.5.1_val.json',
pipeline=val_pipeline,
data_prefix='images/train_images',
)
llava_textvqa_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_textvqa_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAVQAPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,48 @@
from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VizWiz',
data_root='data/vizwiz/',
data_prefix='Images/val',
ann_file='Annotations/val.json',
pipeline=val_pipeline)
llava_vizwiz_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_vizwiz_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAVQAPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,50 @@
from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVABasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.COCOVQA',
data_root='data/coco',
data_prefix='images/val2014',
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=val_pipeline)
llava_vqav2_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_vqav2_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAVQAPromptConstructor),
post_processor=dict(type=LLaVABasePostProcessor)
) # noqa
# evaluation settings
llava_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,48 @@
from opencompass.multimodal.models.llava import LLaVAVQAPromptConstructor, LLaVAVSRPostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VSR',
data_root='data/vsr/',
data_prefix='images/',
ann_file='annotations/test.json',
pipeline=val_pipeline)
llava_vsr_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_vsr_model = dict(
type='llava',
model_path='/path/to/llava',
prompt_constructor=dict(type=LLaVAVQAPromptConstructor),
post_processor=dict(type=LLaVAVSRPostProcessor)
) # noqa
# evaluation settings
llava_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]

View File

@ -0,0 +1,45 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMBasePromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
]
dataset = dict(type='mmpretrain.COCOCaption',
data_root='data/coco',
data_prefix=dict(img_path='images'),
ann_file='annotations/coco_karpathy_val.json',
pipeline=val_pipeline)
visualglm_coco_caption_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_coco_caption_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_coco_caption_evaluator = [
dict(
type='mmpretrain.COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
) # noqa
]

View File

@ -0,0 +1,46 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMBasePromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', algorithm_keys=['image_id'])
]
dataset = dict(type='mmpretrain.Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=val_pipeline)
visualglm_flickr30k_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_flickr30k_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_flickr30k_evaluator = [
dict(
type='mmpretrain.COCOCaption',
ann_file='data/flickr30k/annotations/flickr30k_val_gt.json',
) # noqa
]

View File

@ -0,0 +1,42 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=val_pipeline)
visualglm_gqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_gqa_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_gqa_evaluator = [dict(type='mmpretrain.GQAAcc')]

View File

@ -1,4 +1,4 @@
from opencompass.multimodal.models.visualglm import (VisualGLMPostProcessor, VisualGLMPromptConstructor)
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMMMBenchPromptConstructor)
# dataloader settings
val_pipeline = [
@ -30,8 +30,8 @@ mmbench_dataloader = dict(batch_size=1,
visualglm_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMPromptConstructor),
post_processor=dict(type=VisualGLMPostProcessor)
prompt_constructor=dict(type=VisualGLMMMBenchPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings

View File

@ -0,0 +1,43 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.OCRVQA',
data_root='data/ocrvqa',
ann_file='annotations/dataset.json',
split='test',
data_prefix='images',
pipeline=val_pipeline)
visualglm_ocrvqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_ocrvqa_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_ocrvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,45 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.COCOVQA',
data_root='data/okvqa',
question_file='annotations/OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/mscoco_val2014_annotations.json',
pipeline=val_pipeline,
data_prefix='images/val2014',
)
visualglm_okvqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_okvqa_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_okvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,44 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMScienceQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
])
]
dataset = dict(type='mmpretrain.ScienceQA',
data_root='./data/scienceqa',
split='val',
split_file='pid_splits.json',
ann_file='problems.json',
image_only=True,
data_prefix=dict(img_path='val'),
pipeline=val_pipeline)
visualglm_vizwiz_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_scienceqa_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMScienceQAPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]

View File

@ -0,0 +1,44 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.TextVQA',
data_root='data/textvqa',
ann_file='annotations/TextVQA_0.5.1_val.json',
pipeline=val_pipeline,
data_prefix='images/train_images',
)
visualglm_textvqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,42 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VizWiz',
data_root='data/vizwiz/',
data_prefix='Images/val',
ann_file='Annotations/val.json',
pipeline=val_pipeline)
visualglm_vizwiz_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,44 @@
from opencompass.multimodal.models.visualglm import (VisualGLMBasePostProcessor, VisualGLMVQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.COCOVQA',
data_root='data/coco',
data_prefix='images/val2014',
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=val_pipeline)
visualglm_vqav2_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor)
)
# evaluation settings
visualglm_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]

View File

@ -0,0 +1,43 @@
from opencompass.multimodal.models.visualglm import (VisualGLMVSRPostProcessor, VisualGLMVQAPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VSR',
data_root='data/vsr/',
data_prefix='images/',
ann_file='annotations/test.json',
pipeline=val_pipeline)
visualglm_vsr_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
post_processor=dict(type=VisualGLMVSRPostProcessor)
)
# evaluation settings
visualglm_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]

View File

@ -1,3 +1,12 @@
from .llava import LLaVA
from .post_processor import LLaVABasePostProcessor, LLaVAVSRPostProcessor
from .prompt_constructor import (LLaVABasePromptConstructor,
LLaVAMMBenchPromptConstructor,
LLaVAScienceQAPromptConstructor,
LLaVAVQAPromptConstructor)
__all__ = ['LLaVA']
__all__ = [
'LLaVA', 'LLaVABasePromptConstructor', 'LLaVAMMBenchPromptConstructor',
'LLaVABasePostProcessor', 'LLaVAVQAPromptConstructor',
'LLaVAScienceQAPromptConstructor', 'LLaVAVSRPostProcessor'
]

View File

@ -2,6 +2,7 @@ import importlib
import os
import sys
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
@ -9,8 +10,6 @@ from transformers import StoppingCriteria
from opencompass.registry import MM_MODELS
from .prompt_constructor import LLaVAMMBenchPromptConstructor
IMAGE_TOKEN_INDEX = -200
@ -53,19 +52,27 @@ class LLaVA(nn.Module):
Args:
model_path (str): The path of llava checkpoint.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
"""
def __init__(self, model_path: str) -> None:
def __init__(
self,
model_path: str,
prompt_constructor: dict,
post_processor: dict,
is_caption_task: bool = False,
) -> None:
super().__init__()
self.dtype = torch.float16
self.is_caption_task = is_caption_task
# load LLaVA modules
load_package()
mm_utils = importlib.import_module('llava.mm_utils')
builder = importlib.import_module('llava.model.builder')
conversation = importlib.import_module('llava.conversation')
self.SeparatorStyle = conversation.SeparatorStyle
self.conv_templates = conversation.conv_templates
# load pretrained LLaVA
# Note: When encounters with device related errors,
@ -86,13 +93,16 @@ class LLaVA(nn.Module):
conv_mode = 'multimodal'
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
False)
prompt_constructor.update({
'conv_mode': conv_mode,
'mm_use_im_start_end': mm_use_im_start_end
})
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.model = model
self.tokenizer = tokenizer
self.prompt_constructor = LLaVAMMBenchPromptConstructor(
conv_templates=conversation.conv_templates,
conv_mode=conv_mode,
mm_use_im_start_end=mm_use_im_start_end)
def generate(self, batch):
@ -133,11 +143,12 @@ class LLaVA(nn.Module):
)
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
output_text = outputs.strip()
output_text = self.post_processor(outputs, stop_str)
if self.is_caption_task:
data_sample.pred_caption = output_text
else:
data_sample.pred_answer = output_text
return data_sample

View File

@ -0,0 +1,28 @@
class LLaVABasePostProcessor:
"""Base post processor for LLaVA on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, outputs: str, stop_str: str) -> str:
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
output_text = outputs.strip()
return output_text
class LLaVAVSRPostProcessor(LLaVABasePostProcessor):
"""VSR post processor for LLaVA on MMBench."""
def __init__(self) -> None:
super().__init__()
def __call__(self, outputs: str, stop_str: str) -> str:
output_text = super().__call__(outputs, stop_str)
if 'yes' in output_text.lower():
return 'yes'
elif 'no' in output_text.lower():
return 'no'
else:
return 'unknown'

View File

@ -1,5 +1,4 @@
import importlib
from typing import Any
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
@ -7,23 +6,26 @@ DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
class LLaVAMMBenchPromptConstructor:
"""Prompt constructor for LLaVA on MMBench.
class LLaVABasePromptConstructor:
"""Base prompt constructor for LLaVA.
Args:
conv_templates (Any): Conversation class to build prompt.
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
def __init__(self, conv_templates: Any, conv_mode: str,
mm_use_im_start_end: bool) -> None:
self.conv_templates = conv_templates
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
conversation = importlib.import_module('llava.conversation')
self.conv_templates = conversation.conv_templates
self.conv_mode = conv_mode
self.mm_use_im_start_end = mm_use_im_start_end
conversation = importlib.import_module('llava.conversation')
self.SeparatorStyle = conversation.SeparatorStyle
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> tuple:
"""Construct prompt.
@ -36,13 +38,7 @@ class LLaVAMMBenchPromptConstructor:
"""
data_samples = inputs['data_samples']
assert len(data_samples) == 1
question = data_samples[0].get('question')
options = data_samples[0].get('options')
context = data_samples[0].get('context')
if context is not None:
prompt = context + ' ' + question + ' ' + options
else:
prompt = question + ' ' + options
prompt = self._build_prompt(data_samples[0])
if self.mm_use_im_start_end:
prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN + '\n' + prompt)
@ -57,3 +53,87 @@ class LLaVAMMBenchPromptConstructor:
stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa
return output_prompt, stop_str
def _build_prompt(self, data_sample):
return self.reply_prompt
class LLaVAMMBenchPromptConstructor(LLaVABasePromptConstructor):
"""MMBench prompt constructor for LLaVA.
Args:
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
super().__init__(conv_mode, mm_use_im_start_end, reply_prompt)
def _build_prompt(self, data_sample):
question = data_sample.get('question')
options = data_sample.get('options')
context = data_sample.get('context')
if context is not None:
prompt = context + ' ' + question + ' ' + options
else:
prompt = question + ' ' + options
prompt += self.reply_prompt
return prompt
class LLaVAVQAPromptConstructor(LLaVABasePromptConstructor):
"""VQA prompt constructor for LLaVA.
Args:
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
super().__init__(conv_mode, mm_use_im_start_end, reply_prompt)
def _build_prompt(self, data_sample):
prompt = data_sample.get('question')
prompt += self.reply_prompt
return prompt
class LLaVAScienceQAPromptConstructor(LLaVABasePromptConstructor):
"""ScienceQA prompt constructor for LLaVA.
Args:
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
super().__init__(conv_mode, mm_use_im_start_end, reply_prompt)
def _build_prompt(self, data_sample):
question = data_sample.get('question')
choices = data_sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
choices = 'Choices: ' + ' '.join(choices) + '\n'
context = 'Context: ' + data_sample.get('hint') + '\n'
prompt = context + question + choices + self.reply_prompt
return prompt

View File

@ -1,5 +1,15 @@
from .post_processor import VisualGLMPostProcessor
from .prompt_constructor import VisualGLMPromptConstructor
from .post_processor import (VisualGLMBasePostProcessor,
VisualGLMVSRPostProcessor)
from .prompt_constructor import (VisualGLMBasePromptConstructor,
VisualGLMIconQAPromptConstructor,
VisualGLMMMBenchPromptConstructor,
VisualGLMScienceQAPromptConstructor,
VisualGLMVQAPromptConstructor)
from .visualglm import VisualGLM
__all__ = ['VisualGLM', 'VisualGLMPostProcessor', 'VisualGLMPromptConstructor']
__all__ = [
'VisualGLM', 'VisualGLMBasePostProcessor', 'VisualGLMVSRPostProcessor',
'VisualGLMBasePromptConstructor', 'VisualGLMMMBenchPromptConstructor',
'VisualGLMVQAPromptConstructor', 'VisualGLMScienceQAPromptConstructor',
'VisualGLMIconQAPromptConstructor'
]

View File

@ -3,8 +3,8 @@ from typing import Any
import torch
class VisualGLMPostProcessor:
""""Post processor for VisualGLM on MMBench."""
class VisualGLMBasePostProcessor:
"""Base post processor for VisualGLM."""
def __init__(self) -> None:
pass
@ -12,3 +12,20 @@ class VisualGLMPostProcessor:
def __call__(self, output_token: torch.tensor, tokenizer: Any,
input_len: int) -> str:
return tokenizer.decode(output_token[input_len:])
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
"""VSR post processor for VisualGLM."""
def __init__(self) -> None:
super().__init__()
def __call__(self, output_token: torch.tensor, tokenizer: Any,
input_len: int) -> str:
output_text = tokenizer.decode(output_token[input_len:])
if 'yes' in output_text.lower():
return 'yes'
elif 'no' in output_text.lower():
return 'no'
else:
return 'unknown'

View File

@ -1,8 +1,8 @@
import torch
class VisualGLMPromptConstructor:
"""Prompt constructor for VisualGLM.
class VisualGLMMMBenchPromptConstructor:
"""MMBench prompt constructor for VisualGLM.
The overall prompt will be formulated as
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
@ -30,7 +30,7 @@ class VisualGLMPromptConstructor:
batch (dict): Input data containing image and data_samples.
Returns:
tuple: A tuple containing prompt, images and data_samples.
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
@ -53,3 +53,168 @@ class VisualGLMPromptConstructor:
image_position = 5
return images, prompt, data_samples, image_position
class VisualGLMBasePromptConstructor:
"""Base prompt constructor for VisualGLM.
The prompt will concat <img> and the given system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
"""
def __init__(self, system_prompt='') -> None:
self.prompt = system_prompt
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
Args:
batch (dict): Input data containing image and data_samples.
Returns:
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
# generate text prompt
img_prompt = '<img></img>'
prompt = img_prompt + self.prompt
image_position = prompt.rfind('<img>') + 5
image_position = 5
return images, prompt, data_samples, image_position
class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
"""VQA prompt constructor for VisualGLM.
The prompt will concat <img>, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
"""
def __init__(self, system_prompt='') -> None:
super().__init__(system_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
Args:
batch (dict): Input data containing image and data_samples.
Returns:
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [sample.get('question') for sample in data_samples]
# generate text prompt
prompt = [
'<img></img>Q:{} {}\nA:'.format(question, self.prompt)
for question in questions
]
image_position = 5
return images, prompt, data_samples, image_position
class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
"""ScienceQA prompt constructor for VisualGLM.
The prompt will concat image and all terms in a question.
Args:
system_prompt (str): System prompt. (Default: '')
"""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, system_prompt='') -> None:
super().__init__(system_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
Args:
batch (dict): Input data containing image and data_samples.
Returns:
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [
'Q: ' + sample.get('question') + '\n' for sample in data_samples
]
choices = [sample.get('choices') for sample in data_samples]
choices = [[
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choice)
] for choice in choices]
choices = [
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
] # noqa
contexts = [
'Context: ' + data_sample.get('hint') + '\n'
for data_sample in data_samples
] # noqa
# generate text prompt
prompt = [
'<img></img>' + context + question + choice + self.prompt
for context, question, choice in zip(contexts, questions, choices)
]
image_position = 5
return images, prompt, data_samples, image_position
class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
"""IconQA prompt constructor for VisualGLM.
The prompt will concat <img>, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
"""
def __init__(self, system_prompt='') -> None:
super().__init__(system_prompt)
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
Args:
batch (dict): Input data containing image and data_samples.
Returns:
A tuple containing images, prompt, data_samples and image_position.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [
'Q: ' + sample.get('question') + '\n' for sample in data_samples
]
choices = [sample.get('choices') for sample in data_samples]
choices = [
'Options: ' + ', '.join(choice) + '.\n' for choice in choices
] # noqa
# generate text prompt
prompt = [
'<img></img>' + question + choice + self.prompt
for question, choice in zip(questions, choices)
]
image_position = 5
return images, prompt, data_samples, image_position

View File

@ -18,13 +18,17 @@ class VisualGLM(nn.Module):
pretrained_path (str): Path to visualGLM checkpoint or repo id.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
gen_kwargs (dict): Customize generate function arguments.
Defaults to None.
"""
def __init__(self,
pretrained_path: str,
prompt_constructor: dict,
post_processor: dict,
is_caption_task: bool = False,
gen_kwargs: Optional[dict] = None) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
@ -40,6 +44,7 @@ class VisualGLM(nn.Module):
self.gen_kwargs = gen_kwargs
else:
self.gen_kwargs = dict()
self.is_caption_task = is_caption_task
def encode_by_tokenizer(self, multi_prompts, image_position):
input_ids = []
@ -89,8 +94,12 @@ class VisualGLM(nn.Module):
# format output
outputs = outputs.tolist()
for i, sample in enumerate(data_sample):
data_sample[i].pred_answer = self.post_processor(
outputs[i], self.tokenizer, input_all.shape[1])
answer = self.post_processor(outputs[i], self.tokenizer,
input_all.shape[1])
if self.is_caption_task:
data_sample[i].pred_caption = answer
else:
data_sample[i].pred_answer = answer
return data_sample