diff --git a/configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py b/configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py new file mode 100644 index 00000000..051bcb73 --- /dev/null +++ b/configs/multimodal/minigpt_4/minigpt_4_7b_seedbench.py @@ -0,0 +1,63 @@ +from opencompass.multimodal.models.minigpt_4 import MiniGPT4SEEDBenchPromptConstructor # noqa + +# dataloader settings +image_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'answer', 'choices', 'data_type', 'question_type_id', + 'index', 'data_path', 'question_id' + ]) +] +video_pipeline = [ + dict(type='mmaction.Resize', scale=(224, 224), interpolation='bicubic'), + dict(type='mmaction.CenterCrop', crop_size=224), + dict(type='Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'answer', 'choices', 'data_type', 'question_type_id', + 'index', 'data_path', 'question_id' + ]) +] + +dataset = dict( + type='opencompass.SEEDBenchDataset', + ann_file='data/seedbench/SEED-Bench.json', + cc3m_path='data/seedbench/SEED-Bench-image', + sthv2_path='data/seedbench/sthv2/videos', + epic_kitchens_path='data/seedbench/3h91syskeag572hl6tvuovwv4d/videos/test', + breakfast_path='data/seedbench/BreakfastII_15fps_qvga_sync', + image_pipeline=image_pipeline, + video_pipeline=video_pipeline, + only_image=True) + +minigpt_4_seedbench_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', + shuffle=False)) + +# model settings +minigpt_4_seedbench_model = dict( + type='minigpt-4', + low_resource=False, + llama_model='/path/to/vicuna/', + prompt_constructor=dict(type=MiniGPT4SEEDBenchPromptConstructor, + image_prompt='###Human: ', + reply_prompt='###Assistant:'), + post_processor=None, + mode='loss') + +# evaluation settings +minigpt_4_seedbench_evaluator = [dict(type='opencompass.SEEDBenchAcc')] + +minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' diff --git a/configs/multimodal/seedbench.py b/configs/multimodal/seedbench.py new file mode 100644 index 00000000..efd45619 --- /dev/null +++ b/configs/multimodal/seedbench.py @@ -0,0 +1,14 @@ +from mmengine.config import read_base + +with read_base(): + from .minigpt_4.minigpt_4_7b_seedbench import ( + minigpt_4_seedbench_dataloader, minigpt_4_seedbench_evaluator, + minigpt_4_load_from, minigpt_4_seedbench_model) + +models = [minigpt_4_seedbench_model] +datasets = [minigpt_4_seedbench_dataloader] +evaluators = [minigpt_4_seedbench_evaluator] +load_froms = [minigpt_4_load_from] +num_gpus = 1 +num_procs = 1 +launcher = 'slurm' diff --git a/opencompass/metrics/__init__.py b/opencompass/metrics/__init__.py index 09333edf..68c04467 100644 --- a/opencompass/metrics/__init__.py +++ b/opencompass/metrics/__init__.py @@ -1,3 +1,4 @@ from .dump_results import DumpResults +from .seedbench import SEEDBenchAcc -__all__ = ['DumpResults'] +__all__ = ['DumpResults', 'SEEDBenchAcc'] diff --git a/opencompass/metrics/seedbench.py b/opencompass/metrics/seedbench.py new file mode 100644 index 00000000..e06de4c1 --- /dev/null +++ b/opencompass/metrics/seedbench.py @@ -0,0 +1,67 @@ +import torch +from mmengine.evaluator import BaseMetric + +from opencompass.registry import METRICS + +EVAL_DIM_MAPPING = { + 1: 'Scene Understanding', + 2: 'Instance Identity', + 3: 'Instance Attributes', + 4: 'Instance Location', + 5: 'Instance Counting', + 6: 'Spatial Relations', + 7: 'Instance Interaction', + 8: 'Visual Reasoning', + 9: 'Text Recognition', + 10: 'Action Recognition', + 11: 'Action Prediction', + 12: 'Procedure Understanding', +} + + +@METRICS.register_module() +class SEEDBenchAcc(BaseMetric): + """Compute results for SEED-Bench.""" + + def process(self, data_batch, data_samples) -> None: + for data_sample in data_samples: + losses = data_sample['losses'] + class_ranks = torch.argsort(losses, dim=-1).cpu() + pred_id = ['A', 'B', 'C', 'D'][class_ranks[0]] + answer_record = { + 'q_id': data_sample['question_id'], + 'prediction': pred_id, + 'gt': data_sample['answer'], + 'q_type_id': data_sample['question_type_id'], + 'losses': [str(num) for num in list(losses.cpu().numpy())], + } + self.results.append(answer_record) + + def compute_metrics(self, results: list) -> dict: + type_counts = {} + correct_counts = {} + out = {} + out['answer_records'] = results + for item in results: + pred, gt = item['prediction'], item['gt'] + data_type = item['q_type_id'] + + type_counts[data_type] = type_counts.get(data_type, 0) + 1 + if pred == gt: + correct_counts[data_type] = correct_counts.get(data_type, + 0) + 1 + + total_count = 0 + total_correct = 0 + for data_type in type_counts.keys(): + accuracy = correct_counts.get(data_type, + 0) / type_counts[data_type] * 100 + category = EVAL_DIM_MAPPING[data_type] + out[f'Data type {data_type} - {category}'] = accuracy + + total_count += type_counts[data_type] + total_correct += correct_counts.get(data_type, 0) + + total_accuracy = total_correct / total_count * 100 + out['Total accuracy'] = total_accuracy + return out diff --git a/opencompass/multimodal/datasets/__init__.py b/opencompass/multimodal/datasets/__init__.py index 47fa24ad..c9374c47 100644 --- a/opencompass/multimodal/datasets/__init__.py +++ b/opencompass/multimodal/datasets/__init__.py @@ -1,3 +1,4 @@ from .mmbench import MMBenchDataset +from .seedbench import SEEDBenchDataset -__all__ = ['MMBenchDataset'] +__all__ = ['MMBenchDataset', 'SEEDBenchDataset'] diff --git a/opencompass/multimodal/datasets/seedbench.py b/opencompass/multimodal/datasets/seedbench.py new file mode 100644 index 00000000..1e03c9e5 --- /dev/null +++ b/opencompass/multimodal/datasets/seedbench.py @@ -0,0 +1,173 @@ +import json +import os.path as osp +from typing import List + +import av +import numpy as np +import torch +from decord import VideoReader, cpu +from mmengine.dataset import Compose +from PIL import Image +from torch.utils.data import Dataset + +from opencompass.registry import DATASETS + + +@DATASETS.register_module() +class SEEDBenchDataset(Dataset): + """Dataset to load SEED-Bench dataset. + + Args: + ann_file (str): The path of the annotation file. + cc3m_path (str): The data path of the image dimension(1-9). + sthv2_path (str): The data path of the dimension 10. + epic_kitchens_path (str): The data path of the dimension 11. + breakfast_path (str): The data path of the dimension 12. + image_pipeline (List[dict]): The data transforms for image. + video_pipeline (List[dict]): The data transforms for video. + only_image (bool): Whether run SEED-Bench only with image data. + Defaults to True. + """ + + def __init__( + self, + ann_file: str, + cc3m_path: str, + sthv2_path: str, + epic_kitchens_path: str, + breakfast_path: str, + image_pipeline: List[dict], + video_pipeline: List[dict], + only_image: bool = True, + ) -> None: + ann_file = json.load(open(ann_file, 'rb')) + if 'questions' in ann_file.keys(): + self.ann_file = ann_file['questions'] + self.cc3m_path = cc3m_path + self.sthv2_path = sthv2_path + self.epic_kitchens_path = epic_kitchens_path + self.breakfast_path = breakfast_path + self.image_pipeline = Compose(image_pipeline) + if only_image: + image_ann_file = [ + ann for ann in self.ann_file if ann['data_type'] == 'image' + ] + self.ann_file = image_ann_file + if not only_image: + raise NotImplementedError + self.video_pipeline = Compose(video_pipeline) + + def __len__(self) -> None: + return len(self.ann_file) + + def __getitem__(self, idx: str) -> dict: + item = self.ann_file[idx] + data = { + 'question': + item['question'], + 'answer': + item['answer'], + 'choices': [ + item['choice_a'], item['choice_b'], item['choice_c'], + item['choice_d'] + ], + 'data_type': + item['data_type'], + 'question_id': + item['question_id'], + 'question_type_id': + item['question_type_id'], + 'index': + idx, + } + + if item['data_type'] == 'image': + data_path = osp.join(self.cc3m_path, item['data_id']) + raw_image = Image.open(open(data_path, 'rb')).convert('RGB') + data['data_path'] = data_path + data['img'] = raw_image + data = self.image_pipeline(data) + elif item['data_type'] == 'video': + if item['question_type_id'] == 10: + data_path = osp.join(self.sthv2_path, item['data_id']) + data['data_path'] = data_path + elif item['question_type_id'] == 11: + data_path = osp.join(self.epic_kitchens_path, item['data_id']) + data['data_path'] = data_path + data['segment'] = item['segment'] + elif item['question_type_id'] == 12: + data_path = osp.join(self.breakfast_path, item['data_id']) + data['data_path'] = data_path + data['segment'] = item['segment'] + else: + raise ValueError('The question type id is not valid.') + + # preprocessing videos in evaluation dimension 10-12 + use_pyav = False + if 'segment' in data.keys(): + segment = data['segment'] + if isinstance(segment[0], int): + # using pyav for decoding videos in evaluation dimension 12 + use_pyav = True + start, end = segment[0], segment[1] + else: + start = 0.0 + end = 0.0 + + if use_pyav: + # using pyav for videos in evaluation dimension 12 + reader = av.open(data_path) + frames = [ + torch.from_numpy(f.to_rgb().to_ndarray()) + for f in reader.decode(video=0) + ] + video_len = len(frames) + start_frame, end_frame = start, end + end_frame = min(end_frame, video_len) + offset = self.get_index(end_frame - start_frame, 8) + frame_indices = offset + start_frame + buffer = torch.stack([frames[idx] for idx in frame_indices]) + buffer = buffer.numpy() + else: + # using decord for videos in evaluating dimension 10-11 + import io + + import mmengine.fileio as fileio + file_obj = io.BytesIO(fileio.get(data_path)) + vr = VideoReader(file_obj, num_threads=1, ctx=cpu(0)) + video_len = len(vr) + fps = vr.get_avg_fps() + if 'segment' in data.keys(): + # obtain start and end frame for the video segment + # in evaluation dimension 11 + start_frame = int(min(max(start * fps, 0), video_len - 1)) + end_frame = int(min(max(end * fps, 0), video_len - 1)) + tot_frames = int(end_frame - start_frame) + offset = self.get_index(tot_frames, 8) + frame_indices = offset + start_frame + else: + # sample frames of the video in evaluation dimension 10 + frame_indices = self.get_index(video_len - 1, 8) + vr.seek(0) + buffer = vr.get_batch(frame_indices) + buffer = buffer.asnumpy() + data['imgs'] = buffer + data = self.video_pipeline(data) + + else: + raise ValueError('The data type is not valid.') + + return data + + def get_index(self, num_frames, num_segments): + if num_segments > num_frames: + offsets = np.array([idx for idx in range(num_frames)]) + else: + # uniform sampling + seg_size = float(num_frames - 1) / num_segments + start = int(seg_size / 2) + offsets = np.array([ + start + int(np.round(seg_size * idx)) + for idx in range(num_segments) + ]) + return offsets diff --git a/opencompass/multimodal/models/minigpt_4/__init__.py b/opencompass/multimodal/models/minigpt_4/__init__.py index 56e2cc69..715eb01e 100644 --- a/opencompass/multimodal/models/minigpt_4/__init__.py +++ b/opencompass/multimodal/models/minigpt_4/__init__.py @@ -7,6 +7,7 @@ from .post_processor import (MiniGPT4COCOCaptionPostProcessor, from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor, MiniGPT4MMBenchPromptConstructor, MiniGPT4ScienceQAPromptConstructor, + MiniGPT4SEEDBenchPromptConstructor, MiniGPT4VQAPromptConstructor, MiniGPT4VSRPromptConstructor) @@ -16,5 +17,5 @@ __all__ = [ 'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor', 'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor', 'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor', - 'MiniGPT4VSRPromptConstructor' + 'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor' ] diff --git a/opencompass/multimodal/models/minigpt_4/minigpt_4.py b/opencompass/multimodal/models/minigpt_4/minigpt_4.py index eee0e3dc..d7c1e36c 100644 --- a/opencompass/multimodal/models/minigpt_4/minigpt_4.py +++ b/opencompass/multimodal/models/minigpt_4/minigpt_4.py @@ -59,10 +59,14 @@ class MiniGPT4Inferencer(MiniGPT4): do_sample: bool = False, max_length: int = 30, img_size: int = 224, - low_resource: bool = False) -> None: + low_resource: bool = False, + mode: str = 'generation', + n_segments: int = 1) -> None: super().__init__(llama_model=llama_model, low_resource=low_resource, img_size=img_size) + self.mode = mode + self.n_segments = n_segments cur_device = get_device() stop_words_ids = [ @@ -71,34 +75,73 @@ class MiniGPT4Inferencer(MiniGPT4): ] self.stopping_criteria = StoppingCriteriaList( [StoppingCriteriaSub(stops=stop_words_ids)]) + self.prompt_constructor = mmengine.registry.build_from_cfg( prompt_constructor, MM_MODELS) - self.post_processor = mmengine.registry.build_from_cfg( - post_processor, MM_MODELS) + if post_processor is not None: + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) self.do_sample = do_sample self.max_length = max_length + def forward(self, batch): + if self.mode == 'generation': + return self.generate(batch) + elif self.mode == 'loss': + return self.loss(batch) + else: + raise RuntimeError(f'Invalid mode "{self.mode}".') + def encode_img(self, image): device = image.device with self.maybe_autocast(): - image_embeds = self.ln_vision( - self.visual_encoder(image)).to(device) - image_atts = torch.ones(image_embeds.size()[:-1], - dtype=torch.long).to(device) + if image.dim() == 5: + inputs_llama, atts_llama = [], [] + for j in range(image.size(2)): + this_frame = image[:, :, j, :, :] + frame_embeds = self.ln_vision( + self.visual_encoder(this_frame)) + frame_atts = torch.ones(frame_embeds.size()[:-1], + dtype=torch.long).to(image.device) - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, - -1) - query_output = self.Qformer.bert( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_atts, - return_dict=True, - ) + query_tokens = self.query_tokens.expand( + frame_embeds.shape[0], -1, -1) + frame_query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=frame_embeds, + encoder_attention_mask=frame_atts, + return_dict=True, + ) - inputs_llama = self.llama_proj(query_output.last_hidden_state) - atts_llama = torch.ones(inputs_llama.size()[:-1], - dtype=torch.long).to(image.device) + frame_inputs_llama = self.llama_proj( + frame_query_output.last_hidden_state[:, :query_tokens. + size(1), :]) + frame_atts_llama = torch.ones( + frame_inputs_llama.size()[:-1], + dtype=torch.long).to(image.device) + inputs_llama.append(frame_inputs_llama) + atts_llama.append(frame_atts_llama) + inputs_llama = torch.cat(inputs_llama, dim=1) + atts_llama = torch.cat(atts_llama, dim=1) + else: + image_embeds = self.ln_vision( + self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], + dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand( + image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + atts_llama = torch.ones(inputs_llama.size()[:-1], + dtype=torch.long).to(image.device) return inputs_llama, atts_llama def pack_inputs(self, batch): @@ -153,3 +196,87 @@ class MiniGPT4Inferencer(MiniGPT4): data_sample.pred_answer = output_text data_samples[i] = data_sample return data_samples + + def loss(self, batch): + inputs = self.pack_inputs(batch) + inputs = self.prompt_constructor(inputs) + image = inputs['image'] + batch_size = image.size(0) + prompt = inputs['prompt'] + data_samples = inputs['data_samples'] + choices = data_samples[0].choices + + with torch.no_grad(): + img_embeds, atts_img = self.encode_img(image) + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, + prompt) + + self.llama_tokenizer.padding_side = 'right' + + n_cands = len(choices) + losses = [] + for n in range(self.n_segments): + seg_len = n_cands // self.n_segments + if n == (self.n_segments - 1): + seg_len = n_cands - seg_len * (self.n_segments - 1) + + to_regress_tokens = self.llama_tokenizer( + choices, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False).to(image.device) + + targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == + self.llama_tokenizer.pad_token_id, -100) + + empty_targets = ( + torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], + dtype=torch.long).to(image.device).fill_( + -100) # plus one for bos + ) + empty_targets = empty_targets.repeat_interleave(seg_len, dim=0) + targets = torch.cat([empty_targets, targets], dim=1) + + bos = torch.ones([batch_size, 1], + dtype=to_regress_tokens.input_ids.dtype, + device=to_regress_tokens.input_ids.device + ) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + bos_embeds = bos_embeds.repeat_interleave(seg_len, dim=0) + img_embeds = img_embeds.repeat_interleave(seg_len, dim=0) + + atts_bos = atts_img[:, :1] + atts_bos = atts_bos.repeat_interleave(seg_len, dim=0) + atts_img = atts_img.repeat_interleave(seg_len, dim=0) + + to_regress_embeds = self.llama_model.model.embed_tokens( + to_regress_tokens.input_ids) + + inputs_embeds = torch.cat( + [bos_embeds, img_embeds, to_regress_embeds], dim=1) + attention_mask = torch.cat( + [atts_bos, atts_img, to_regress_tokens.attention_mask], + dim=1) + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + reduction='none', + ) + loss = outputs.loss + loss = loss.view(targets.size(0), -1).sum(1) + loss = loss.reshape(batch_size, seg_len) + losses.append(loss) + # losses of 4 choices + losses = torch.cat(losses, dim=-1)[0] + + for i, data_sample in enumerate(data_samples): + data_sample.losses = losses + data_samples[i] = data_sample + return data_samples diff --git a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py index aec42b95..55c8300e 100644 --- a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py +++ b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py @@ -118,3 +118,23 @@ class MiniGPT4VSRPromptConstructor(MiniGPT4MMBenchPromptConstructor): question = questions[0] prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa return prompt + + +class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor): + + def _process(self, data_samples: List[DataSample]) -> str: + """Process data sample to prompt. + + Args: + data_samples (List[DataSample]): A list of data_samples. + + Returns: + str: Prompt. + """ + assert len(data_samples) == 1, 'Only support batch size 1.' + questions = [ + data_sample.get('question') for data_sample in data_samples + ] + question = questions[0] + prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt + return prompt diff --git a/opencompass/tasks/mm_infer.py b/opencompass/tasks/mm_infer.py index ce269284..f8ed1e45 100644 --- a/opencompass/tasks/mm_infer.py +++ b/opencompass/tasks/mm_infer.py @@ -127,9 +127,9 @@ class MultimodalInferTask: for batch in track_iter_progress(dataloader): if dist.is_initialized(): - data_samples = model.module.generate(batch) + data_samples = model.module.forward(batch) else: - data_samples = model.generate(batch) + data_samples = model.forward(batch) if not isinstance(data_samples, Sequence): data_samples = [data_samples] evaluator.process(data_samples)