[Feat] Support visualglm and llava for MMBench evaluation. (#211)

* [Feat] Support visualglm inference on MMBench.

* [Feat] Support llava inference on MMBench.

* [Fix] Fix pre-commit format.

* [Fix] Add docstring for llava

* [Fix] Fix multi-process inference error of LlaVA and add comments.
1. Set `low_cpu_mem_usage` to False to address device issue.
2. Add docstring and type hints.
3. Rename class and remove registry.

* [Fix] Pre-commit fix.

* [Fix] add forward entry, add dynamic import to seedbench

* [Fix] Fix pre-commit.

* [Fix] Fix missing context.

* [Fix] Fix docstring.
This commit is contained in:
Yike Yuan 2023-08-21 15:57:30 +08:00 committed by GitHub
parent a6552224cb
commit 8d368d1cd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 478 additions and 1 deletions

View File

@ -0,0 +1,10 @@
# LLaVA
### Prepare the environment
```sh
cd opencompass/multimodal/models/llava
git clone https://github.com/haotian-liu/LLaVA.git
```
Then prepare the environement according to the [install instruction](https://github.com/haotian-liu/LLaVA/tree/main#install)

View File

@ -0,0 +1,43 @@
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'category', 'l2-category', 'context', 'index',
'options_dict', 'options', 'split'
],
),
]
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
mmbench_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_model = dict(
type='llava',
model_path='/path/to/llava',
) # noqa
# evaluation settings
mmbench_evaluator = [
dict(type='opencompass.DumpResults',
save_path='work_dirs/llava-7b-mmbench.xlsx')
]

View File

@ -0,0 +1,41 @@
from opencompass.multimodal.models.visualglm import (VisualGLMPostProcessor, VisualGLMPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'options', 'category', 'l2-category', 'context',
'index', 'options_dict'
])
]
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
mmbench_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMPromptConstructor),
post_processor=dict(type=VisualGLMPostProcessor)
)
# evaluation settings
mmbench_evaluator = [
dict(type='opencompass.DumpResults',
save_path='work_dirs/visualglm-6b-mmbench.xlsx')
]

View File

@ -1,8 +1,8 @@
import importlib
import json import json
import os.path as osp import os.path as osp
from typing import List from typing import List
import av
import numpy as np import numpy as np
import torch import torch
from decord import VideoReader, cpu from decord import VideoReader, cpu
@ -116,6 +116,7 @@ class SEEDBenchDataset(Dataset):
if use_pyav: if use_pyav:
# using pyav for videos in evaluation dimension 12 # using pyav for videos in evaluation dimension 12
av = importlib.importmodule('av')
reader = av.open(data_path) reader = av.open(data_path)
frames = [ frames = [
torch.from_numpy(f.to_rgb().to_ndarray()) torch.from_numpy(f.to_rgb().to_ndarray())

View File

@ -2,4 +2,7 @@ from opencompass.utils import satisfy_requirement
if satisfy_requirement('salesforce-lavis'): if satisfy_requirement('salesforce-lavis'):
from .instructblip import * # noqa: F401, F403 from .instructblip import * # noqa: F401, F403
from .llava import * # noqa: F401, F403
from .minigpt_4 import * # noqa: F401, F403 from .minigpt_4 import * # noqa: F401, F403
from .visualglm import * # noqa: F401, F403

View File

@ -0,0 +1,3 @@
from .llava import LLaVA
__all__ = ['LLaVA']

View File

@ -0,0 +1,145 @@
import importlib
import os
import sys
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import StoppingCriteria
from opencompass.registry import MM_MODELS
from .prompt_constructor import LLaVAMMBenchPromptConstructor
IMAGE_TOKEN_INDEX = -200
def load_package():
"""Load required packages from LLaVA."""
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_folder_path, 'LLaVA')) # noqa
return
class KeywordsStoppingCriteria(StoppingCriteria):
"""Keyword stopping criteria implemented for llava."""
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor,
**kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(output_ids[:,
self.start_len:],
skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
@MM_MODELS.register_module('llava')
class LLaVA(nn.Module):
"""Inference code of LLaVA. Need to clone LLaVA official repo first. Please
check out the README in config.
Args:
model_path (str): The path of llava checkpoint.
"""
def __init__(self, model_path: str) -> None:
super().__init__()
self.dtype = torch.float16
# load LLaVA modules
load_package()
mm_utils = importlib.import_module('llava.mm_utils')
builder = importlib.import_module('llava.model.builder')
conversation = importlib.import_module('llava.conversation')
self.SeparatorStyle = conversation.SeparatorStyle
self.conv_templates = conversation.conv_templates
# load pretrained LLaVA
# Note: When encounters with device related errors,
# try setting `low_cpu_mem_usage` in `load_pretrained_model` as False
model_name = mm_utils.get_model_name_from_path(model_path)
tokenizer, model, _, _ = builder.load_pretrained_model(
model_path, None, model_name)
vision_tower = model.get_vision_tower()
vision_tower.to(device=get_device(), dtype=self.dtype)
model.to(device=get_device(), dtype=self.dtype)
# load prompt constructor and post processor
if 'v1' in model_path.lower():
conv_mode = 'llava_v1'
elif 'mpt' in model_path.lower():
conv_mode = 'mpt_multimodal'
else:
conv_mode = 'multimodal'
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
False)
self.model = model
self.tokenizer = tokenizer
self.prompt_constructor = LLaVAMMBenchPromptConstructor(
conv_templates=conversation.conv_templates,
conv_mode=conv_mode,
mm_use_im_start_end=mm_use_im_start_end)
def generate(self, batch):
prompt, stop_str = self.prompt_constructor(batch)
keywords = [stop_str]
data_sample = batch['data_samples'][0]
image = batch['inputs'][0].unsqueeze(0)
if image is not None:
images = image.to(get_device())
else:
images = None
mm_utils = importlib.import_module('llava.mm_utils')
input_ids = mm_utils.tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).to(get_device())
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer,
input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=images.half(),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria],
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids !=
output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(
f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids' # noqa
)
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
output_text = outputs.strip()
data_sample.pred_answer = output_text
return data_sample
def forward(self, batch):
return self.generate(batch)

