mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
141 lines
5.1 KiB
Python
141 lines
5.1 KiB
Python
import os
|
|
os.environ['LOWRES_RESIZE'] = '384x32'
|
|
os.environ['HIGHRES_BASE'] = '0x32'
|
|
os.environ['VIDEO_RESIZE'] = "0x64"
|
|
os.environ['VIDEO_MAXRES'] = "480"
|
|
os.environ['VIDEO_MINRES'] = "288"
|
|
os.environ['MAXRES'] = '1536'
|
|
os.environ['MINRES'] = '0'
|
|
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
|
os.environ['LOAD_VISION_EARLY'] = '1'
|
|
os.environ['PAD2STRIDE'] = '1'
|
|
|
|
from opencompass.models.base import BaseModel
|
|
from typing import Dict, List, Optional
|
|
from typing import Dict, List, Optional, Union
|
|
import numpy as np
|
|
import torch
|
|
|
|
from opencompass.models.base import BaseModel, LMTemplateParser
|
|
from opencompass.utils.prompt import PromptList
|
|
PromptType = Union[PromptList, str]
|
|
import sys
|
|
|
|
import torch
|
|
import re
|
|
from PIL import Image
|
|
import numpy as np
|
|
import transformers
|
|
from typing import Dict, Optional, Sequence, List
|
|
from opencompass.models.ola.conversation import conv_templates, SeparatorStyle
|
|
from opencompass.models.ola.model.builder import load_pretrained_model
|
|
from opencompass.models.ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
|
from opencompass.models.ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
|
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
|
import argparse
|
|
import copy
|
|
|
|
class OlaModel(BaseModel):
|
|
def __init__(self,
|
|
path: str,
|
|
max_seq_len: int = 2048,
|
|
tokenizer_path: Optional[str] = None,
|
|
model_config: Optional[str] = None,
|
|
meta_template: Optional[Dict] = None):
|
|
|
|
|
|
self.template_parser = LMTemplateParser(meta_template)
|
|
self.eos_token_id = None
|
|
if meta_template and 'eos_token_id' in meta_template:
|
|
self.eos_token_id = meta_template['eos_token_id']
|
|
|
|
|
|
tokenizer, model, _, _ = load_pretrained_model(path, None)
|
|
model = model.to('cuda').eval()
|
|
model = model.bfloat16()
|
|
self.tokenizer=tokenizer
|
|
self.model=model
|
|
self.gen_kwargs = {
|
|
"max_new_tokens":1024,
|
|
"temperature":0.2,
|
|
"top_p":None,
|
|
"num_beams":1,
|
|
}
|
|
|
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
|
assert len(inputs)==1 # batch=1
|
|
image_path = None
|
|
audio_path = None
|
|
video_path = None
|
|
text = inputs[0]
|
|
|
|
|
|
images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
|
images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
|
image_sizes = [(224, 224)]
|
|
|
|
|
|
|
|
USE_SPEECH=False
|
|
speechs = []
|
|
speech_lengths = []
|
|
speech_wavs = []
|
|
speech_chunks = []
|
|
speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
|
|
speech_lengths = [torch.LongTensor([3000]).to('cuda')]
|
|
speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
|
|
speech_chunks = [torch.LongTensor([1]).to('cuda')]
|
|
|
|
|
|
conv_mode = "qwen_1_5"
|
|
if text:
|
|
qs = text
|
|
else:
|
|
qs = ''
|
|
conv = conv_templates[conv_mode].copy()
|
|
conv.append_message(conv.roles[0], qs)
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt = conv.get_prompt()
|
|
|
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
|
|
|
pad_token_ids = 151643
|
|
|
|
attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
|
keywords = [stop_str]
|
|
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
output_ids = self.model.generate(
|
|
input_ids,
|
|
images=images,
|
|
images_highres=images_highres,
|
|
image_sizes=image_sizes,
|
|
modalities=['text'],
|
|
speech=speechs,
|
|
speech_lengths=speech_lengths,
|
|
speech_chunks=speech_chunks,
|
|
speech_wav=speech_wavs,
|
|
attention_mask=attention_masks,
|
|
use_cache=True,
|
|
stopping_criteria=[stopping_criteria],
|
|
do_sample=True if self.gen_kwargs["temperature"] > 0 else False,
|
|
temperature=self.gen_kwargs["temperature"],
|
|
top_p=self.gen_kwargs["top_p"],
|
|
num_beams=self.gen_kwargs["num_beams"],
|
|
max_new_tokens=self.gen_kwargs["max_new_tokens"],
|
|
)
|
|
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
|
out=[]
|
|
for output in outputs:
|
|
output = output.strip()
|
|
if output.endswith(stop_str):
|
|
output = output[:-len(stop_str)]
|
|
out.append(output)
|
|
print(f"prompt---->",prompt)
|
|
print(f"out---->",out)
|
|
print(f"\n")
|
|
return out |