mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] Support visualglm and llava for MMBench evaluation. (#211)
* [Feat] Support visualglm inference on MMBench. * [Feat] Support llava inference on MMBench. * [Fix] Fix pre-commit format. * [Fix] Add docstring for llava * [Fix] Fix multi-process inference error of LlaVA and add comments. 1. Set `low_cpu_mem_usage` to False to address device issue. 2. Add docstring and type hints. 3. Rename class and remove registry. * [Fix] Pre-commit fix. * [Fix] add forward entry, add dynamic import to seedbench * [Fix] Fix pre-commit. * [Fix] Fix missing context. * [Fix] Fix docstring.
This commit is contained in:
parent
a6552224cb
commit
8d368d1cd6
10
configs/multimodal/llava/README.md
Normal file
10
configs/multimodal/llava/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# LLaVA
|
||||
|
||||
### Prepare the environment
|
||||
|
||||
```sh
|
||||
cd opencompass/multimodal/models/llava
|
||||
git clone https://github.com/haotian-liu/LLaVA.git
|
||||
```
|
||||
|
||||
Then prepare the environement according to the [install instruction](https://github.com/haotian-liu/LLaVA/tree/main#install)
|
43
configs/multimodal/llava/llava_7b_mmbench.py
Normal file
43
configs/multimodal/llava/llava_7b_mmbench.py
Normal file
@ -0,0 +1,43 @@
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(
|
||||
type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'category', 'l2-category', 'context', 'index',
|
||||
'options_dict', 'options', 'split'
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBenchDataset',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
mmbench_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
|
||||
# model settings
|
||||
llava_model = dict(
|
||||
type='llava',
|
||||
model_path='/path/to/llava',
|
||||
) # noqa
|
||||
|
||||
# evaluation settings
|
||||
mmbench_evaluator = [
|
||||
dict(type='opencompass.DumpResults',
|
||||
save_path='work_dirs/llava-7b-mmbench.xlsx')
|
||||
]
|
41
configs/multimodal/visualglm/visualglm_6b_mmbench.py
Normal file
41
configs/multimodal/visualglm/visualglm_6b_mmbench.py
Normal file
@ -0,0 +1,41 @@
|
||||
from opencompass.multimodal.models.visualglm import (VisualGLMPostProcessor, VisualGLMPromptConstructor)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'options', 'category', 'l2-category', 'context',
|
||||
'index', 'options_dict'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBenchDataset',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
mmbench_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
visualglm_model = dict(
|
||||
type='visualglm',
|
||||
pretrained_path='/path/to/visualglm', # or Huggingface repo id
|
||||
prompt_constructor=dict(type=VisualGLMPromptConstructor),
|
||||
post_processor=dict(type=VisualGLMPostProcessor)
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
mmbench_evaluator = [
|
||||
dict(type='opencompass.DumpResults',
|
||||
save_path='work_dirs/visualglm-6b-mmbench.xlsx')
|
||||
]
|
@ -1,8 +1,8 @@
|
||||
import importlib
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from decord import VideoReader, cpu
|
||||
@ -116,6 +116,7 @@ class SEEDBenchDataset(Dataset):
|
||||
|
||||
if use_pyav:
|
||||
# using pyav for videos in evaluation dimension 12
|
||||
av = importlib.importmodule('av')
|
||||
reader = av.open(data_path)
|
||||
frames = [
|
||||
torch.from_numpy(f.to_rgb().to_ndarray())
|
||||
|
@ -2,4 +2,7 @@ from opencompass.utils import satisfy_requirement
|
||||
|
||||
if satisfy_requirement('salesforce-lavis'):
|
||||
from .instructblip import * # noqa: F401, F403
|
||||
|
||||
from .llava import * # noqa: F401, F403
|
||||
from .minigpt_4 import * # noqa: F401, F403
|
||||
from .visualglm import * # noqa: F401, F403
|
||||
|
3
opencompass/multimodal/models/llava/__init__.py
Normal file
3
opencompass/multimodal/models/llava/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .llava import LLaVA
|
||||
|
||||
__all__ = ['LLaVA']
|
145
opencompass/multimodal/models/llava/llava.py
Normal file
145
opencompass/multimodal/models/llava/llava.py
Normal file
@ -0,0 +1,145 @@
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
from opencompass.registry import MM_MODELS
|
||||
|
||||
from .prompt_constructor import LLaVAMMBenchPromptConstructor
|
||||
|
||||
IMAGE_TOKEN_INDEX = -200
|
||||
|
||||
|
||||
def load_package():
|
||||
"""Load required packages from LLaVA."""
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_folder_path = os.path.dirname(current_file_path)
|
||||
|
||||
sys.path.append(os.path.join(current_folder_path, 'LLaVA')) # noqa
|
||||
return
|
||||
|
||||
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
"""Keyword stopping criteria implemented for llava."""
|
||||
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
self.tokenizer = tokenizer
|
||||
self.start_len = None
|
||||
self.input_ids = input_ids
|
||||
|
||||
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor,
|
||||
**kwargs) -> bool:
|
||||
if self.start_len is None:
|
||||
self.start_len = self.input_ids.shape[1]
|
||||
else:
|
||||
outputs = self.tokenizer.batch_decode(output_ids[:,
|
||||
self.start_len:],
|
||||
skip_special_tokens=True)[0]
|
||||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@MM_MODELS.register_module('llava')
|
||||
class LLaVA(nn.Module):
|
||||
"""Inference code of LLaVA. Need to clone LLaVA official repo first. Please
|
||||
check out the README in config.
|
||||
|
||||
Args:
|
||||
model_path (str): The path of llava checkpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str) -> None:
|
||||
super().__init__()
|
||||
self.dtype = torch.float16
|
||||
|
||||
# load LLaVA modules
|
||||
load_package()
|
||||
mm_utils = importlib.import_module('llava.mm_utils')
|
||||
builder = importlib.import_module('llava.model.builder')
|
||||
conversation = importlib.import_module('llava.conversation')
|
||||
self.SeparatorStyle = conversation.SeparatorStyle
|
||||
self.conv_templates = conversation.conv_templates
|
||||
|
||||
# load pretrained LLaVA
|
||||
# Note: When encounters with device related errors,
|
||||
# try setting `low_cpu_mem_usage` in `load_pretrained_model` as False
|
||||
model_name = mm_utils.get_model_name_from_path(model_path)
|
||||
tokenizer, model, _, _ = builder.load_pretrained_model(
|
||||
model_path, None, model_name)
|
||||
vision_tower = model.get_vision_tower()
|
||||
vision_tower.to(device=get_device(), dtype=self.dtype)
|
||||
model.to(device=get_device(), dtype=self.dtype)
|
||||
|
||||
# load prompt constructor and post processor
|
||||
if 'v1' in model_path.lower():
|
||||
conv_mode = 'llava_v1'
|
||||
elif 'mpt' in model_path.lower():
|
||||
conv_mode = 'mpt_multimodal'
|
||||
else:
|
||||
conv_mode = 'multimodal'
|
||||
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
|
||||
False)
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.prompt_constructor = LLaVAMMBenchPromptConstructor(
|
||||
conv_templates=conversation.conv_templates,
|
||||
conv_mode=conv_mode,
|
||||
mm_use_im_start_end=mm_use_im_start_end)
|
||||
|
||||
def generate(self, batch):
|
||||
|
||||
prompt, stop_str = self.prompt_constructor(batch)
|
||||
keywords = [stop_str]
|
||||
data_sample = batch['data_samples'][0]
|
||||
|
||||
image = batch['inputs'][0].unsqueeze(0)
|
||||
if image is not None:
|
||||
images = image.to(get_device())
|
||||
else:
|
||||
images = None
|
||||
|
||||
mm_utils = importlib.import_module('llava.mm_utils')
|
||||
input_ids = mm_utils.tokenizer_image_token(
|
||||
prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
|
||||
return_tensors='pt').unsqueeze(0).to(get_device())
|
||||
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer,
|
||||
input_ids)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = self.model.generate(
|
||||
input_ids,
|
||||
images=images.half(),
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=1024,
|
||||
stopping_criteria=[stopping_criteria],
|
||||
)
|
||||
|
||||
input_token_len = input_ids.shape[1]
|
||||
n_diff_input_output = (input_ids !=
|
||||
output_ids[:, :input_token_len]).sum().item()
|
||||
if n_diff_input_output > 0:
|
||||
print(
|
||||
f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids' # noqa
|
||||
)
|
||||
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
|
||||
skip_special_tokens=True)[0]
|
||||
outputs = outputs.strip()
|
||||
if outputs.endswith(stop_str):
|
||||
outputs = outputs[:-len(stop_str)]
|
||||
output_text = outputs.strip()
|
||||
|
||||
data_sample.pred_answer = output_text
|
||||
return data_sample
|
||||
|
||||
def forward(self, batch):
|
||||
return self.generate(batch)
|
59
opencompass/multimodal/models/llava/prompt_constructor.py
Normal file
59
opencompass/multimodal/models/llava/prompt_constructor.py
Normal file
@ -0,0 +1,59 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
DEFAULT_IMAGE_TOKEN = '<image>'
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
|
||||
DEFAULT_IM_START_TOKEN = '<im_start>'
|
||||
DEFAULT_IM_END_TOKEN = '<im_end>'
|
||||
|
||||
|
||||
class LLaVAMMBenchPromptConstructor:
|
||||
"""Prompt constructor for LLaVA on MMBench.
|
||||
|
||||
Args:
|
||||
conv_templates (Any): Conversation class to build prompt.
|
||||
conv_mode (str): Version control args for different version of LLaVA.
|
||||
mm_use_im_start_end (bool):
|
||||
Config arg. Use start and end token when build prompt or not.
|
||||
"""
|
||||
|
||||
def __init__(self, conv_templates: Any, conv_mode: str,
|
||||
mm_use_im_start_end: bool) -> None:
|
||||
self.conv_templates = conv_templates
|
||||
self.conv_mode = conv_mode
|
||||
self.mm_use_im_start_end = mm_use_im_start_end
|
||||
conversation = importlib.import_module('llava.conversation')
|
||||
self.SeparatorStyle = conversation.SeparatorStyle
|
||||
|
||||
def __call__(self, inputs: dict) -> tuple:
|
||||
"""Construct prompt.
|
||||
|
||||
Args:
|
||||
inputs (dict): Input data containing images and data_samples.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing prompt, images and data_samples.
|
||||
"""
|
||||
data_samples = inputs['data_samples']
|
||||
assert len(data_samples) == 1
|
||||
question = data_samples[0].get('question')
|
||||
options = data_samples[0].get('options')
|
||||
context = data_samples[0].get('context')
|
||||
if context is not None:
|
||||
prompt = context + ' ' + question + ' ' + options
|
||||
else:
|
||||
prompt = question + ' ' + options
|
||||
if self.mm_use_im_start_end:
|
||||
prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
|
||||
DEFAULT_IM_END_TOKEN + '\n' + prompt)
|
||||
else:
|
||||
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt # noqa
|
||||
|
||||
conv = self.conv_templates[self.conv_mode].copy()
|
||||
conv.append_message(conv.roles[0], prompt)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
output_prompt = conv.get_prompt()
|
||||
|
||||
stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa
|
||||
|
||||
return output_prompt, stop_str
|
5
opencompass/multimodal/models/visualglm/__init__.py
Normal file
5
opencompass/multimodal/models/visualglm/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .post_processor import VisualGLMPostProcessor
|
||||
from .prompt_constructor import VisualGLMPromptConstructor
|
||||
from .visualglm import VisualGLM
|
||||
|
||||
__all__ = ['VisualGLM', 'VisualGLMPostProcessor', 'VisualGLMPromptConstructor']
|
14
opencompass/multimodal/models/visualglm/post_processor.py
Normal file
14
opencompass/multimodal/models/visualglm/post_processor.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class VisualGLMPostProcessor:
|
||||
""""Post processor for VisualGLM on MMBench."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer: Any,
|
||||
input_len: int) -> str:
|
||||
return tokenizer.decode(output_token[input_len:])
|
@ -0,0 +1,55 @@
|
||||
import torch
|
||||
|
||||
|
||||
class VisualGLMPromptConstructor:
|
||||
"""Prompt constructor for VisualGLM.
|
||||
|
||||
The overall prompt will be formulated as
|
||||
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
|
||||
Args:
|
||||
system_prompt (str): System prompt. (Default: '')
|
||||
human_prompt (str): Human prompt. (Default: 'Q:')
|
||||
image_prompt (str): Image prompt. (Default: '<img></img>')
|
||||
assistant_prompt (str): Assistant prompt. (Default: 'A:')
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
system_prompt: str = '',
|
||||
human_prompt: str = 'Q:',
|
||||
image_prompt: str = '<img></img>',
|
||||
assistant_prompt: str = 'A:') -> None:
|
||||
self.image_prompt = image_prompt
|
||||
self.system_prompt = system_prompt
|
||||
self.human_prompt = human_prompt
|
||||
self.assistant_prompt = assistant_prompt
|
||||
|
||||
def __call__(self, batch: dict) -> tuple:
|
||||
"""Construct prompt.
|
||||
|
||||
Args:
|
||||
batch (dict): Input data containing image and data_samples.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing prompt, images and data_samples.
|
||||
"""
|
||||
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
|
||||
data_samples = batch.pop('data_samples')
|
||||
questions = [sample.get('question') for sample in data_samples]
|
||||
options = [sample.get('options') for sample in data_samples]
|
||||
contexts = [sample.get('context') for sample in data_samples]
|
||||
contexts = [c if c else '' for c in contexts]
|
||||
|
||||
# generate text prompt
|
||||
prompt = [
|
||||
'{}{}{}{}{}{}{}'.format(self.system_prompt, self.image_prompt,
|
||||
self.human_prompt, context, question,
|
||||
option, self.assistant_prompt)
|
||||
for context, question, option in zip(contexts, questions, options)
|
||||
]
|
||||
|
||||
image_position = 5
|
||||
|
||||
return images, prompt, data_samples, image_position
|
98
opencompass/multimodal/models/visualglm/visualglm.py
Normal file
98
opencompass/multimodal/models/visualglm/visualglm.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Optional
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from opencompass.registry import MM_MODELS
|
||||
|
||||
|
||||
@MM_MODELS.register_module('visualglm')
|
||||
class VisualGLM(nn.Module):
|
||||
"""Inference code of VisualGLM.
|
||||
|
||||
We load the visualGLM model via Huggingface.
|
||||
Args:
|
||||
pretrained_path (str): Path to visualGLM checkpoint or repo id.
|
||||
prompt_constructor (dict): The config of prompt constructor.
|
||||
post_processor (dict): The config of post processor.
|
||||
gen_kwargs (dict): Customize generate function arguments.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained_path: str,
|
||||
prompt_constructor: dict,
|
||||
post_processor: dict,
|
||||
gen_kwargs: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
|
||||
trust_remote_code=True)
|
||||
self.model = AutoModel.from_pretrained(pretrained_path,
|
||||
trust_remote_code=True).half()
|
||||
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||
prompt_constructor, MM_MODELS)
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
|
||||
if gen_kwargs:
|
||||
self.gen_kwargs = gen_kwargs
|
||||
else:
|
||||
self.gen_kwargs = dict()
|
||||
|
||||
def encode_by_tokenizer(self, multi_prompts, image_position):
|
||||
input_ids = []
|
||||
max_seq_length = 0
|
||||
for prompt in multi_prompts:
|
||||
input0 = self.tokenizer.encode(prompt[:image_position],
|
||||
add_special_tokens=False)
|
||||
input1 = [self.tokenizer.pad_token_id] * self.model.image_length
|
||||
input2 = self.tokenizer.encode(prompt[image_position:],
|
||||
add_special_tokens=False)
|
||||
input_all = sum([input0, input1, input2], [])
|
||||
input_all = self.tokenizer.build_inputs_with_special_tokens(
|
||||
input_all)
|
||||
max_seq_length = max(max_seq_length, len(input_all))
|
||||
input_ids.append(input_all)
|
||||
pre_image_len = len(input0)
|
||||
|
||||
# padding
|
||||
for i, _ in enumerate(input_ids):
|
||||
pad_len = max_seq_length - len(input_ids[i])
|
||||
input_ids[i] = [self.tokenizer.pad_token_id
|
||||
] * pad_len + input_ids[i]
|
||||
|
||||
return input_ids, pre_image_len
|
||||
|
||||
def generate(self, batch):
|
||||
# process input
|
||||
image, prompt, data_sample, image_position = self.prompt_constructor(
|
||||
batch)
|
||||
image = image.to(self.model.dtype).to(get_device())
|
||||
|
||||
# tokenize
|
||||
input_all, pre_image_len = self.encode_by_tokenizer(
|
||||
prompt, image_position)
|
||||
|
||||
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
|
||||
|
||||
# build input param
|
||||
inputs = {
|
||||
'input_ids': input_all,
|
||||
'pre_image_length': pre_image_len,
|
||||
'images': image
|
||||
}
|
||||
# generate answer
|
||||
outputs = self.model.generate(**inputs, **self.gen_kwargs)
|
||||
|
||||
# format output
|
||||
outputs = outputs.tolist()
|
||||
for i, sample in enumerate(data_sample):
|
||||
data_sample[i].pred_answer = self.post_processor(
|
||||
outputs[i], self.tokenizer, input_all.shape[1])
|
||||
|
||||
return data_sample
|
||||
|
||||
def forward(self, batch):
|
||||
return self.generate(batch)
|
Loading…
Reference in New Issue
Block a user