2023-08-21 15:57:30 +08:00
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
import mmengine
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from mmengine.device import get_device
|
|
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
|
|
|
|
from opencompass.registry import MM_MODELS
|
|
|
|
|
|
|
|
|
|
|
|
@MM_MODELS.register_module('visualglm')
|
|
|
|
class VisualGLM(nn.Module):
|
|
|
|
"""Inference code of VisualGLM.
|
|
|
|
|
|
|
|
We load the visualGLM model via Huggingface.
|
|
|
|
Args:
|
|
|
|
pretrained_path (str): Path to visualGLM checkpoint or repo id.
|
|
|
|
prompt_constructor (dict): The config of prompt constructor.
|
|
|
|
post_processor (dict): The config of post processor.
|
2023-08-25 15:44:32 +08:00
|
|
|
is_caption_task (bool): Whether the task is caption task.
|
|
|
|
Defaults to False.
|
2023-08-21 15:57:30 +08:00
|
|
|
gen_kwargs (dict): Customize generate function arguments.
|
2023-08-25 15:44:32 +08:00
|
|
|
Defaults to None.
|
2023-08-21 15:57:30 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
pretrained_path: str,
|
|
|
|
prompt_constructor: dict,
|
|
|
|
post_processor: dict,
|
2023-08-25 15:44:32 +08:00
|
|
|
is_caption_task: bool = False,
|
2023-08-21 15:57:30 +08:00
|
|
|
gen_kwargs: Optional[dict] = None) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
|
|
|
|
trust_remote_code=True)
|
|
|
|
self.model = AutoModel.from_pretrained(pretrained_path,
|
|
|
|
trust_remote_code=True).half()
|
|
|
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
|
|
|
prompt_constructor, MM_MODELS)
|
|
|
|
self.post_processor = mmengine.registry.build_from_cfg(
|
|
|
|
post_processor, MM_MODELS)
|
|
|
|
|
|
|
|
if gen_kwargs:
|
|
|
|
self.gen_kwargs = gen_kwargs
|
|
|
|
else:
|
2023-09-21 19:54:23 +08:00
|
|
|
self.gen_kwargs = dict(max_length=1024,
|
|
|
|
min_length=100,
|
|
|
|
do_sample=True,
|
|
|
|
temperature=0.8,
|
|
|
|
top_p=0.4,
|
|
|
|
top_k=100,
|
|
|
|
repetition_penalty=1.2)
|
2023-09-19 19:08:44 +08:00
|
|
|
|
2023-08-25 15:44:32 +08:00
|
|
|
self.is_caption_task = is_caption_task
|
2023-08-21 15:57:30 +08:00
|
|
|
|
2023-09-21 19:54:23 +08:00
|
|
|
def encode_by_tokenizer(self, prompt, image_position):
|
2023-08-21 15:57:30 +08:00
|
|
|
|
2023-09-21 19:54:23 +08:00
|
|
|
input0 = self.tokenizer.encode(prompt[:image_position],
|
|
|
|
add_special_tokens=False)
|
|
|
|
input1 = [self.tokenizer.unk_token_id] * self.model.image_length
|
|
|
|
input2 = self.tokenizer.encode(prompt[image_position:],
|
|
|
|
add_special_tokens=False)
|
|
|
|
input_all = sum([input0, input1, input2], [])
|
|
|
|
input_all = self.tokenizer.build_inputs_with_special_tokens(input_all)
|
|
|
|
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
|
|
|
|
input_all = input_all.unsqueeze(0)
|
2023-08-21 15:57:30 +08:00
|
|
|
|
2023-09-21 19:54:23 +08:00
|
|
|
pre_image_len = len(input0)
|
|
|
|
|
|
|
|
return input_all, pre_image_len
|
2023-08-21 15:57:30 +08:00
|
|
|
|
|
|
|
def generate(self, batch):
|
|
|
|
# process input
|
|
|
|
image, prompt, data_sample, image_position = self.prompt_constructor(
|
|
|
|
batch)
|
|
|
|
image = image.to(self.model.dtype).to(get_device())
|
|
|
|
|
|
|
|
# tokenize
|
|
|
|
input_all, pre_image_len = self.encode_by_tokenizer(
|
|
|
|
prompt, image_position)
|
|
|
|
|
|
|
|
# build input param
|
|
|
|
inputs = {
|
|
|
|
'input_ids': input_all,
|
|
|
|
'pre_image_length': pre_image_len,
|
|
|
|
'images': image
|
|
|
|
}
|
2023-09-21 19:54:23 +08:00
|
|
|
|
2023-08-21 15:57:30 +08:00
|
|
|
# generate answer
|
|
|
|
outputs = self.model.generate(**inputs, **self.gen_kwargs)
|
|
|
|
|
|
|
|
# format output
|
2023-09-21 19:54:23 +08:00
|
|
|
outputs = outputs.tolist()[0][input_all.shape[1]:]
|
|
|
|
answer = self.post_processor(outputs, self.tokenizer)
|
|
|
|
|
|
|
|
if self.is_caption_task:
|
|
|
|
data_sample.pred_caption = answer
|
|
|
|
else:
|
|
|
|
data_sample.pred_answer = answer
|
2023-08-21 15:57:30 +08:00
|
|
|
|
|
|
|
return data_sample
|
|
|
|
|
|
|
|
def forward(self, batch):
|
|
|
|
return self.generate(batch)
|