diff --git a/configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py b/configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py
new file mode 100644
index 00000000..051bcb73
--- /dev/null
+++ b/configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py
@@ -0,0 +1,63 @@
+from opencompass.multimodal.models.minigpt_4 import MiniGPT4SEEDBenchPromptConstructor # noqa
+
+# dataloader settings
+image_pipeline = [
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs',
+ algorithm_keys=[
+ 'question', 'answer', 'choices', 'data_type', 'question_type_id',
+ 'index', 'data_path', 'question_id'
+ ])
+]
+video_pipeline = [
+ dict(type='mmaction.Resize', scale=(224, 224), interpolation='bicubic'),
+ dict(type='mmaction.CenterCrop', crop_size=224),
+ dict(type='Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs',
+ algorithm_keys=[
+ 'question', 'answer', 'choices', 'data_type', 'question_type_id',
+ 'index', 'data_path', 'question_id'
+ ])
+]
+
+dataset = dict(
+ type='opencompass.SEEDBenchDataset',
+ ann_file='data/seedbench/SEED-Bench.json',
+ cc3m_path='data/seedbench/SEED-Bench-image',
+ sthv2_path='data/seedbench/sthv2/videos',
+ epic_kitchens_path='data/seedbench/3h91syskeag572hl6tvuovwv4d/videos/test',
+ breakfast_path='data/seedbench/BreakfastII_15fps_qvga_sync',
+ image_pipeline=image_pipeline,
+ video_pipeline=video_pipeline,
+ only_image=True)
+
+minigpt_4_seedbench_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
+
+# model settings
+minigpt_4_seedbench_model = dict(
+ type='minigpt-4',
+ low_resource=False,
+ llama_model='/path/to/vicuna/',
+ prompt_constructor=dict(type=MiniGPT4SEEDBenchPromptConstructor,
+ image_prompt='###Human:
',
+ reply_prompt='###Assistant:'),
+ post_processor=None,
+ mode='loss')
+
+# evaluation settings
+minigpt_4_seedbench_evaluator = [dict(type='opencompass.SEEDBenchAcc')]
+
+minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth'
diff --git a/configs/multimodal/seedbench.py b/configs/multimodal/seedbench.py
new file mode 100644
index 00000000..efd45619
--- /dev/null
+++ b/configs/multimodal/seedbench.py
@@ -0,0 +1,14 @@
+from mmengine.config import read_base
+
+with read_base():
+ from .minigpt_4.minigpt_4_7b_seedbench import (
+ minigpt_4_seedbench_dataloader, minigpt_4_seedbench_evaluator,
+ minigpt_4_load_from, minigpt_4_seedbench_model)
+
+models = [minigpt_4_seedbench_model]
+datasets = [minigpt_4_seedbench_dataloader]
+evaluators = [minigpt_4_seedbench_evaluator]
+load_froms = [minigpt_4_load_from]
+num_gpus = 1
+num_procs = 1
+launcher = 'slurm'
diff --git a/opencompass/metrics/__init__.py b/opencompass/metrics/__init__.py
index 09333edf..68c04467 100644
--- a/opencompass/metrics/__init__.py
+++ b/opencompass/metrics/__init__.py
@@ -1,3 +1,4 @@
from .dump_results import DumpResults
+from .seedbench import SEEDBenchAcc
-__all__ = ['DumpResults']
+__all__ = ['DumpResults', 'SEEDBenchAcc']
diff --git a/opencompass/metrics/seedbench.py b/opencompass/metrics/seedbench.py
new file mode 100644
index 00000000..e06de4c1
--- /dev/null
+++ b/opencompass/metrics/seedbench.py
@@ -0,0 +1,67 @@
+import torch
+from mmengine.evaluator import BaseMetric
+
+from opencompass.registry import METRICS
+
+EVAL_DIM_MAPPING = {
+ 1: 'Scene Understanding',
+ 2: 'Instance Identity',
+ 3: 'Instance Attributes',
+ 4: 'Instance Location',
+ 5: 'Instance Counting',
+ 6: 'Spatial Relations',
+ 7: 'Instance Interaction',
+ 8: 'Visual Reasoning',
+ 9: 'Text Recognition',
+ 10: 'Action Recognition',
+ 11: 'Action Prediction',
+ 12: 'Procedure Understanding',
+}
+
+
+@METRICS.register_module()
+class SEEDBenchAcc(BaseMetric):
+ """Compute results for SEED-Bench."""
+
+ def process(self, data_batch, data_samples) -> None:
+ for data_sample in data_samples:
+ losses = data_sample['losses']
+ class_ranks = torch.argsort(losses, dim=-1).cpu()
+ pred_id = ['A', 'B', 'C', 'D'][class_ranks[0]]
+ answer_record = {
+ 'q_id': data_sample['question_id'],
+ 'prediction': pred_id,
+ 'gt': data_sample['answer'],
+ 'q_type_id': data_sample['question_type_id'],
+ 'losses': [str(num) for num in list(losses.cpu().numpy())],
+ }
+ self.results.append(answer_record)
+
+ def compute_metrics(self, results: list) -> dict:
+ type_counts = {}
+ correct_counts = {}
+ out = {}
+ out['answer_records'] = results
+ for item in results:
+ pred, gt = item['prediction'], item['gt']
+ data_type = item['q_type_id']
+
+ type_counts[data_type] = type_counts.get(data_type, 0) + 1
+ if pred == gt:
+ correct_counts[data_type] = correct_counts.get(data_type,
+ 0) + 1
+
+ total_count = 0
+ total_correct = 0
+ for data_type in type_counts.keys():
+ accuracy = correct_counts.get(data_type,
+ 0) / type_counts[data_type] * 100
+ category = EVAL_DIM_MAPPING[data_type]
+ out[f'Data type {data_type} - {category}'] = accuracy
+
+ total_count += type_counts[data_type]
+ total_correct += correct_counts.get(data_type, 0)
+
+ total_accuracy = total_correct / total_count * 100
+ out['Total accuracy'] = total_accuracy
+ return out
diff --git a/opencompass/multimodal/datasets/__init__.py b/opencompass/multimodal/datasets/__init__.py
index 47fa24ad..c9374c47 100644
--- a/opencompass/multimodal/datasets/__init__.py
+++ b/opencompass/multimodal/datasets/__init__.py
@@ -1,3 +1,4 @@
from .mmbench import MMBenchDataset
+from .seedbench import SEEDBenchDataset
-__all__ = ['MMBenchDataset']
+__all__ = ['MMBenchDataset', 'SEEDBenchDataset']
diff --git a/opencompass/multimodal/datasets/seedbench.py b/opencompass/multimodal/datasets/seedbench.py
new file mode 100644
index 00000000..1e03c9e5
--- /dev/null
+++ b/opencompass/multimodal/datasets/seedbench.py
@@ -0,0 +1,173 @@
+import json
+import os.path as osp
+from typing import List
+
+import av
+import numpy as np
+import torch
+from decord import VideoReader, cpu
+from mmengine.dataset import Compose
+from PIL import Image
+from torch.utils.data import Dataset
+
+from opencompass.registry import DATASETS
+
+
+@DATASETS.register_module()
+class SEEDBenchDataset(Dataset):
+ """Dataset to load SEED-Bench dataset.
+
+ Args:
+ ann_file (str): The path of the annotation file.
+ cc3m_path (str): The data path of the image dimension(1-9).
+ sthv2_path (str): The data path of the dimension 10.
+ epic_kitchens_path (str): The data path of the dimension 11.
+ breakfast_path (str): The data path of the dimension 12.
+ image_pipeline (List[dict]): The data transforms for image.
+ video_pipeline (List[dict]): The data transforms for video.
+ only_image (bool): Whether run SEED-Bench only with image data.
+ Defaults to True.
+ """
+
+ def __init__(
+ self,
+ ann_file: str,
+ cc3m_path: str,
+ sthv2_path: str,
+ epic_kitchens_path: str,
+ breakfast_path: str,
+ image_pipeline: List[dict],
+ video_pipeline: List[dict],
+ only_image: bool = True,
+ ) -> None:
+ ann_file = json.load(open(ann_file, 'rb'))
+ if 'questions' in ann_file.keys():
+ self.ann_file = ann_file['questions']
+ self.cc3m_path = cc3m_path
+ self.sthv2_path = sthv2_path
+ self.epic_kitchens_path = epic_kitchens_path
+ self.breakfast_path = breakfast_path
+ self.image_pipeline = Compose(image_pipeline)
+ if only_image:
+ image_ann_file = [
+ ann for ann in self.ann_file if ann['data_type'] == 'image'
+ ]
+ self.ann_file = image_ann_file
+ if not only_image:
+ raise NotImplementedError
+ self.video_pipeline = Compose(video_pipeline)
+
+ def __len__(self) -> None:
+ return len(self.ann_file)
+
+ def __getitem__(self, idx: str) -> dict:
+ item = self.ann_file[idx]
+ data = {
+ 'question':
+ item['question'],
+ 'answer':
+ item['answer'],
+ 'choices': [
+ item['choice_a'], item['choice_b'], item['choice_c'],
+ item['choice_d']
+ ],
+ 'data_type':
+ item['data_type'],
+ 'question_id':
+ item['question_id'],
+ 'question_type_id':
+ item['question_type_id'],
+ 'index':
+ idx,
+ }
+
+ if item['data_type'] == 'image':
+ data_path = osp.join(self.cc3m_path, item['data_id'])
+ raw_image = Image.open(open(data_path, 'rb')).convert('RGB')
+ data['data_path'] = data_path
+ data['img'] = raw_image
+ data = self.image_pipeline(data)
+ elif item['data_type'] == 'video':
+ if item['question_type_id'] == 10:
+ data_path = osp.join(self.sthv2_path, item['data_id'])
+ data['data_path'] = data_path
+ elif item['question_type_id'] == 11:
+ data_path = osp.join(self.epic_kitchens_path, item['data_id'])
+ data['data_path'] = data_path
+ data['segment'] = item['segment']
+ elif item['question_type_id'] == 12:
+ data_path = osp.join(self.breakfast_path, item['data_id'])
+ data['data_path'] = data_path
+ data['segment'] = item['segment']
+ else:
+ raise ValueError('The question type id is not valid.')
+
+ # preprocessing videos in evaluation dimension 10-12
+ use_pyav = False
+ if 'segment' in data.keys():
+ segment = data['segment']
+ if isinstance(segment[0], int):
+ # using pyav for decoding videos in evaluation dimension 12
+ use_pyav = True
+ start, end = segment[0], segment[1]
+ else:
+ start = 0.0
+ end = 0.0
+
+ if use_pyav:
+ # using pyav for videos in evaluation dimension 12
+ reader = av.open(data_path)
+ frames = [
+ torch.from_numpy(f.to_rgb().to_ndarray())
+ for f in reader.decode(video=0)
+ ]
+ video_len = len(frames)
+ start_frame, end_frame = start, end
+ end_frame = min(end_frame, video_len)
+ offset = self.get_index(end_frame - start_frame, 8)
+ frame_indices = offset + start_frame
+ buffer = torch.stack([frames[idx] for idx in frame_indices])
+ buffer = buffer.numpy()
+ else:
+ # using decord for videos in evaluating dimension 10-11
+ import io
+
+ import mmengine.fileio as fileio
+ file_obj = io.BytesIO(fileio.get(data_path))
+ vr = VideoReader(file_obj, num_threads=1, ctx=cpu(0))
+ video_len = len(vr)
+ fps = vr.get_avg_fps()
+ if 'segment' in data.keys():
+ # obtain start and end frame for the video segment
+ # in evaluation dimension 11
+ start_frame = int(min(max(start * fps, 0), video_len - 1))
+ end_frame = int(min(max(end * fps, 0), video_len - 1))
+ tot_frames = int(end_frame - start_frame)
+ offset = self.get_index(tot_frames, 8)
+ frame_indices = offset + start_frame
+ else:
+ # sample frames of the video in evaluation dimension 10
+ frame_indices = self.get_index(video_len - 1, 8)
+ vr.seek(0)
+ buffer = vr.get_batch(frame_indices)
+ buffer = buffer.asnumpy()
+ data['imgs'] = buffer
+ data = self.video_pipeline(data)
+
+ else:
+ raise ValueError('The data type is not valid.')
+
+ return data
+
+ def get_index(self, num_frames, num_segments):
+ if num_segments > num_frames:
+ offsets = np.array([idx for idx in range(num_frames)])
+ else:
+ # uniform sampling
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx))
+ for idx in range(num_segments)
+ ])
+ return offsets
diff --git a/opencompass/multimodal/models/minigpt_4/__init__.py b/opencompass/multimodal/models/minigpt_4/__init__.py
index 56e2cc69..715eb01e 100644
--- a/opencompass/multimodal/models/minigpt_4/__init__.py
+++ b/opencompass/multimodal/models/minigpt_4/__init__.py
@@ -7,6 +7,7 @@ from .post_processor import (MiniGPT4COCOCaptionPostProcessor,
from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor,
MiniGPT4MMBenchPromptConstructor,
MiniGPT4ScienceQAPromptConstructor,
+ MiniGPT4SEEDBenchPromptConstructor,
MiniGPT4VQAPromptConstructor,
MiniGPT4VSRPromptConstructor)
@@ -16,5 +17,5 @@ __all__ = [
'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor',
'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor',
'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor',
- 'MiniGPT4VSRPromptConstructor'
+ 'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor'
]
diff --git a/opencompass/multimodal/models/minigpt_4/minigpt_4.py b/opencompass/multimodal/models/minigpt_4/minigpt_4.py
index eee0e3dc..d7c1e36c 100644
--- a/opencompass/multimodal/models/minigpt_4/minigpt_4.py
+++ b/opencompass/multimodal/models/minigpt_4/minigpt_4.py
@@ -59,10 +59,14 @@ class MiniGPT4Inferencer(MiniGPT4):
do_sample: bool = False,
max_length: int = 30,
img_size: int = 224,
- low_resource: bool = False) -> None:
+ low_resource: bool = False,
+ mode: str = 'generation',
+ n_segments: int = 1) -> None:
super().__init__(llama_model=llama_model,
low_resource=low_resource,
img_size=img_size)
+ self.mode = mode
+ self.n_segments = n_segments
cur_device = get_device()
stop_words_ids = [
@@ -71,34 +75,73 @@ class MiniGPT4Inferencer(MiniGPT4):
]
self.stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)])
+
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
- self.post_processor = mmengine.registry.build_from_cfg(
- post_processor, MM_MODELS)
+ if post_processor is not None:
+ self.post_processor = mmengine.registry.build_from_cfg(
+ post_processor, MM_MODELS)
self.do_sample = do_sample
self.max_length = max_length
+ def forward(self, batch):
+ if self.mode == 'generation':
+ return self.generate(batch)
+ elif self.mode == 'loss':
+ return self.loss(batch)
+ else:
+ raise RuntimeError(f'Invalid mode "{self.mode}".')
+
def encode_img(self, image):
device = image.device
with self.maybe_autocast():
- image_embeds = self.ln_vision(
- self.visual_encoder(image)).to(device)
- image_atts = torch.ones(image_embeds.size()[:-1],
- dtype=torch.long).to(device)
+ if image.dim() == 5:
+ inputs_llama, atts_llama = [], []
+ for j in range(image.size(2)):
+ this_frame = image[:, :, j, :, :]
+ frame_embeds = self.ln_vision(
+ self.visual_encoder(this_frame))
+ frame_atts = torch.ones(frame_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
- -1)
- query_output = self.Qformer.bert(
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_atts,
- return_dict=True,
- )
+ query_tokens = self.query_tokens.expand(
+ frame_embeds.shape[0], -1, -1)
+ frame_query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=frame_embeds,
+ encoder_attention_mask=frame_atts,
+ return_dict=True,
+ )
- inputs_llama = self.llama_proj(query_output.last_hidden_state)
- atts_llama = torch.ones(inputs_llama.size()[:-1],
- dtype=torch.long).to(image.device)
+ frame_inputs_llama = self.llama_proj(
+ frame_query_output.last_hidden_state[:, :query_tokens.
+ size(1), :])
+ frame_atts_llama = torch.ones(
+ frame_inputs_llama.size()[:-1],
+ dtype=torch.long).to(image.device)
+ inputs_llama.append(frame_inputs_llama)
+ atts_llama.append(frame_atts_llama)
+ inputs_llama = torch.cat(inputs_llama, dim=1)
+ atts_llama = torch.cat(atts_llama, dim=1)
+ else:
+ image_embeds = self.ln_vision(
+ self.visual_encoder(image)).to(device)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(device)
+
+ query_tokens = self.query_tokens.expand(
+ image_embeds.shape[0], -1, -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
+ atts_llama = torch.ones(inputs_llama.size()[:-1],
+ dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def pack_inputs(self, batch):
@@ -153,3 +196,87 @@ class MiniGPT4Inferencer(MiniGPT4):
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples
+
+ def loss(self, batch):
+ inputs = self.pack_inputs(batch)
+ inputs = self.prompt_constructor(inputs)
+ image = inputs['image']
+ batch_size = image.size(0)
+ prompt = inputs['prompt']
+ data_samples = inputs['data_samples']
+ choices = data_samples[0].choices
+
+ with torch.no_grad():
+ img_embeds, atts_img = self.encode_img(image)
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
+ prompt)
+
+ self.llama_tokenizer.padding_side = 'right'
+
+ n_cands = len(choices)
+ losses = []
+ for n in range(self.n_segments):
+ seg_len = n_cands // self.n_segments
+ if n == (self.n_segments - 1):
+ seg_len = n_cands - seg_len * (self.n_segments - 1)
+
+ to_regress_tokens = self.llama_tokenizer(
+ choices,
+ return_tensors='pt',
+ padding='longest',
+ truncation=True,
+ max_length=self.max_txt_len,
+ add_special_tokens=False).to(image.device)
+
+ targets = to_regress_tokens.input_ids.masked_fill(
+ to_regress_tokens.input_ids ==
+ self.llama_tokenizer.pad_token_id, -100)
+
+ empty_targets = (
+ torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
+ dtype=torch.long).to(image.device).fill_(
+ -100) # plus one for bos
+ )
+ empty_targets = empty_targets.repeat_interleave(seg_len, dim=0)
+ targets = torch.cat([empty_targets, targets], dim=1)
+
+ bos = torch.ones([batch_size, 1],
+ dtype=to_regress_tokens.input_ids.dtype,
+ device=to_regress_tokens.input_ids.device
+ ) * self.llama_tokenizer.bos_token_id
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
+ bos_embeds = bos_embeds.repeat_interleave(seg_len, dim=0)
+ img_embeds = img_embeds.repeat_interleave(seg_len, dim=0)
+
+ atts_bos = atts_img[:, :1]
+ atts_bos = atts_bos.repeat_interleave(seg_len, dim=0)
+ atts_img = atts_img.repeat_interleave(seg_len, dim=0)
+
+ to_regress_embeds = self.llama_model.model.embed_tokens(
+ to_regress_tokens.input_ids)
+
+ inputs_embeds = torch.cat(
+ [bos_embeds, img_embeds, to_regress_embeds], dim=1)
+ attention_mask = torch.cat(
+ [atts_bos, atts_img, to_regress_tokens.attention_mask],
+ dim=1)
+
+ with self.maybe_autocast():
+ outputs = self.llama_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ reduction='none',
+ )
+ loss = outputs.loss
+ loss = loss.view(targets.size(0), -1).sum(1)
+ loss = loss.reshape(batch_size, seg_len)
+ losses.append(loss)
+ # losses of 4 choices
+ losses = torch.cat(losses, dim=-1)[0]
+
+ for i, data_sample in enumerate(data_samples):
+ data_sample.losses = losses
+ data_samples[i] = data_sample
+ return data_samples
diff --git a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
index aec42b95..55c8300e 100644
--- a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
+++ b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
@@ -118,3 +118,23 @@ class MiniGPT4VSRPromptConstructor(MiniGPT4MMBenchPromptConstructor):
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
return prompt
+
+
+class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor):
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ """Process data sample to prompt.
+
+ Args:
+ data_samples (List[DataSample]): A list of data_samples.
+
+ Returns:
+ str: Prompt.
+ """
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ questions = [
+ data_sample.get('question') for data_sample in data_samples
+ ]
+ question = questions[0]
+ prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt
+ return prompt
diff --git a/opencompass/tasks/mm_infer.py b/opencompass/tasks/mm_infer.py
index ce269284..f8ed1e45 100644
--- a/opencompass/tasks/mm_infer.py
+++ b/opencompass/tasks/mm_infer.py
@@ -127,9 +127,9 @@ class MultimodalInferTask:
for batch in track_iter_progress(dataloader):
if dist.is_initialized():
- data_samples = model.module.generate(batch)
+ data_samples = model.module.forward(batch)
else:
- data_samples = model.generate(batch)
+ data_samples = model.forward(batch)
if not isinstance(data_samples, Sequence):
data_samples = [data_samples]
evaluator.process(data_samples)