From 8d368d1cd603067b9e21520bc5c8d92138b64986 Mon Sep 17 00:00:00 2001 From: Yike Yuan <32432002+yyk-wew@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:57:30 +0800 Subject: [PATCH] [Feat] Support visualglm and llava for MMBench evaluation. (#211) * [Feat] Support visualglm inference on MMBench. * [Feat] Support llava inference on MMBench. * [Fix] Fix pre-commit format. * [Fix] Add docstring for llava * [Fix] Fix multi-process inference error of LlaVA and add comments. 1. Set `low_cpu_mem_usage` to False to address device issue. 2. Add docstring and type hints. 3. Rename class and remove registry. * [Fix] Pre-commit fix. * [Fix] add forward entry, add dynamic import to seedbench * [Fix] Fix pre-commit. * [Fix] Fix missing context. * [Fix] Fix docstring. --- configs/multimodal/llava/README.md | 10 ++ configs/multimodal/llava/llava_7b_mmbench.py | 43 ++++++ .../visualglm/visualglm_6b_mmbench.py | 41 +++++ opencompass/multimodal/datasets/seedbench.py | 3 +- opencompass/multimodal/models/__init__.py | 3 + .../multimodal/models/llava/__init__.py | 3 + opencompass/multimodal/models/llava/llava.py | 145 ++++++++++++++++++ .../models/llava/prompt_constructor.py | 59 +++++++ .../multimodal/models/visualglm/__init__.py | 5 + .../models/visualglm/post_processor.py | 14 ++ .../models/visualglm/prompt_constructor.py | 55 +++++++ .../multimodal/models/visualglm/visualglm.py | 98 ++++++++++++ 12 files changed, 478 insertions(+), 1 deletion(-) create mode 100644 configs/multimodal/llava/README.md create mode 100644 configs/multimodal/llava/llava_7b_mmbench.py create mode 100644 configs/multimodal/visualglm/visualglm_6b_mmbench.py create mode 100644 opencompass/multimodal/models/llava/__init__.py create mode 100644 opencompass/multimodal/models/llava/llava.py create mode 100644 opencompass/multimodal/models/llava/prompt_constructor.py create mode 100644 opencompass/multimodal/models/visualglm/__init__.py create mode 100644 opencompass/multimodal/models/visualglm/post_processor.py create mode 100644 opencompass/multimodal/models/visualglm/prompt_constructor.py create mode 100644 opencompass/multimodal/models/visualglm/visualglm.py diff --git a/configs/multimodal/llava/README.md b/configs/multimodal/llava/README.md new file mode 100644 index 00000000..8c2d7fac --- /dev/null +++ b/configs/multimodal/llava/README.md @@ -0,0 +1,10 @@ +# LLaVA + +### Prepare the environment + +```sh +cd opencompass/multimodal/models/llava +git clone https://github.com/haotian-liu/LLaVA.git +``` + +Then prepare the environement according to the [install instruction](https://github.com/haotian-liu/LLaVA/tree/main#install) \ No newline at end of file diff --git a/configs/multimodal/llava/llava_7b_mmbench.py b/configs/multimodal/llava/llava_7b_mmbench.py new file mode 100644 index 00000000..9bef7e8f --- /dev/null +++ b/configs/multimodal/llava/llava_7b_mmbench.py @@ -0,0 +1,43 @@ +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'category', 'l2-category', 'context', 'index', + 'options_dict', 'options', 'split' + ], + ), +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +mmbench_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +llava_model = dict( + type='llava', + model_path='/path/to/llava', +) # noqa + +# evaluation settings +mmbench_evaluator = [ + dict(type='opencompass.DumpResults', + save_path='work_dirs/llava-7b-mmbench.xlsx') +] diff --git a/configs/multimodal/visualglm/visualglm_6b_mmbench.py b/configs/multimodal/visualglm/visualglm_6b_mmbench.py new file mode 100644 index 00000000..bd50b5b0 --- /dev/null +++ b/configs/multimodal/visualglm/visualglm_6b_mmbench.py @@ -0,0 +1,41 @@ +from opencompass.multimodal.models.visualglm import (VisualGLMPostProcessor, VisualGLMPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'options', 'category', 'l2-category', 'context', + 'index', 'options_dict' + ]) +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +mmbench_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +visualglm_model = dict( + type='visualglm', + pretrained_path='/path/to/visualglm', # or Huggingface repo id + prompt_constructor=dict(type=VisualGLMPromptConstructor), + post_processor=dict(type=VisualGLMPostProcessor) +) + +# evaluation settings +mmbench_evaluator = [ + dict(type='opencompass.DumpResults', + save_path='work_dirs/visualglm-6b-mmbench.xlsx') +] diff --git a/opencompass/multimodal/datasets/seedbench.py b/opencompass/multimodal/datasets/seedbench.py index 1e03c9e5..068d2bca 100644 --- a/opencompass/multimodal/datasets/seedbench.py +++ b/opencompass/multimodal/datasets/seedbench.py @@ -1,8 +1,8 @@ +import importlib import json import os.path as osp from typing import List -import av import numpy as np import torch from decord import VideoReader, cpu @@ -116,6 +116,7 @@ class SEEDBenchDataset(Dataset): if use_pyav: # using pyav for videos in evaluation dimension 12 + av = importlib.importmodule('av') reader = av.open(data_path) frames = [ torch.from_numpy(f.to_rgb().to_ndarray()) diff --git a/opencompass/multimodal/models/__init__.py b/opencompass/multimodal/models/__init__.py index 3747a125..72465706 100644 --- a/opencompass/multimodal/models/__init__.py +++ b/opencompass/multimodal/models/__init__.py @@ -2,4 +2,7 @@ from opencompass.utils import satisfy_requirement if satisfy_requirement('salesforce-lavis'): from .instructblip import * # noqa: F401, F403 + +from .llava import * # noqa: F401, F403 from .minigpt_4 import * # noqa: F401, F403 +from .visualglm import * # noqa: F401, F403 diff --git a/opencompass/multimodal/models/llava/__init__.py b/opencompass/multimodal/models/llava/__init__.py new file mode 100644 index 00000000..5c367473 --- /dev/null +++ b/opencompass/multimodal/models/llava/__init__.py @@ -0,0 +1,3 @@ +from .llava import LLaVA + +__all__ = ['LLaVA'] diff --git a/opencompass/multimodal/models/llava/llava.py b/opencompass/multimodal/models/llava/llava.py new file mode 100644 index 00000000..046fbad8 --- /dev/null +++ b/opencompass/multimodal/models/llava/llava.py @@ -0,0 +1,145 @@ +import importlib +import os +import sys + +import torch +import torch.nn as nn +from mmengine.device import get_device +from transformers import StoppingCriteria + +from opencompass.registry import MM_MODELS + +from .prompt_constructor import LLaVAMMBenchPromptConstructor + +IMAGE_TOKEN_INDEX = -200 + + +def load_package(): + """Load required packages from LLaVA.""" + current_file_path = os.path.abspath(__file__) + current_folder_path = os.path.dirname(current_file_path) + + sys.path.append(os.path.join(current_folder_path, 'LLaVA')) # noqa + return + + +class KeywordsStoppingCriteria(StoppingCriteria): + """Keyword stopping criteria implemented for llava.""" + + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.tokenizer = tokenizer + self.start_len = None + self.input_ids = input_ids + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, + **kwargs) -> bool: + if self.start_len is None: + self.start_len = self.input_ids.shape[1] + else: + outputs = self.tokenizer.batch_decode(output_ids[:, + self.start_len:], + skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + +@MM_MODELS.register_module('llava') +class LLaVA(nn.Module): + """Inference code of LLaVA. Need to clone LLaVA official repo first. Please + check out the README in config. + + Args: + model_path (str): The path of llava checkpoint. + """ + + def __init__(self, model_path: str) -> None: + super().__init__() + self.dtype = torch.float16 + + # load LLaVA modules + load_package() + mm_utils = importlib.import_module('llava.mm_utils') + builder = importlib.import_module('llava.model.builder') + conversation = importlib.import_module('llava.conversation') + self.SeparatorStyle = conversation.SeparatorStyle + self.conv_templates = conversation.conv_templates + + # load pretrained LLaVA + # Note: When encounters with device related errors, + # try setting `low_cpu_mem_usage` in `load_pretrained_model` as False + model_name = mm_utils.get_model_name_from_path(model_path) + tokenizer, model, _, _ = builder.load_pretrained_model( + model_path, None, model_name) + vision_tower = model.get_vision_tower() + vision_tower.to(device=get_device(), dtype=self.dtype) + model.to(device=get_device(), dtype=self.dtype) + + # load prompt constructor and post processor + if 'v1' in model_path.lower(): + conv_mode = 'llava_v1' + elif 'mpt' in model_path.lower(): + conv_mode = 'mpt_multimodal' + else: + conv_mode = 'multimodal' + mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', + False) + + self.model = model + self.tokenizer = tokenizer + self.prompt_constructor = LLaVAMMBenchPromptConstructor( + conv_templates=conversation.conv_templates, + conv_mode=conv_mode, + mm_use_im_start_end=mm_use_im_start_end) + + def generate(self, batch): + + prompt, stop_str = self.prompt_constructor(batch) + keywords = [stop_str] + data_sample = batch['data_samples'][0] + + image = batch['inputs'][0].unsqueeze(0) + if image is not None: + images = image.to(get_device()) + else: + images = None + + mm_utils = importlib.import_module('llava.mm_utils') + input_ids = mm_utils.tokenizer_image_token( + prompt, self.tokenizer, IMAGE_TOKEN_INDEX, + return_tensors='pt').unsqueeze(0).to(get_device()) + + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, + input_ids) + + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + images=images.half(), + do_sample=True, + temperature=0.2, + max_new_tokens=1024, + stopping_criteria=[stopping_criteria], + ) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != + output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print( + f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids' # noqa + ) + outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], + skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + output_text = outputs.strip() + + data_sample.pred_answer = output_text + return data_sample + + def forward(self, batch): + return self.generate(batch) diff --git a/opencompass/multimodal/models/llava/prompt_constructor.py b/opencompass/multimodal/models/llava/prompt_constructor.py new file mode 100644 index 00000000..c055c207 --- /dev/null +++ b/opencompass/multimodal/models/llava/prompt_constructor.py @@ -0,0 +1,59 @@ +import importlib +from typing import Any + +DEFAULT_IMAGE_TOKEN = '' +DEFAULT_IMAGE_PATCH_TOKEN = '' +DEFAULT_IM_START_TOKEN = '' +DEFAULT_IM_END_TOKEN = '' + + +class LLaVAMMBenchPromptConstructor: + """Prompt constructor for LLaVA on MMBench. + + Args: + conv_templates (Any): Conversation class to build prompt. + conv_mode (str): Version control args for different version of LLaVA. + mm_use_im_start_end (bool): + Config arg. Use start and end token when build prompt or not. + """ + + def __init__(self, conv_templates: Any, conv_mode: str, + mm_use_im_start_end: bool) -> None: + self.conv_templates = conv_templates + self.conv_mode = conv_mode + self.mm_use_im_start_end = mm_use_im_start_end + conversation = importlib.import_module('llava.conversation') + self.SeparatorStyle = conversation.SeparatorStyle + + def __call__(self, inputs: dict) -> tuple: + """Construct prompt. + + Args: + inputs (dict): Input data containing images and data_samples. + + Returns: + tuple: A tuple containing prompt, images and data_samples. + """ + data_samples = inputs['data_samples'] + assert len(data_samples) == 1 + question = data_samples[0].get('question') + options = data_samples[0].get('options') + context = data_samples[0].get('context') + if context is not None: + prompt = context + ' ' + question + ' ' + options + else: + prompt = question + ' ' + options + if self.mm_use_im_start_end: + prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + + DEFAULT_IM_END_TOKEN + '\n' + prompt) + else: + prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt # noqa + + conv = self.conv_templates[self.conv_mode].copy() + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + output_prompt = conv.get_prompt() + + stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa + + return output_prompt, stop_str diff --git a/opencompass/multimodal/models/visualglm/__init__.py b/opencompass/multimodal/models/visualglm/__init__.py new file mode 100644 index 00000000..69b12b4a --- /dev/null +++ b/opencompass/multimodal/models/visualglm/__init__.py @@ -0,0 +1,5 @@ +from .post_processor import VisualGLMPostProcessor +from .prompt_constructor import VisualGLMPromptConstructor +from .visualglm import VisualGLM + +__all__ = ['VisualGLM', 'VisualGLMPostProcessor', 'VisualGLMPromptConstructor'] diff --git a/opencompass/multimodal/models/visualglm/post_processor.py b/opencompass/multimodal/models/visualglm/post_processor.py new file mode 100644 index 00000000..ce048ea9 --- /dev/null +++ b/opencompass/multimodal/models/visualglm/post_processor.py @@ -0,0 +1,14 @@ +from typing import Any + +import torch + + +class VisualGLMPostProcessor: + """"Post processor for VisualGLM on MMBench.""" + + def __init__(self) -> None: + pass + + def __call__(self, output_token: torch.tensor, tokenizer: Any, + input_len: int) -> str: + return tokenizer.decode(output_token[input_len:]) diff --git a/opencompass/multimodal/models/visualglm/prompt_constructor.py b/opencompass/multimodal/models/visualglm/prompt_constructor.py new file mode 100644 index 00000000..3ff50f17 --- /dev/null +++ b/opencompass/multimodal/models/visualglm/prompt_constructor.py @@ -0,0 +1,55 @@ +import torch + + +class VisualGLMPromptConstructor: + """Prompt constructor for VisualGLM. + + The overall prompt will be formulated as + "system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt". + Args: + system_prompt (str): System prompt. (Default: '') + human_prompt (str): Human prompt. (Default: 'Q:') + image_prompt (str): Image prompt. (Default: '') + assistant_prompt (str): Assistant prompt. (Default: 'A:') + """ + + def __init__(self, + system_prompt: str = '', + human_prompt: str = 'Q:', + image_prompt: str = '', + assistant_prompt: str = 'A:') -> None: + self.image_prompt = image_prompt + self.system_prompt = system_prompt + self.human_prompt = human_prompt + self.assistant_prompt = assistant_prompt + + def __call__(self, batch: dict) -> tuple: + """Construct prompt. + + Args: + batch (dict): Input data containing image and data_samples. + + Returns: + tuple: A tuple containing prompt, images and data_samples. + """ + + images = batch.pop('inputs') + images = torch.stack(images, dim=0) + + data_samples = batch.pop('data_samples') + questions = [sample.get('question') for sample in data_samples] + options = [sample.get('options') for sample in data_samples] + contexts = [sample.get('context') for sample in data_samples] + contexts = [c if c else '' for c in contexts] + + # generate text prompt + prompt = [ + '{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt, + self.human_prompt, context, question, + option, self.assistant_prompt) + for context, question, option in zip(contexts, questions, options) + ] + + image_position = 5 + + return images, prompt, data_samples, image_position diff --git a/opencompass/multimodal/models/visualglm/visualglm.py b/opencompass/multimodal/models/visualglm/visualglm.py new file mode 100644 index 00000000..e5b103bc --- /dev/null +++ b/opencompass/multimodal/models/visualglm/visualglm.py @@ -0,0 +1,98 @@ +from typing import Optional + +import mmengine +import torch +import torch.nn as nn +from mmengine.device import get_device +from transformers import AutoModel, AutoTokenizer + +from opencompass.registry import MM_MODELS + + +@MM_MODELS.register_module('visualglm') +class VisualGLM(nn.Module): + """Inference code of VisualGLM. + + We load the visualGLM model via Huggingface. + Args: + pretrained_path (str): Path to visualGLM checkpoint or repo id. + prompt_constructor (dict): The config of prompt constructor. + post_processor (dict): The config of post processor. + gen_kwargs (dict): Customize generate function arguments. + """ + + def __init__(self, + pretrained_path: str, + prompt_constructor: dict, + post_processor: dict, + gen_kwargs: Optional[dict] = None) -> None: + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path, + trust_remote_code=True) + self.model = AutoModel.from_pretrained(pretrained_path, + trust_remote_code=True).half() + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) + + if gen_kwargs: + self.gen_kwargs = gen_kwargs + else: + self.gen_kwargs = dict() + + def encode_by_tokenizer(self, multi_prompts, image_position): + input_ids = [] + max_seq_length = 0 + for prompt in multi_prompts: + input0 = self.tokenizer.encode(prompt[:image_position], + add_special_tokens=False) + input1 = [self.tokenizer.pad_token_id] * self.model.image_length + input2 = self.tokenizer.encode(prompt[image_position:], + add_special_tokens=False) + input_all = sum([input0, input1, input2], []) + input_all = self.tokenizer.build_inputs_with_special_tokens( + input_all) + max_seq_length = max(max_seq_length, len(input_all)) + input_ids.append(input_all) + pre_image_len = len(input0) + + # padding + for i, _ in enumerate(input_ids): + pad_len = max_seq_length - len(input_ids[i]) + input_ids[i] = [self.tokenizer.pad_token_id + ] * pad_len + input_ids[i] + + return input_ids, pre_image_len + + def generate(self, batch): + # process input + image, prompt, data_sample, image_position = self.prompt_constructor( + batch) + image = image.to(self.model.dtype).to(get_device()) + + # tokenize + input_all, pre_image_len = self.encode_by_tokenizer( + prompt, image_position) + + input_all = torch.tensor(input_all, dtype=torch.long).to(get_device()) + + # build input param + inputs = { + 'input_ids': input_all, + 'pre_image_length': pre_image_len, + 'images': image + } + # generate answer + outputs = self.model.generate(**inputs, **self.gen_kwargs) + + # format output + outputs = outputs.tolist() + for i, sample in enumerate(data_sample): + data_sample[i].pred_answer = self.post_processor( + outputs[i], self.tokenizer, input_all.shape[1]) + + return data_sample + + def forward(self, batch): + return self.generate(batch)