diff --git a/configs/multimodal/llama_adapter_v2_multimodal/README.md b/configs/multimodal/llama_adapter_v2_multimodal/README.md new file mode 100644 index 00000000..781cd877 --- /dev/null +++ b/configs/multimodal/llama_adapter_v2_multimodal/README.md @@ -0,0 +1,24 @@ +# Llama Adapter V2 + +### Prepare the environment + +```sh +cd opencompass/multimodal/models/llama_adapter_v2_multimodal +git clone https://github.com/OpenGVLab/LLaMA-Adapter.git +``` + +### Start evaluation + +#### Slurm + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION +``` + +#### PyTorch + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval +``` \ No newline at end of file diff --git a/configs/multimodal/llama_adapter_v2_multimodal/llama_adapter_v2_mm_7b_mmbench.py b/configs/multimodal/llama_adapter_v2_multimodal/llama_adapter_v2_mm_7b_mmbench.py new file mode 100644 index 00000000..44c42c60 --- /dev/null +++ b/configs/multimodal/llama_adapter_v2_multimodal/llama_adapter_v2_mm_7b_mmbench.py @@ -0,0 +1,45 @@ +from opencompass.multimodal.models.llama_adapter_v2_multimodal import ( + LlamaAadapterMMBenchPostProcessor, LlamaAadapterMMBenchPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'answer', 'options', 'category', 'l2-category', + 'index', 'context', 'options_dict' + ]) +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +llama_adapter_mmbench_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +llama_adapter_model = dict( + type='LLaMA-adapter-v2', + llama_dir= # noqa + '/llama_adapter_v2_multimodal', + prompt_constructor=dict(type=LlamaAadapterMMBenchPromptConstructor), + post_processor=dict(type=LlamaAadapterMMBenchPostProcessor)) +) + +# evaluation settings +llama_adapter_evaluator = [ + dict( + type='opencompass.DumpResults', + save_path='work_dirs/llama-adapter-v2-multimodal-mmagibench-v0.1.0.xlsx' + ) +] diff --git a/configs/multimodal/mplug_owl/README.md b/configs/multimodal/mplug_owl/README.md new file mode 100644 index 00000000..7425f94b --- /dev/null +++ b/configs/multimodal/mplug_owl/README.md @@ -0,0 +1,24 @@ +# MplugOwl + +### Prepare the environment + +```sh +cd opencompass/multimodal/models/mplug_owl +git clone https://github.com/X-PLUG/mPLUG-Owl.git +``` + +### Start evaluation + +#### Slurm + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION +``` + +#### PyTorch + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval +``` \ No newline at end of file diff --git a/configs/multimodal/mplug_owl/mplug_owl-7b-mmbench.py b/configs/multimodal/mplug_owl/mplug_owl-7b-mmbench.py new file mode 100644 index 00000000..322c041f --- /dev/null +++ b/configs/multimodal/mplug_owl/mplug_owl-7b-mmbench.py @@ -0,0 +1,48 @@ +from opencompass.multimodal.models.mplug_owl import ( + MplugOwlMMBenchPostProcessor, MplugOwlMMBenchPromptConstructor) + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(224, 224), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict( + type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'answer', 'category', 'l2-category', 'context', + 'index', 'options_dict', 'options' + ], + ), +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +mplug_owl_mmbench_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False), +) + +# model settings +mplug_owl_mmbench_model = dict( + type='mplug_owl-7b', + model_path='/mplug-owl-llama-7b-ft', + prompt_constructor=dict(type=MplugOwlMMBenchPromptConstructor), + post_processor=dict(type=MplugOwlMMBenchPostProcessor) +) # noqa + +# evaluation settings +mplug_owl_mmbench_evaluator = [ + dict(type='opencompass.DumpResults', + save_path='work_dirs/mplug_owl-7b-mmagibench-v0.1.0.xlsx') +] diff --git a/configs/multimodal/tasks.py b/configs/multimodal/tasks.py index ef6bd417..56dee084 100644 --- a/configs/multimodal/tasks.py +++ b/configs/multimodal/tasks.py @@ -13,4 +13,4 @@ load_froms = [minigpt_4_mmbench_load_from] num_gpus = 8 num_procs = 8 -launcher = 'pytorch' \ No newline at end of file +launcher = 'pytorch' diff --git a/opencompass/multimodal/datasets/__init__.py b/opencompass/multimodal/datasets/__init__.py index dcb96607..39dde918 100644 --- a/opencompass/multimodal/datasets/__init__.py +++ b/opencompass/multimodal/datasets/__init__.py @@ -1,5 +1,6 @@ -from .mmbench import MMBenchDataset -from .mme import MMEDataset -from .seedbench import SEEDBenchDataset +from .mmbench import MMBenchDataset # noqa: F401, F403 +from .mme import MMEDataset # noqa: F401, F403 +from .seedbench import SEEDBenchDataset # noqa: F401, F403 -__all__ = ['MMBenchDataset', 'SEEDBenchDataset', 'MMEDataset'] +__all__ = ['MMBenchDataset' + 'SEEDBenchDataset', 'MMEDataset'] diff --git a/opencompass/multimodal/models/__init__.py b/opencompass/multimodal/models/__init__.py index 2cebe9f9..f62458a4 100644 --- a/opencompass/multimodal/models/__init__.py +++ b/opencompass/multimodal/models/__init__.py @@ -8,7 +8,9 @@ if satisfy_requirement('salesforce-lavis'): if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'): from .minigpt_4 import * # noqa: F401, F403 +from .llama_adapter_v2_multimodal import * # noqa: F401, F403 from .llava import * # noqa: F401, F403 +from .mplug_owl import * # noqa: F401, F403 from .openflamingo import * # noqa: F401, F403 from .otter import * # noqa: F401, F403 from .visualglm import * # noqa: F401, F403 diff --git a/opencompass/multimodal/models/llama_adapter_v2_multimodal/__init__.py b/opencompass/multimodal/models/llama_adapter_v2_multimodal/__init__.py new file mode 100644 index 00000000..34c55ee8 --- /dev/null +++ b/opencompass/multimodal/models/llama_adapter_v2_multimodal/__init__.py @@ -0,0 +1,8 @@ +from .llama_adapter import LLaMA_adapter_v2 +from .post_processor import LlamaAadapterMMBenchPostProcessor +from .prompt_constructor import LlamaAadapterMMBenchPromptConstructor # noqa + +__all__ = [ + 'LLaMA_adapter_v2', 'LlamaAadapterMMBenchPostProcessor', + 'LlamaAadapterMMBenchPromptConstructor' +] diff --git a/opencompass/multimodal/models/llama_adapter_v2_multimodal/llama_adapter.py b/opencompass/multimodal/models/llama_adapter_v2_multimodal/llama_adapter.py new file mode 100644 index 00000000..b65d50f4 --- /dev/null +++ b/opencompass/multimodal/models/llama_adapter_v2_multimodal/llama_adapter.py @@ -0,0 +1,306 @@ +import json +import os +from pathlib import Path + +import clip +import mmengine +import torch +import torch.nn as nn +from llama_adapter_v2_multimodal7b.llama.llama import ModelArgs, Transformer +from llama_adapter_v2_multimodal7b.llama.tokenizer import Tokenizer +from llama_adapter_v2_multimodal7b.llama.utils import sample_top_p +from mmengine.device import get_device +from timm.models.vision_transformer import Block + +from opencompass.registry import MM_MODELS + + +class LLaMA_adapter(nn.Module): + + def __init__(self, + llama_ckpt_dir, + llama_tokenizer, + max_seq_len=512, + max_batch_size=1, + clip_model='ViT-L/14', + v_embed_dim=768, + v_depth=8, + v_num_heads=16, + v_mlp_ratio=4.0, + query_len=10, + query_layer=31, + w_bias=False, + w_lora=False, + lora_rank=16, + prompt_constructor=None, + post_processor=None): + super().__init__() + + self.device = get_device() + # load llama configs + with open(os.path.join(llama_ckpt_dir, 'params.json'), 'r') as f: + params = json.loads(f.read()) + model_args = ModelArgs(max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params) + + # 1. clip and clip projector + self.clip, self.clip_transform = clip.load(clip_model) + + clip_dim = self.clip.visual.proj.shape[1] + self.clip_proj = nn.Linear(clip_dim, v_embed_dim) + self.clip_proj_norm = nn.LayerNorm(v_embed_dim) + + self.query_len = query_len + self.query_layer = query_layer + + # 2. visual query, blocks and projector + self.visual_query = nn.Embedding(query_len, v_embed_dim) + self.visual_blocks = nn.ModuleList([ + Block(v_embed_dim, v_num_heads, v_mlp_ratio, qkv_bias=True) + for _ in range(v_depth) + ]) + self.visual_proj = nn.Linear(v_embed_dim, model_args.dim) + self.visual_proj_norm = nn.LayerNorm(model_args.dim) + + # 3. adapter query + self.adapter_query = nn.Embedding(query_len * query_layer, + model_args.dim) + + # 4. tokenizer + self.tokenizer = Tokenizer(model_path=llama_tokenizer) + + # 5. llama + model_args.vocab_size = self.tokenizer.n_words + model_args.w_bias = w_bias + model_args.w_lora = w_lora + model_args.lora_rank = lora_rank + torch.set_default_tensor_type(torch.cuda.HalfTensor) + self.llama = Transformer(model_args) + torch.set_default_tensor_type(torch.FloatTensor) + + ckpts = sorted(Path(llama_ckpt_dir).glob('*.pth')) + for ckpt in ckpts: + ckpt = torch.load(ckpt, map_location='cpu') + self.llama.load_state_dict(ckpt, strict=False) + + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + if post_processor is not None: + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) + + def clip_encode_image(self, x): + # modified from CLIP + x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid] + # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.clip.visual.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.clip.visual.positional_embedding.to(x.dtype) + x = self.clip.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + # preserve all spatial tokens + x = self.clip.visual.ln_post(x[:, :, :]) + + if self.clip.visual.proj is not None: + x = x @ self.clip.visual.proj + + return x + + def forward_visual(self, imgs): + clip_feats = self.clip_encode_image(imgs) + clip_feats = self.clip_proj_norm(self.clip_proj(clip_feats.float())) + + visual_query = self.visual_query.weight.unsqueeze(0).repeat( + len(imgs), 1, 1) + visual_query = torch.cat([visual_query, clip_feats], dim=1) + for block in self.visual_blocks: + visual_query = block(visual_query) + + visual_query = visual_query[:, :self.query_len, :] + visual_query = self.visual_proj(visual_query) + visual_query = self.visual_proj_norm(visual_query) + + return visual_query + + @torch.inference_mode() + def forward(self, visual_query, tokens, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.llama.tok_embeddings(tokens) + freqs_cis = self.llama.freqs_cis.to(h.device) + freqs_cis = freqs_cis[start_pos:start_pos + seqlen] + mask = None + mask = torch.full((1, 1, seqlen, seqlen), + float('-inf'), + device=h.device) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + for layer in self.llama.layers[:-1 * self.query_layer]: + h = layer(h, start_pos, freqs_cis, mask) + + adapter = self.adapter_query.weight.reshape(self.query_layer, + self.query_len, + -1).unsqueeze(1) + adapter_index = 0 + for layer in self.llama.layers[-1 * self.query_layer:]: + dynamic_adapter = adapter[adapter_index].repeat(_bsz, 1, 1) + dynamic_adapter = dynamic_adapter + visual_query + h = layer(h, start_pos, freqs_cis, mask, dynamic_adapter) + adapter_index = adapter_index + 1 + + h = self.llama.norm(h) + output = self.llama.output(h[:, -1, :]) + + return output.float() + + def pack_inputs(self, batch): + images = [image.unsqueeze(0) for image in batch['inputs']] + data_samples = [data_sample for data_sample in batch['data_samples']] + images = torch.cat(images, dim=0).to(get_device()) + inputs = {'image': images, 'data_samples': data_samples} + return inputs + + @torch.inference_mode() + def generate(self, batch): + max_gen_len = 256 + temperature = 0.1 + top_p = 0.75 + inputs = self.pack_inputs(batch) + inputs = self.prompt_constructor(inputs) + image = inputs['image'] + prompts = inputs['prompt'] + data_samples = inputs['data_samples'] + + data_sample = data_samples[0] + + prompts = [prompts] + imgs = image + + # import pdb;pdb.set_trace() + bsz = len(imgs) + params = self.llama.params + + with torch.cuda.amp.autocast(): + visual_query = self.forward_visual(imgs) + + # import pdb;pdb.set_trace() + if isinstance(prompts[0], str): + prompts = [ + self.tokenizer.encode(x, bos=True, eos=False) for x in prompts + ] + + # import pdb;pdb.set_trace() + min_prompt_size = min([len(t) for t in prompts]) + max_prompt_size = max([len(t) for t in prompts]) + + total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) + + tokens = torch.full((bsz, total_len), + self.tokenizer.pad_id).cuda().long() + + # import pdb;pdb.set_trace() + for k, t in enumerate(prompts): + if len(t) <= total_len: + tokens[k, :len(t)] = torch.tensor(t).cuda().long() + else: + tokens[k, :total_len] = torch.tensor( + t[:total_len]).cuda().long() + + input_text_mask = tokens != self.tokenizer.pad_id + start_pos = min_prompt_size + prev_pos = 0 + for cur_pos in range(start_pos, total_len): + with torch.cuda.amp.autocast(): + logits = self.forward(visual_query, + tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + next_token = next_token.reshape(-1) + + next_token = torch.where(input_text_mask[:, cur_pos], + tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + # trick: early stop if bsz==1 + if bsz == 1 and next_token[0] == self.tokenizer.eos_id: + break + prev_pos = cur_pos + + decoded = [] + for i, t in enumerate(tokens.tolist()): + + # cut to max gen len + t = t[len(prompts[i]):len(prompts[i]) + max_gen_len] + # cut to eos tok if any + try: + t = t[:t.index(self.tokenizer.eos_id)] + except ValueError: + pass + decoded.append(self.tokenizer.decode(t)) + + output_text = self.post_processor(decoded[0]) + data_sample.pred_answer = output_text + return data_sample + + +@MM_MODELS.register_module('LLaMA-adapter-v2') +class LLaMA_adapter_v2(nn.Module): + + def __init__(self, + llama_dir, + prompt_constructor: dict, + post_processor: dict, + mode: str = 'generation', + device='cuda' if torch.cuda.is_available() else 'cpu', + download_root='ckpts'): + super().__init__() + name = 'BIAS-7B' + + # BIAS-7B or https://xxx/sha256_BIAS-7B.pth -> 7B + llama_type = name.split('.')[0].split('-')[-1] + llama_ckpt_dir = os.path.join(llama_dir, llama_type) + llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model') + + # load llama_adapter weights and model_cfg + print(f'Loading LLaMA-Adapter from {llama_dir}') + ckpt = torch.load( + f'{llama_dir}/7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth', # noqa: E501 + map_location='cpu') + model_cfg = ckpt.get('config', {}) + + self.model = LLaMA_adapter( + llama_ckpt_dir, + llama_tokenzier_path, + max_seq_len=512, + max_batch_size=1, + clip_model='ViT-L/14', + v_embed_dim=768, + v_depth=8, + v_num_heads=16, + v_mlp_ratio=4.0, + query_len=10, + query_layer=31, + w_bias=model_cfg.get('w_bias', False), + w_lora=model_cfg.get('w_lora', False), + lora_rank=model_cfg.get('lora_rank', 16), + prompt_constructor=prompt_constructor, + post_processor=post_processor, + ) + + self.model.load_state_dict(ckpt['model'], strict=False) + self.mode = mode + + def forward(self, batch): + if self.mode == 'generation': + return self.model.generate(batch) diff --git a/opencompass/multimodal/models/llama_adapter_v2_multimodal/post_processor.py b/opencompass/multimodal/models/llama_adapter_v2_multimodal/post_processor.py new file mode 100644 index 00000000..fd9073ab --- /dev/null +++ b/opencompass/multimodal/models/llama_adapter_v2_multimodal/post_processor.py @@ -0,0 +1,15 @@ +import torch + + +class LlamaAadapterMMBenchPostProcessor: + """"Post processor for Llama Aadapter V2 on MMBench.""" + + def __init__(self) -> None: + pass + + def __call__(self, output_token: torch.tensor, tokenizer) -> str: + + if len(output_token) >= 2: + if output_token[1] == '.': + output_token = output_token[2:].strip() + return output_token diff --git a/opencompass/multimodal/models/llama_adapter_v2_multimodal/prompt_constructor.py b/opencompass/multimodal/models/llama_adapter_v2_multimodal/prompt_constructor.py new file mode 100644 index 00000000..2657447c --- /dev/null +++ b/opencompass/multimodal/models/llama_adapter_v2_multimodal/prompt_constructor.py @@ -0,0 +1,56 @@ +from typing import List + +from mmpretrain.structures import DataSample + + +class LlamaAadapterMMBenchPromptConstructor: + """Prompt constructor for Llama Adapter v2 on MMBench. + + Args: + image_prompt (str): Image prompt. Defaults to `''`. + reply_prompt (str): Reply prompt. Defaults to `''`. + """ + + def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None: + self.image_prompt = image_prompt + self.reply_prompt = reply_prompt + + def __call__(self, inputs: dict) -> dict: + """Construct prompt. + + Args: + inputs (dict): Input data containing image and data_samples. + + Returns: + dict: A dict containing prompt, images and data_samples. + """ + data_samples = inputs['data_samples'] + prompt = self._process(data_samples) + inputs.update({'prompt': prompt}) + + return inputs + + def _process(self, data_samples: List[DataSample]) -> str: + """Process data sample to prompt. + + Args: + data_samples (List[DataSample]): A list of data_samples. + + Returns: + str: Prompt. + """ + # import pdb;pdb.set_trace() + question = [ + data_sample.get('question') for data_sample in data_samples + ] + options = [data_sample.get('options') for data_sample in data_samples] + if data_samples[0].get('context') is not None: + context = [ + data_sample.get('context') for data_sample in data_samples + ] + else: + context = '' + + prompts = context + ' ' + question + ' ' + options # noqa + + return prompts diff --git a/opencompass/multimodal/models/mplug_owl/__init__.py b/opencompass/multimodal/models/mplug_owl/__init__.py new file mode 100644 index 00000000..a29cf270 --- /dev/null +++ b/opencompass/multimodal/models/mplug_owl/__init__.py @@ -0,0 +1,8 @@ +from .mplug_owl import MplugOwl +from .post_processor import MplugOwlMMBenchPostProcessor +from .prompt_constructor import MplugOwlMMBenchPromptConstructor # noqa + +__all__ = [ + 'MplugOwl', 'MplugOwlMMBenchPostProcessor', + 'MplugOwlMMBenchPromptConstructor' +] diff --git a/opencompass/multimodal/models/mplug_owl/mplug_owl.py b/opencompass/multimodal/models/mplug_owl/mplug_owl.py new file mode 100644 index 00000000..49da472a --- /dev/null +++ b/opencompass/multimodal/models/mplug_owl/mplug_owl.py @@ -0,0 +1,86 @@ +import mmengine +import torch +import torch.nn as nn +from mmengine.device import get_device +# Load via Huggingface Style +from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration +from mplug_owl.processing_mplug_owl import (MplugOwlImageProcessor, + MplugOwlProcessor) +from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer + +from opencompass.registry import MM_MODELS + + +@MM_MODELS.register_module('mplug_owl') +class MplugOwl(nn.Module): + + def __init__(self, + prompt_constructor: dict, + post_processor: dict, + model_path='MAGAer13/mplug-owl-llama-7b', + mode: str = 'generation') -> None: + super().__init__() + pretrained_ckpt = model_path + # import pdb;pdb.set_trace() + self.model = MplugOwlForConditionalGeneration.from_pretrained( + pretrained_ckpt, + torch_dtype=torch.bfloat16, + ).cuda() + self.image_processor = MplugOwlImageProcessor.from_pretrained( + pretrained_ckpt) + self.tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt) + self.processor = MplugOwlProcessor(self.image_processor, + self.tokenizer) + self.generate_kwargs = { + 'do_sample': False, + 'top_k': 5, + 'max_length': 20, + 'num_beams': 3, + } + + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + if post_processor is not None: + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) + + self.mode = mode + + def forward(self, batch): + if self.mode == 'generation': + return self.generate(batch) + + def generate(self, batch): + images = [image.unsqueeze(0) for image in batch['inputs']] + data_samples = [data_sample for data_sample in batch['data_samples']] + images = torch.cat(images, dim=0).to(get_device()) + inputs = {'image': images, 'data_samples': data_samples} + inputs = self.prompt_constructor(inputs) + image = inputs['image'] + prompt = inputs['prompt'] + data_samples = inputs['data_samples'] + + data_sample = data_samples[0] + owl_template = """The following is a conversation + between a curious human and AI assistant. + The assistant gives helpful, detailed, and + polite answers to the user's questions. + Human: + Human: {text_input} + AI: """ + prompt = owl_template.format(text_input=prompt) + inputs = self.processor(text=[prompt], return_tensors='pt') + inputs['pixel_values'] = image + # inputs['pixel_values'] = torch.zeros_like(samples['image']) + inputs = { + k: v.bfloat16() if v.dtype == torch.float else v + for k, v in inputs.items() + } + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + with torch.no_grad(): + res = self.model.generate(**inputs, **self.generate_kwargs) + output_text = self.tokenizer.decode(res.tolist()[0], + skip_special_tokens=True) + output_text = self.post_processor(output_text) + data_sample.pred_answer = output_text + return data_sample diff --git a/opencompass/multimodal/models/mplug_owl/post_processor.py b/opencompass/multimodal/models/mplug_owl/post_processor.py new file mode 100644 index 00000000..2538349b --- /dev/null +++ b/opencompass/multimodal/models/mplug_owl/post_processor.py @@ -0,0 +1,17 @@ +import re + +import torch + + +class MplugOwlMMBenchPostProcessor: + """"Post processor for MplugOwl on MMBench.""" + + def __init__(self) -> None: + pass + + def __call__(self, output_token: torch.tensor, tokenizer) -> str: + pattern = re.compile(r'([A-Z]\.)') + res = pattern.findall(output_token) + if len(res) > 0: + output_token = res[0][:-1] + return output_token diff --git a/opencompass/multimodal/models/mplug_owl/prompt_constructor.py b/opencompass/multimodal/models/mplug_owl/prompt_constructor.py new file mode 100644 index 00000000..b3998710 --- /dev/null +++ b/opencompass/multimodal/models/mplug_owl/prompt_constructor.py @@ -0,0 +1,55 @@ +from typing import List + +from mmpretrain.structures import DataSample + + +class MplugOwlMMBenchPromptConstructor: + """Prompt constructor for MplugOwl on MMBench. + + Args: + image_prompt (str): Image prompt. Defaults to `''`. + reply_prompt (str): Reply prompt. Defaults to `''`. + """ + + def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None: + self.image_prompt = image_prompt + self.reply_prompt = reply_prompt + + def __call__(self, inputs: dict) -> dict: + """Construct prompt. + + Args: + inputs (dict): Input data containing image and data_samples. + + Returns: + dict: A dict containing prompt, images and data_samples. + """ + data_samples = inputs['data_samples'] + prompt = self._process(data_samples) + inputs.update({'prompt': prompt}) + + return inputs + + def _process(self, data_samples: List[DataSample]) -> str: + """Process data sample to prompt. + + Args: + data_samples (List[DataSample]): A list of data_samples. + + Returns: + str: Prompt. + """ + question = [ + data_sample.get('question') for data_sample in data_samples + ] + options = [data_sample.get('options') for data_sample in data_samples] + if data_samples[0].get('context') is not None: + context = [ + data_sample.get('context') for data_sample in data_samples + ] + else: + context = '' + + prompts = context + ' ' + question + ' ' + options # noqa + + return prompts