OpenCompass/opencompass/multimodal/models/otter/otter.py
Li Bo a4d6840739
[Feat] Add Otter to OpenCompass MMBench Evaluation (#232)
* add otter model for opencompass mmbench

* add docs

* add readme docs

* debug for otter opencomass eval

* delete unused folders

* change to default data path

* remove unused files

* remove unused files

* update

* update config file

* flake8 lint formated and add prompt generator

* add prompt generator to config

* add a specific postproecss

* add post processor

* add post processor

* add post processor

* update according to suggestions

* remove unused redefinition
2023-08-31 12:55:53 +08:00

72 lines
2.7 KiB
Python

import mmengine
import torch
import torch.nn as nn
from opencompass.registry import MM_MODELS
from .Otter.models.otter.modeling_otter import OtterForConditionalGeneration
@MM_MODELS.register_module('otter-9b')
class Otter(nn.Module):
"""Inference code of OTTER.
Model details:
OTTER: a multi-modal model based on OpenFlamingo
(open-sourced version of DeepMind's Flamingo)
https://github.com/Luodian/Otter
Args:
model_path (str): The path of OTTER model
in Huggingface model hub format.
load_bit (str): The bit of OTTER model, can be "fp32" or "bf16".
"""
def __init__(self, model_path, load_bit, prompt_constructor,
post_processor) -> None:
super().__init__()
torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32
self.model = OtterForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch_dtype)
self.tokenizer = self.model.text_tokenizer
self.tokenizer.padding_side = 'left'
self.model_dtype = next(self.model.parameters()).dtype
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
elif self.mode == 'loss':
return self.loss(batch)
else:
raise RuntimeError(f'Invalid mode "{self.mode}".')
def generate(self, batch):
inputs = self.prompt_constructor(batch)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples']
vision_x = image.unsqueeze(1).unsqueeze(0).to(dtype=self.model_dtype)
lang_x = self.model.text_tokenizer([prompt], return_tensors='pt')
bad_words_id = self.model.text_tokenizer(['User:', 'GPT:']).input_ids
generated_text = self.model.generate(
vision_x=vision_x.to(self.model.device),
lang_x=lang_x['input_ids'].to(self.model.device),
attention_mask=lang_x['attention_mask'].to(self.model.device),
do_sample=False,
max_new_tokens=512,
num_beams=3,
bad_words_ids=bad_words_id,
no_repeat_ngram_size=3,
)
for i, data_sample in enumerate(data_samples):
output_text = self.post_processor(generated_text[i],
self.model.text_tokenizer)
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples