mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* refine gitignore * [Feature]: Add minigpt-4 * [Feature]: Add mm local runner * [Feature]: Add instructblip * add otter and llama-adapter * add owl * add llama2-adapter and owl * lint * [Feature]: Add minigpt-4 * [Feature]: Add instructblip * add otter and llama-adapter * add owl * add llama2-adapter and owl * lint * lint * update * lint * lint * add __init__.py * update * update * update --------- Co-authored-by: liuyuan <3463423099@qq.com>
87 lines
3.3 KiB
Python
87 lines
3.3 KiB
Python
import mmengine
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.device import get_device
|
|
# Load via Huggingface Style
|
|
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
|
|
from mplug_owl.processing_mplug_owl import (MplugOwlImageProcessor,
|
|
MplugOwlProcessor)
|
|
from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
|
|
|
|
from opencompass.registry import MM_MODELS
|
|
|
|
|
|
@MM_MODELS.register_module('mplug_owl')
|
|
class MplugOwl(nn.Module):
|
|
|
|
def __init__(self,
|
|
prompt_constructor: dict,
|
|
post_processor: dict,
|
|
model_path='MAGAer13/mplug-owl-llama-7b',
|
|
mode: str = 'generation') -> None:
|
|
super().__init__()
|
|
pretrained_ckpt = model_path
|
|
# import pdb;pdb.set_trace()
|
|
self.model = MplugOwlForConditionalGeneration.from_pretrained(
|
|
pretrained_ckpt,
|
|
torch_dtype=torch.bfloat16,
|
|
).cuda()
|
|
self.image_processor = MplugOwlImageProcessor.from_pretrained(
|
|
pretrained_ckpt)
|
|
self.tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt)
|
|
self.processor = MplugOwlProcessor(self.image_processor,
|
|
self.tokenizer)
|
|
self.generate_kwargs = {
|
|
'do_sample': False,
|
|
'top_k': 5,
|
|
'max_length': 20,
|
|
'num_beams': 3,
|
|
}
|
|
|
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
|
prompt_constructor, MM_MODELS)
|
|
if post_processor is not None:
|
|
self.post_processor = mmengine.registry.build_from_cfg(
|
|
post_processor, MM_MODELS)
|
|
|
|
self.mode = mode
|
|
|
|
def forward(self, batch):
|
|
if self.mode == 'generation':
|
|
return self.generate(batch)
|
|
|
|
def generate(self, batch):
|
|
images = [image.unsqueeze(0) for image in batch['inputs']]
|
|
data_samples = [data_sample for data_sample in batch['data_samples']]
|
|
images = torch.cat(images, dim=0).to(get_device())
|
|
inputs = {'image': images, 'data_samples': data_samples}
|
|
inputs = self.prompt_constructor(inputs)
|
|
image = inputs['image']
|
|
prompt = inputs['prompt']
|
|
data_samples = inputs['data_samples']
|
|
|
|
data_sample = data_samples[0]
|
|
owl_template = """The following is a conversation
|
|
between a curious human and AI assistant.
|
|
The assistant gives helpful, detailed, and
|
|
polite answers to the user's questions.
|
|
Human: <image>
|
|
Human: {text_input}
|
|
AI: """
|
|
prompt = owl_template.format(text_input=prompt)
|
|
inputs = self.processor(text=[prompt], return_tensors='pt')
|
|
inputs['pixel_values'] = image
|
|
# inputs['pixel_values'] = torch.zeros_like(samples['image'])
|
|
inputs = {
|
|
k: v.bfloat16() if v.dtype == torch.float else v
|
|
for k, v in inputs.items()
|
|
}
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
|
with torch.no_grad():
|
|
res = self.model.generate(**inputs, **self.generate_kwargs)
|
|
output_text = self.tokenizer.decode(res.tolist()[0],
|
|
skip_special_tokens=True)
|
|
output_text = self.post_processor(output_text)
|
|
data_sample.pred_answer = output_text
|
|
return data_sample
|