mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* [Feat] Add public dataset support for visualglm, qwenvl, and flamingo * [Fix] MMBench related changes. * [Fix] Openflamingo inference. * [Fix] Hide ckpt path. * [Fix] Pre-commit. --------- Co-authored-by: Haodong Duan <dhd.efz@gmail.com>
80 lines
3.0 KiB
Python
80 lines
3.0 KiB
Python
import importlib
|
|
|
|
import mmengine
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.device import get_device
|
|
|
|
from opencompass.registry import MM_MODELS
|
|
|
|
|
|
@MM_MODELS.register_module('otter-9b')
|
|
class Otter(nn.Module):
|
|
"""Inference code of OTTER.
|
|
|
|
Model details:
|
|
OTTER: a multi-modal model based on OpenFlamingo
|
|
(open-sourced version of DeepMind's Flamingo)
|
|
https://github.com/Luodian/Otter
|
|
Args:
|
|
model_path (str): The path of OTTER model
|
|
in Huggingface model hub format.
|
|
load_bit (str): The bit of OTTER model, can be "fp32" or "bf16".
|
|
mode (str): The mode of inference. Defaults to 'generation'.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_path,
|
|
load_bit,
|
|
prompt_constructor,
|
|
post_processor,
|
|
mode='generation') -> None:
|
|
super().__init__()
|
|
torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32
|
|
otter_ai = importlib.import_module('otter_ai')
|
|
self.model = otter_ai.OtterForConditionalGeneration.from_pretrained(
|
|
model_path, torch_dtype=torch_dtype, device_map=get_device())
|
|
self.tokenizer = self.model.text_tokenizer
|
|
self.tokenizer.padding_side = 'left'
|
|
self.model_dtype = next(self.model.parameters()).dtype
|
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
|
prompt_constructor, MM_MODELS)
|
|
if post_processor is not None:
|
|
self.post_processor = mmengine.registry.build_from_cfg(
|
|
post_processor, MM_MODELS)
|
|
self.mode = mode
|
|
|
|
def forward(self, batch):
|
|
if self.mode == 'generation':
|
|
return self.generate(batch)
|
|
elif self.mode == 'loss':
|
|
return self.loss(batch)
|
|
else:
|
|
raise RuntimeError(f'Invalid mode "{self.mode}".')
|
|
|
|
def generate(self, batch):
|
|
inputs = self.prompt_constructor(batch)
|
|
image = inputs['image']
|
|
prompt = inputs['prompt']
|
|
data_samples = inputs['data_samples']
|
|
vision_x = image.unsqueeze(1).unsqueeze(0).to(dtype=self.model_dtype)
|
|
lang_x = self.model.text_tokenizer([prompt], return_tensors='pt')
|
|
bad_words_id = self.model.text_tokenizer(['User:', 'GPT:']).input_ids
|
|
generated_text = self.model.generate(
|
|
vision_x=vision_x.to(self.model.device),
|
|
lang_x=lang_x['input_ids'].to(self.model.device),
|
|
attention_mask=lang_x['attention_mask'].to(self.model.device),
|
|
do_sample=False,
|
|
max_new_tokens=512,
|
|
num_beams=3,
|
|
bad_words_ids=bad_words_id,
|
|
no_repeat_ngram_size=3,
|
|
)
|
|
for i, data_sample in enumerate(data_samples):
|
|
output_text = self.post_processor(generated_text[i],
|
|
self.model.text_tokenizer)
|
|
data_sample.pred_answer = output_text
|
|
data_samples[i] = data_sample
|
|
|
|
return data_samples
|