[Feat] Support visualglm and llava for MMBench evaluation. (#211)

* [Feat] Support visualglm inference on MMBench.

* [Feat] Support llava inference on MMBench.

* [Fix] Fix pre-commit format.

* [Fix] Add docstring for llava

* [Fix] Fix multi-process inference error of LlaVA and add comments.
1. Set `low_cpu_mem_usage` to False to address device issue.
2. Add docstring and type hints.
3. Rename class and remove registry.

* [Fix] Pre-commit fix.

* [Fix] add forward entry, add dynamic import to seedbench

* [Fix] Fix pre-commit.

* [Fix] Fix missing context.

* [Fix] Fix docstring.
This commit is contained in:
Yike Yuan 2023-08-21 15:57:30 +08:00 committed by GitHub
parent a6552224cb
commit 8d368d1cd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 478 additions and 1 deletions

View File

@ -0,0 +1,10 @@
# LLaVA
### Prepare the environment
```sh
cd opencompass/multimodal/models/llava
git clone https://github.com/haotian-liu/LLaVA.git
```
Then prepare the environement according to the [install instruction](https://github.com/haotian-liu/LLaVA/tree/main#install)

View File

@ -0,0 +1,43 @@
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(
type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'category', 'l2-category', 'context', 'index',
'options_dict', 'options', 'split'
],
),
]
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
mmbench_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# model settings
llava_model = dict(
type='llava',
model_path='/path/to/llava',
) # noqa
# evaluation settings
mmbench_evaluator = [
dict(type='opencompass.DumpResults',
save_path='work_dirs/llava-7b-mmbench.xlsx')
]

View File

@ -0,0 +1,41 @@
from opencompass.multimodal.models.visualglm import (VisualGLMPostProcessor, VisualGLMPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'options', 'category', 'l2-category', 'context',
'index', 'options_dict'
])
]
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
mmbench_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
visualglm_model = dict(
type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMPromptConstructor),
post_processor=dict(type=VisualGLMPostProcessor)
)
# evaluation settings
mmbench_evaluator = [
dict(type='opencompass.DumpResults',
save_path='work_dirs/visualglm-6b-mmbench.xlsx')
]

View File

@ -1,8 +1,8 @@
import importlib
import json
import os.path as osp
from typing import List
import av
import numpy as np
import torch
from decord import VideoReader, cpu
@ -116,6 +116,7 @@ class SEEDBenchDataset(Dataset):
if use_pyav:
# using pyav for videos in evaluation dimension 12
av = importlib.importmodule('av')
reader = av.open(data_path)
frames = [
torch.from_numpy(f.to_rgb().to_ndarray())

View File

@ -2,4 +2,7 @@ from opencompass.utils import satisfy_requirement
if satisfy_requirement('salesforce-lavis'):
from .instructblip import * # noqa: F401, F403
from .llava import * # noqa: F401, F403
from .minigpt_4 import * # noqa: F401, F403
from .visualglm import * # noqa: F401, F403

View File

@ -0,0 +1,3 @@
from .llava import LLaVA
__all__ = ['LLaVA']

View File

@ -0,0 +1,145 @@
import importlib
import os
import sys
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import StoppingCriteria
from opencompass.registry import MM_MODELS
from .prompt_constructor import LLaVAMMBenchPromptConstructor
IMAGE_TOKEN_INDEX = -200
def load_package():
"""Load required packages from LLaVA."""
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_folder_path, 'LLaVA')) # noqa
return
class KeywordsStoppingCriteria(StoppingCriteria):
"""Keyword stopping criteria implemented for llava."""
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor,
**kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(output_ids[:,
self.start_len:],
skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
@MM_MODELS.register_module('llava')
class LLaVA(nn.Module):
"""Inference code of LLaVA. Need to clone LLaVA official repo first. Please
check out the README in config.
Args:
model_path (str): The path of llava checkpoint.
"""
def __init__(self, model_path: str) -> None:
super().__init__()
self.dtype = torch.float16
# load LLaVA modules
load_package()
mm_utils = importlib.import_module('llava.mm_utils')
builder = importlib.import_module('llava.model.builder')
conversation = importlib.import_module('llava.conversation')
self.SeparatorStyle = conversation.SeparatorStyle
self.conv_templates = conversation.conv_templates
# load pretrained LLaVA
# Note: When encounters with device related errors,
# try setting `low_cpu_mem_usage` in `load_pretrained_model` as False
model_name = mm_utils.get_model_name_from_path(model_path)
tokenizer, model, _, _ = builder.load_pretrained_model(
model_path, None, model_name)
vision_tower = model.get_vision_tower()
vision_tower.to(device=get_device(), dtype=self.dtype)
model.to(device=get_device(), dtype=self.dtype)
# load prompt constructor and post processor
if 'v1' in model_path.lower():
conv_mode = 'llava_v1'
elif 'mpt' in model_path.lower():
conv_mode = 'mpt_multimodal'
else:
conv_mode = 'multimodal'
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
False)
self.model = model
self.tokenizer = tokenizer
self.prompt_constructor = LLaVAMMBenchPromptConstructor(
conv_templates=conversation.conv_templates,
conv_mode=conv_mode,
mm_use_im_start_end=mm_use_im_start_end)
def generate(self, batch):
prompt, stop_str = self.prompt_constructor(batch)
keywords = [stop_str]
data_sample = batch['data_samples'][0]
image = batch['inputs'][0].unsqueeze(0)
if image is not None:
images = image.to(get_device())
else:
images = None
mm_utils = importlib.import_module('llava.mm_utils')
input_ids = mm_utils.tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).to(get_device())
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer,
input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=images.half(),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria],
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids !=
output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(
f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids' # noqa
)
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
output_text = outputs.strip()
data_sample.pred_answer = output_text
return data_sample
def forward(self, batch):
return self.generate(batch)

View File

@ -0,0 +1,59 @@
import importlib
from typing import Any
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
class LLaVAMMBenchPromptConstructor:
"""Prompt constructor for LLaVA on MMBench.
Args:
conv_templates (Any): Conversation class to build prompt.
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
"""
def __init__(self, conv_templates: Any, conv_mode: str,
mm_use_im_start_end: bool) -> None:
self.conv_templates = conv_templates
self.conv_mode = conv_mode
self.mm_use_im_start_end = mm_use_im_start_end
conversation = importlib.import_module('llava.conversation')
self.SeparatorStyle = conversation.SeparatorStyle
def __call__(self, inputs: dict) -> tuple:
"""Construct prompt.
Args:
inputs (dict): Input data containing images and data_samples.
Returns:
tuple: A tuple containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
assert len(data_samples) == 1
question = data_samples[0].get('question')
options = data_samples[0].get('options')
context = data_samples[0].get('context')
if context is not None:
prompt = context + ' ' + question + ' ' + options
else:
prompt = question + ' ' + options
if self.mm_use_im_start_end:
prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN + '\n' + prompt)
else:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt # noqa
conv = self.conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
output_prompt = conv.get_prompt()
stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa
return output_prompt, stop_str

View File

@ -0,0 +1,5 @@
from .post_processor import VisualGLMPostProcessor
from .prompt_constructor import VisualGLMPromptConstructor
from .visualglm import VisualGLM
__all__ = ['VisualGLM', 'VisualGLMPostProcessor', 'VisualGLMPromptConstructor']

View File

@ -0,0 +1,14 @@
from typing import Any
import torch
class VisualGLMPostProcessor:
""""Post processor for VisualGLM on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer: Any,
input_len: int) -> str:
return tokenizer.decode(output_token[input_len:])

View File

@ -0,0 +1,55 @@
import torch
class VisualGLMPromptConstructor:
"""Prompt constructor for VisualGLM.
The overall prompt will be formulated as
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
image_prompt (str): Image prompt. (Default: '<img></img>')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def __init__(self,
system_prompt: str = '',
human_prompt: str = 'Q:',
image_prompt: str = '<img></img>',
assistant_prompt: str = 'A:') -> None:
self.image_prompt = image_prompt
self.system_prompt = system_prompt
self.human_prompt = human_prompt
self.assistant_prompt = assistant_prompt
def __call__(self, batch: dict) -> tuple:
"""Construct prompt.
Args:
batch (dict): Input data containing image and data_samples.
Returns:
tuple: A tuple containing prompt, images and data_samples.
"""
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
data_samples = batch.pop('data_samples')
questions = [sample.get('question') for sample in data_samples]
options = [sample.get('options') for sample in data_samples]
contexts = [sample.get('context') for sample in data_samples]
contexts = [c if c else '' for c in contexts]
# generate text prompt
prompt = [
'{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
self.human_prompt, context, question,
option, self.assistant_prompt)
for context, question, option in zip(contexts, questions, options)
]
image_position = 5
return images, prompt, data_samples, image_position

View File

@ -0,0 +1,98 @@
from typing import Optional
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import AutoModel, AutoTokenizer
from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('visualglm')
class VisualGLM(nn.Module):
"""Inference code of VisualGLM.
We load the visualGLM model via Huggingface.
Args:
pretrained_path (str): Path to visualGLM checkpoint or repo id.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
gen_kwargs (dict): Customize generate function arguments.
"""
def __init__(self,
pretrained_path: str,
prompt_constructor: dict,
post_processor: dict,
gen_kwargs: Optional[dict] = None) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
trust_remote_code=True)
self.model = AutoModel.from_pretrained(pretrained_path,
trust_remote_code=True).half()
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
if gen_kwargs:
self.gen_kwargs = gen_kwargs
else:
self.gen_kwargs = dict()
def encode_by_tokenizer(self, multi_prompts, image_position):
input_ids = []
max_seq_length = 0
for prompt in multi_prompts:
input0 = self.tokenizer.encode(prompt[:image_position],
add_special_tokens=False)
input1 = [self.tokenizer.pad_token_id] * self.model.image_length
input2 = self.tokenizer.encode(prompt[image_position:],
add_special_tokens=False)
input_all = sum([input0, input1, input2], [])
input_all = self.tokenizer.build_inputs_with_special_tokens(
input_all)
max_seq_length = max(max_seq_length, len(input_all))
input_ids.append(input_all)
pre_image_len = len(input0)
# padding
for i, _ in enumerate(input_ids):
pad_len = max_seq_length - len(input_ids[i])
input_ids[i] = [self.tokenizer.pad_token_id
] * pad_len + input_ids[i]
return input_ids, pre_image_len
def generate(self, batch):
# process input
image, prompt, data_sample, image_position = self.prompt_constructor(
batch)
image = image.to(self.model.dtype).to(get_device())
# tokenize
input_all, pre_image_len = self.encode_by_tokenizer(
prompt, image_position)
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
# build input param
inputs = {
'input_ids': input_all,
'pre_image_length': pre_image_len,
'images': image
}
# generate answer
outputs = self.model.generate(**inputs, **self.gen_kwargs)
# format output
outputs = outputs.tolist()
for i, sample in enumerate(data_sample):
data_sample[i].pred_answer = self.post_processor(
outputs[i], self.tokenizer, input_all.shape[1])
return data_sample
def forward(self, batch):
return self.generate(batch)