View File

@ -0,0 +1,59 @@
import importlib
from typing import Any
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
class LLaVAMMBenchPromptConstructor:
"""Prompt constructor for LLaVA on MMBench.
Args:
conv_templates (Any): Conversation class to build prompt.
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
"""
def __init__(self, conv_templates: Any, conv_mode: str,
mm_use_im_start_end: bool) -> None:
self.conv_templates = conv_templates
self.conv_mode = conv_mode
self.mm_use_im_start_end = mm_use_im_start_end
conversation = importlib.import_module('llava.conversation')
self.SeparatorStyle = conversation.SeparatorStyle
def __call__(self, inputs: dict) -> tuple:
"""Construct prompt.
Args:
inputs (dict): Input data containing images and data_samples.
Returns:
tuple: A tuple containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
assert len(data_samples) == 1
question = data_samples[0].get('question')
options = data_samples[0].get('options')
context = data_samples[0].get('context')
if context is not None:
prompt = context + ' ' + question + ' ' + options
else:
prompt = question + ' ' + options
if self.mm_use_im_start_end:
prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN + '\n' + prompt)
else:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt # noqa
conv = self.conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
output_prompt = conv.get_prompt()
stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa
return output_prompt, stop_str

View File

@ -0,0 +1,5 @@
from .post_processor import VisualGLMPostProcessor
from .prompt_constructor import VisualGLMPromptConstructor
from .visualglm import VisualGLM
__all__ = ['VisualGLM', 'VisualGLMPostProcessor', 'VisualGLMPromptConstructor']

View File

@ -0,0 +1,14 @@
from typing import Any
import torch
class VisualGLMPostProcessor:
""""Post processor for VisualGLM on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer: Any,
input_len: int) -> str:
return tokenizer.decode(output_token[input_len:])

View File

@ -0,0 +1,55 @@
import torch
class VisualGLMPromptConstructor:
"""Prompt constructor for VisualGLM.
The overall prompt will be formulated as
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
image_prompt (str): Image prompt. (Default: '<img></img>')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def __init__(self,
system_prompt: str = '',
human_prompt: str = 'Q:',
image_prompt: str = '<img></img>',
assistant_prompt: str = 'A:') -> None:
self.image_prompt = image_prompt
self.system_prompt = system_prompt
self.human_prompt = human_prompt
self.assistant_prompt = assistant_prompt
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
Args:
batch (dict): Input data containing image and data_samples.
Returns:
tuple: A tuple containing prompt, images and data_samples.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [sample.get('question') for sample in data_samples]
options = [sample.get('options') for sample in data_samples]
contexts = [sample.get('context') for sample in data_samples]
contexts = [c if c else '' for c in contexts]
# generate text prompt
prompt = [
'{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
self.human_prompt, context, question,
option, self.assistant_prompt)
for context, question, option in zip(contexts, questions, options)
]
image_position = 5
return images, prompt, data_samples, image_position

View File

@ -0,0 +1,98 @@
from typing import Optional
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import AutoModel, AutoTokenizer
from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('visualglm')
class VisualGLM(nn.Module):
"""Inference code of VisualGLM.
We load the visualGLM model via Huggingface.
Args:
pretrained_path (str): Path to visualGLM checkpoint or repo id.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
gen_kwargs (dict): Customize generate function arguments.
"""
def __init__(self,
pretrained_path: str,
prompt_constructor: dict,
post_processor: dict,
gen_kwargs: Optional[dict] = None) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
trust_remote_code=True)
self.model = AutoModel.from_pretrained(pretrained_path,
trust_remote_code=True).half()
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
if gen_kwargs:
self.gen_kwargs = gen_kwargs
else:
self.gen_kwargs = dict()
def encode_by_tokenizer(self, multi_prompts, image_position):
input_ids = []
max_seq_length = 0
for prompt in multi_prompts:
input0 = self.tokenizer.encode(prompt[:image_position],
add_special_tokens=False)
input1 = [self.tokenizer.pad_token_id] * self.model.image_length
input2 = self.tokenizer.encode(prompt[image_position:],
add_special_tokens=False)
input_all = sum([input0, input1, input2], [])
input_all = self.tokenizer.build_inputs_with_special_tokens(
input_all)
max_seq_length = max(max_seq_length, len(input_all))
input_ids.append(input_all)
pre_image_len = len(input0)
# padding
for i, _ in enumerate(input_ids):
pad_len = max_seq_length - len(input_ids[i])
input_ids[i] = [self.tokenizer.pad_token_id
] * pad_len + input_ids[i]
return input_ids, pre_image_len
def generate(self, batch):
# process input
image, prompt, data_sample, image_position = self.prompt_constructor(
batch)
image = image.to(self.model.dtype).to(get_device())
# tokenize
input_all, pre_image_len = self.encode_by_tokenizer(
prompt, image_position)
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
# build input param
inputs = {
'input_ids': input_all,
'pre_image_length': pre_image_len,
'images': image
}
# generate answer
outputs = self.model.generate(**inputs, **self.gen_kwargs)
# format output
outputs = outputs.tolist()
for i, sample in enumerate(data_sample):
data_sample[i].pred_answer = self.post_processor(
outputs[i], self.tokenizer, input_all.shape[1])
return data_sample
def forward(self, batch):
return self.generate(batch)