diff --git a/examples/eval_ola.py b/examples/eval_ola.py new file mode 100644 index 00000000..f3d9b169 --- /dev/null +++ b/examples/eval_ola.py @@ -0,0 +1,10 @@ +from mmengine.config import read_base + +with read_base(): + from opencompass.configs.datasets.demo.demo_gsm8k_chat_gen import gsm8k_datasets + from opencompass.configs.datasets.demo.demo_math_chat_gen import math_datasets + from opencompass.configs.models.ola.ola import models as ola_models + + +datasets = math_datasets +models = ola_models \ No newline at end of file diff --git a/opencompass/configs/models/ola/ola.py b/opencompass/configs/models/ola/ola.py new file mode 100644 index 00000000..0c0f595f --- /dev/null +++ b/opencompass/configs/models/ola/ola.py @@ -0,0 +1,12 @@ +from opencompass.models import OlaModel +models = [ + dict( + type=OlaModel, + path="THUdyh/Ola-7b", + max_seq_len=2048, + abbr='ola', + max_out_len=1024, + batch_size=1, + run_cfg=dict(num_gpus=1), + ) +] diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 580402d4..bbe3b2b0 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -35,6 +35,7 @@ from .openai_api import OpenAI # noqa: F401 from .openai_api import OpenAISDK # noqa: F401 from .pangu_api import PanGu # noqa: F401 from .qwen_api import Qwen # noqa: F401 +from .ola_model import OlaModel # noqa: F401 from .rendu_api import Rendu # noqa: F401 from .sensetime_api import SenseTime # noqa: F401 from .stepfun_api import StepFun # noqa: F401 diff --git a/opencompass/models/ola/arguments.py b/opencompass/models/ola/arguments.py new file mode 100644 index 00000000..199c5d5b --- /dev/null +++ b/opencompass/models/ola/arguments.py @@ -0,0 +1,65 @@ +import transformers + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_speech_projector: bool = field(default=False) + tune_speech_encoder: bool = field(default=False) + tune_speech_generator_only: bool = field(default=False) + speech_encoder_type: Optional[str] = field(default=None) + speech_encoder: Optional[str] = field(default=None) + pretrain_speech_projector: Optional[str] = field(default=None) + speech_projector_type: Optional[str] = field(default='linear') + speech_encoder_ds_rate: int = 5 + speech_encoder_hidden_size: int = 1280 + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + is_multimodal: bool = False + input_type: str = field(default="mel") + speech_normalize: bool = False + mel_size: int = 128 + has_tgt_units: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + freeze_speech_projector: bool = field(default=False) + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + speech_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) \ No newline at end of file diff --git a/opencompass/models/ola/constants.py b/opencompass/models/ola/constants.py new file mode 100644 index 00000000..9b903d94 --- /dev/null +++ b/opencompass/models/ola/constants.py @@ -0,0 +1,14 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +SPEECH_TOKEN_INDEX = -200 +DEFAULT_SPEECH_TOKEN = "" +IMAGE_TOKEN_INDEX= -300 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" \ No newline at end of file diff --git a/opencompass/models/ola/conversation.py b/opencompass/models/ola/conversation.py new file mode 100644 index 00000000..79001154 --- /dev/null +++ b/opencompass/models/ola/conversation.py @@ -0,0 +1,254 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Any, Union, Tuple +import base64 +from io import BytesIO +from PIL import Image + + +class SeparatorStyle(Enum): + """Different separator style.""" + TWO = auto() + PLAIN = auto() + CHATML = auto() + LLAMA_2 = auto() + LLAMA_3 = auto() + QWEN2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.PLAIN + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + tokenizer_id: str = "" + tokenizer: Any = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + + if self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg + ret = "<|begin_of_text|>" + wrap_sys(self.system) + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += message.strip() + self.sep2 + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + raise ValueError("Tuple not supported in CHATML") + message, images = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.QWEN2: + start = '<|im_start|>' + end = '<|im_end|>\n' + ret = start + 'system\n' + self.system + end + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + + if message.endswith('<|endoftext|>'): + message = message.replace('<|endoftext|>', '') + ret += start + role + "\n" + message + end + '<|endoftext|>' + else: + assert not '<|endoftext|>' in message, f"Invalid message: {message}" + ret += start + role + "\n" + message + end + else: + ret += start + role + "\n" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, speech = msg + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llama_3 = Conversation( + system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", + roles=("user", "assistant"), + version="llama_v3", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_3, + sep="", + sep2="<|eot_id|>" +) + + +conv_qwen_v1 = Conversation( + system="You are a helpful assistant.", + roles=("user", "assistant"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.QWEN2, +) + +conv_plain = Conversation( + system="", + roles=("", ""), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + +conv_qwen = Conversation( + system="""<|im_start|>system +You are a helpful assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + version="qwen", + messages=[], + offset=0, + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", +) + +default_conversation = conv_llama_3 +conv_templates = { + "v1": conv_vicuna_v1, + "plain": conv_plain, + "llama_2": conv_llama_2, + "llama_3": conv_llama_3, + 'v1_qwen2': conv_qwen_v1, + "qwen_1_5": conv_qwen, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/opencompass/models/ola/datasets/__init__.py b/opencompass/models/ola/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/opencompass/models/ola/datasets/preprocess.py b/opencompass/models/ola/datasets/preprocess.py new file mode 100644 index 00000000..d8621e94 --- /dev/null +++ b/opencompass/models/ola/datasets/preprocess.py @@ -0,0 +1,231 @@ +import copy +import torch +import transformers +import tokenizers + +from typing import Dict, Sequence + +from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX +from opencompass.models.ola import conversation as conversation_lib +from opencompass.models.ola.model import * +from opencompass.models.ola.arguments import DataArguments +from opencompass.models.ola.constants import SPEECH_TOKEN_INDEX + +from packaging import version + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("\nUser's question in speech: \n")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + nl_tokens = tokenizer("\n").input_ids + special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens] + + for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # FIXME: tokenizer bug + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_SPEECH_TOKEN in source[0]['value'] + source[0]['value'] = DEFAULT_SPEECH_TOKEN + conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_speech=has_speech) + raise NotImplementedError \ No newline at end of file diff --git a/opencompass/models/ola/mm_utils.py b/opencompass/models/ola/mm_utils.py new file mode 100644 index 00000000..3bf2189b --- /dev/null +++ b/opencompass/models/ola/mm_utils.py @@ -0,0 +1,271 @@ +from PIL import Image +import base64 +import math +import ast + +import torch +from transformers import StoppingCriteria +import os +import io + +if 'VIDEO_RESIZE' in os.environ: + # highresxpatch + VIDEO_RESIZE = os.environ['VIDEO_RESIZE'] + video_base, video_ps = VIDEO_RESIZE.split('x') + video_base = int(video_base) + video_ps = int(video_ps) + print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}") +else: + HIGHRES_BASE = None + +if 'HIGHRES_BASE' in os.environ: + # highresxpatch + HIGHRES_BASE = os.environ['HIGHRES_BASE'] + highres_base, highres_ps = HIGHRES_BASE.split('x') + highres_base = int(highres_base) + highres_ps = int(highres_ps) + print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}") +else: + HIGHRES_BASE = None + +if 'MAXRES' in os.environ: + # highresxpatch + MAXRES = int(os.environ['MAXRES']) + print(f"MAXRES is set as {MAXRES}") +else: + MAXRES = 1536 + +if 'MINRES' in os.environ: + # highresxpatch + MINRES = int(os.environ['MINRES']) + print(f"MINRES is set as {MINRES}") +else: + MINRES = 0 + +if 'VIDEO_MAXRES' in os.environ: + # highresxpatch + VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES']) + print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}") +else: + VIDEO_MAXRES = 1536 + +if 'VIDEO_MINRES' in os.environ: + # highresxpatch + VIDEO_MINRES = int(os.environ['VIDEO_MINRES']) + print(f"VIDEO_MINRES is set as {VIDEO_MINRES}") +else: + MINRES = 0 + +if 'PAD2STRIDE' in os.environ: + # highresxpatch + PAD2STRIDE = True + print(f"PAD2STRIDE is set") +else: + PAD2STRIDE = False + +if 'LOWRES_RESIZE' in os.environ: + LOWRES_RESIZE = os.environ['LOWRES_RESIZE'] + print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}") + if 'x' in LOWRES_RESIZE: + size, ps = LOWRES_RESIZE.split('x') + size = int(size) + ps = int(ps) + LOWRES_RESIZE = (size, ps) + else: + LOWRES_RESIZE = int(LOWRES_RESIZE) +else: + LOWRES_RESIZE = None + + +def pad_image(image, target_resolution, value=0): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + # Create a new image with the target size and paste the resized image onto it + new_image = Image.new('RGB', (target_width, target_height), (value, value, value)) + paste_x = (target_width - original_width) // 2 + paste_y = (target_height - original_height) // 2 + new_image.paste(image, (paste_x, paste_y)) + return new_image + +def resize_images(image, patch_size=14, base_size=896): + h, w = image.size + if base_size == 0: + if h * w > MAXRES * MAXRES: + # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') + scale = MAXRES * MAXRES / (h * w) + scale = math.sqrt(scale) + elif h * w < MINRES * MINRES: + # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') + scale = MINRES * MINRES / (h * w) + scale = math.sqrt(scale) + else: + scale = None + else: + scale = base_size * base_size / (h * w) + scale = math.sqrt(scale) + + + if scale is not None: + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + new_h = max(new_h, patch_size) + new_w = max(new_w, patch_size) + image = image.resize((new_h, new_w)) + elif PAD2STRIDE: + if h % patch_size == 0: + new_h = h + else: + new_h = (h // patch_size + 1) * patch_size + + if w % patch_size == 0: + new_w = w + else: + new_w = (w // patch_size + 1) * patch_size + image = pad_image(image, (new_h, new_w), value=127) + else: + scale = 1.0 + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + new_h = max(new_h, patch_size) + new_w = max(new_w, patch_size) + image = image.resize((new_h, new_w)) + + return image + +def resize_video(image, patch_size=14, base_size=896): + h, w = image.size + if base_size == 0: + if h * w > VIDEO_MAXRES * VIDEO_MAXRES: + # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') + scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w) + scale = math.sqrt(scale) + elif h * w < VIDEO_MINRES * VIDEO_MINRES: + # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') + scale = VIDEO_MINRES * VIDEO_MINRES / (h * w) + scale = math.sqrt(scale) + else: + scale = None + else: + scale = base_size * base_size / (h * w) + scale = math.sqrt(scale) + + if scale is not None: + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + image = image.resize((new_h, new_w)) + elif PAD2STRIDE: + if h % patch_size == 0: + new_h = h + else: + new_h = (h // patch_size + 1) * patch_size + + if w % patch_size == 0: + new_w = w + else: + new_w = (w // patch_size + 1) * patch_size + image = pad_image(image, (new_h, new_w), value=127) + else: + scale = 1.0 + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + image = image.resize((new_h, new_w)) + + return image + +def process_anyres_video(image, processor): + if VIDEO_RESIZE is not None: + image = resize_video(image, patch_size=video_ps, base_size=video_base) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return image.unsqueeze(0) + else: + raise ValueError("VIDEO_RESIZE is not set") + +def process_anyres_highres_image(image, processor): + processor2 = None + if type(processor) is tuple: + processor, processor2 = processor[0], processor[1] + + if HIGHRES_BASE is not None: + image = resize_images(image, patch_size=highres_ps, base_size=highres_base) + + if processor2 is not None: + image_original_resize = image.resize((processor2.size['shortest_edge'], processor.size['shortest_edge'])) + image_patches = [image_original_resize] + [image_original_resize] + image_patches = [processor2.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches] + else: + if LOWRES_RESIZE is not None: + if type(LOWRES_RESIZE) is int: + image_original_resize = resize_images(image, patch_size=14, base_size=LOWRES_RESIZE) + else: + image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0]) + else: + image_original_resize = image.resize((336, 336)) + image_patches = [image_original_resize] + image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches] + image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0) + +def read_image_patch(patch_info): + if 'img_path' in patch_info.keys(): + image = Image.open(patch_info['img_path']).convert('RGB') + else: + if 'image_encoing' in patch_info.keys(): + patch_info['image_encoding'] = patch_info['image_encoing'] + image_file_name = patch_info['patch'] + start_bytes = int(patch_info['start_num']) + file_size = int(patch_info['size']) + + with open(image_file_name, 'rb') as f: + f.seek(start_bytes) + if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64': + image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB") + else: + image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB") + return image + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO + offset = min(output_ids.shape[1] - self.start_len, 3) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if output_ids[0, -keyword_id.shape[0]:] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False diff --git a/opencompass/models/ola/model/__init__.py b/opencompass/models/ola/model/__init__.py new file mode 100644 index 00000000..7599857d --- /dev/null +++ b/opencompass/models/ola/model/__init__.py @@ -0,0 +1 @@ +from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen \ No newline at end of file diff --git a/opencompass/models/ola/model/builder.py b/opencompass/models/ola/model/builder.py new file mode 100644 index 00000000..0849b6b8 --- /dev/null +++ b/opencompass/models/ola/model/builder.py @@ -0,0 +1,91 @@ +import os +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import torch +from opencompass.models.ola.model import * +from opencompass.models.ola.model.speech_encoder.builder import build_speech_encoder + +def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs): + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.bfloat16 + + if use_flash_attn: + kwargs['attn_implementation'] = 'flash_attention_2' + + model_cls = OlaQwenForCausalLM + + # Load Ola model + if is_lora: + assert model_base is not None, "model_base is required for LoRA models." + from ola.model.language_model.ola_qwen import OlaConfigQwen + lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading Ola from base model...') + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) + print('Loading additional Ola weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + print('Loading Ola from base model...') + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) + + speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu') + speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} + model.load_state_dict(speech_projector_weights, strict=False) + model = model.to(device=device) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = model_cls.from_pretrained( + model_path, + low_cpu_mem_usage=False, + **kwargs + ) + model = model.to(device=device) + + model.get_model().speech_encoder = build_speech_encoder(model.config) + model.get_model().speech_encoder.to(device=device, dtype=torch.float16) + + image_processor = None + model.resize_token_embeddings(len(tokenizer)) + vision_tower = model.get_vision_tower() + print("Loading vision tower...") + if not vision_tower.is_loaded: + vision_tower.load_model(device_map=device) + if device != "auto": + vision_tower.to(device="cuda", dtype=torch.bfloat16) + else: + vision_tower.to(device="cuda:0", dtype=torch.bfloat16) + image_processor = vision_tower.image_processor + print("Loading vision tower succeeded.") + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 16384 + + return tokenizer, model, image_processor, context_len diff --git a/opencompass/models/ola/model/language_model/ola_qwen.py b/opencompass/models/ola/model/language_model/ola_qwen.py new file mode 100644 index 00000000..fd88538c --- /dev/null +++ b/opencompass/models/ola/model/language_model/ola_qwen.py @@ -0,0 +1,237 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +import transformers +from transformers import AutoConfig, AutoModelForCausalLM + + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM +from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM + + +class OlaConfigQwen(Qwen2Config): + model_type = "ola_qwen" + + +class OlaQwenModel(OlaMetaModel, Qwen2Model): + config_class = OlaConfigQwen + + def __init__(self, config: Qwen2Config): + super(OlaQwenModel, self).__init__(config) + + +class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM): + config_class = OlaConfigQwen + + def __init__(self, config): + super(Qwen2ForCausalLM, self).__init__(config) + + config.rope_scaling = None + self.model = OlaQwenModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + speech_chunks: Optional[torch.LongTensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + images: Optional[torch.FloatTensor] = None, + images_highres: Optional[List[torch.FloatTensor]] = None, + image_sizes: Optional[List[List[int]]] = None, + modalities: Optional[List[str]] = ["image"], + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_speech_vision_text( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + speech_chunks, + speech_wav, + images, + modalities, + image_sizes, + images_highres + ) + + if labels is None: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + else: + return self.forward_llm_efficient( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + + def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_dim = hidden_states.size(-1) + shift_labels = labels[..., 1:].contiguous().reshape(-1) + shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim) + assert shift_labels.size(0) == shift_hidden_states.size(0) + mask = shift_labels > -1 + assert mask.float().sum() > 0 + shift_labels = shift_labels[mask] + shift_hidden_states = shift_hidden_states[mask, :] + logits = self.lm_head(shift_hidden_states) + logits = logits.float() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits, shift_labels) + + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + speech_chunks: Optional[torch.Tensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + images: Optional[torch.Tensor] = None, + images_highres: Optional[List[torch.FloatTensor]] = None, + image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_speech_vision_text( + inputs, + position_ids, + attention_mask, + None, + None, + speech, + speech_lengths, + speech_chunks, + speech_wav, + images, + modalities, + image_sizes, + images_highres + ) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + speech_chunks = kwargs.pop("speech_chunks", None) + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if speech is not None: + inputs['speech'] = speech + inputs['speech_lengths'] = speech_lengths + inputs['speech_chunks'] = speech_chunks + if images is not None: + inputs["images"] = images + if image_sizes is not None: + inputs["image_sizes"] = image_sizes + return inputs + +AutoConfig.register("ola_qwen", OlaConfigQwen) +AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM) diff --git a/opencompass/models/ola/model/multimodal_encoder/builder.py b/opencompass/models/ola/model/multimodal_encoder/builder.py new file mode 100644 index 00000000..154a20b1 --- /dev/null +++ b/opencompass/models/ola/model/multimodal_encoder/builder.py @@ -0,0 +1,9 @@ +import os +from .oryx_vit import SigLIPViTAnysizeWrapper + +def build_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None)) + is_absolute_path_exists = os.path.exists(vision_tower) + print(f"Buiding OryxViTWrapper from {vision_tower}...") + # path = vision_tower.split(":")[1] + return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs) \ No newline at end of file diff --git a/opencompass/models/ola/model/multimodal_encoder/oryx_vit.py b/opencompass/models/ola/model/multimodal_encoder/oryx_vit.py new file mode 100644 index 00000000..1499dcee --- /dev/null +++ b/opencompass/models/ola/model/multimodal_encoder/oryx_vit.py @@ -0,0 +1,1075 @@ +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import ( + Callable, + Dict, + Final, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +from torch.utils.checkpoint import checkpoint +import torch +import torch.nn as nn +import torch.nn.functional as F +try: + from timm.layers import ( + AttentionPoolLatent, + DropPath, + LayerType, + Mlp, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed, + ) + from timm.models._manipulate import checkpoint_seq, named_apply +except: + print('Wrong timm version') + +from flash_attn import flash_attn_func, flash_attn_varlen_func + +from typing import Optional + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed +import os + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa: E741 + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with torch.no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.float() + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor_dtype = tensor_fp32.to(dtype=dtype) + tensor.copy_(tensor_dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if cu_slens is not None: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item() + x = flash_attn_varlen_func( + q.squeeze(0), + k.squeeze(0), + v.squeeze(0), + cu_seqlens_q=cu_slens, + cu_seqlens_k=cu_slens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=self.scale, + causal=False, + ) + + x = x.reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + else: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c + + x = x.reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + # if self.fused_attn: + # x = F.scaled_dot_product_attention( + # q, + # k, + # v, + # dropout_p=self.attn_drop.p if self.training else 0.0, + # ) + # else: + # q = q * self.scale + # attn = q @ k.transpose(-2, -1) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = attn @ v + + # x = x.transpose(1, 2).reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + strict_img_size: bool = False, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + add_patch2x2: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + strict_img_size=strict_img_size, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + + + # deepspeed.zero.register_external_parameter(self, self.pos_embed) + # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.weight) + # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.bias) + # print(self.patch_embed.state_dict().keys()) + + + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + + + if add_patch2x2: + if add_patch2x2 == 'v2': + self.downsample = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim*2, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(embed_dim*2, embed_dim*4, 1) + ) + else: + mid_dim = embed_dim * 2 + self.downsample = nn.Sequential( + nn.Conv2d(embed_dim, mid_dim, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(mid_dim, mid_dim, 1) + ) + + else: + self.downsample = None + + + # self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # # Classifier Head + # if global_pool == "map": + # AttentionPoolLatent.init_weights = init_weights + # self.attn_pool = AttentionPoolLatent( + # self.embed_dim, + # num_heads=num_heads, + # mlp_ratio=mlp_ratio, + # norm_layer=norm_layer, + # ) + # else: + # self.attn_pool = None + # self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + # self.head_drop = nn.Dropout(drop_rate) + # self.head = ( + # nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + # ) + + # if weight_init != "skip": + # self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def rescale_positional_embedding(self, out_size): + h, w = out_size + pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5) + if (h, w) == (pos_embed_shape, pos_embed_shape): + return self.pos_embed + rescaled_positional_embedding = \ + self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2]) + pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape) + pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w) + rescaled_positional_embedding[0] = pe_2d.T.contiguous() + return rescaled_positional_embedding + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features_list(self, x_list): + x_all = [] + image_sizes = [] + for x in x_list: + bs, _, h, w = x.shape + + # fix patch size=14 in datasets + pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0] + pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + bs, _, h, w = x.shape + + h = h // self.patch_embed.patch_size[0] + w = w // self.patch_embed.patch_size[1] + + x = self.patch_embed(x) + # x = self._pos_embed(x) + x = x + self.rescale_positional_embedding(out_size=(h, w)) + x = self.patch_drop(x) + x = self.norm_pre(x) + x_all.append(x) + image_sizes.append((h, w)) + + slen = [xi.size(1) for xi in x_all] + x = torch.cat(x_all, dim=1) + + cu_indices = [0, ] + for i in slen: + cu_indices.append(cu_indices[-1] + i) + + cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device) + for idx, blk in enumerate(self.blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, cu_slens, use_reentrant=True) + else: + x = blk(x, cu_slens=cu_slens) + feats = x.split(slen, dim=1) #[(1, slen, c)] + + if self.downsample is not None: + new_feats = [] + new_sizes = [] + for f, s in zip(feats, image_sizes): + h, w = s + b, n, c = f.size() + f = f.reshape(b, h, w, c).permute(0, 3, 1, 2) + f = self.downsample(f) + b, c, h, w = f.size() + f = f.permute(0, 2, 3, 1).reshape(b, h*w, c) + new_feats.append(f) + new_sizes.append((h, w)) + return new_feats, new_sizes + + + return feats, image_sizes + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + bs, _, h, w = x.shape + h = h // self.patch_embed.patch_size[0] + w = w // self.patch_embed.patch_size[1] + + x = self.patch_embed(x) + # x = self._pos_embed(x) + x = x + self.rescale_positional_embedding(out_size=(h, w)) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + if self.downsample is not None: + b, n, c = x.size() + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = self.downsample(x) + b, c, h, w = x.size() + x = x.permute(0, 2, 3, 1).reshape(b, h*w, c) + new_feats = x + new_sizes = (h, w) + return new_feats, new_sizes + + return x, (h, w) + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.norm(x) + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x, cal_attn_pool=False): + # import pdb;pdb.set_trace() + if type(x) is list: + x, image_sizes = self.forward_features_list(x) + return x, image_sizes, None + else: + x, image_sizes = self.forward_features(x) + return x, image_sizes, None + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 384, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'): + # interpolate position embedding + orig_size = 24 + new_size = 128 + pos_tokens = model.pos_embed + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True) + return model + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + path: str = "", + gradient_checkpointing: bool = False, + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + + + if 'patch2x2' or 'patch4x4' in path: + add_patch2x2 = True + else: + add_patch2x2 = False + + if 'patch4x4pool' in path or 'patch2x2from4x4' in path: + add_patch2x2 = 'v2' + + if FORCE_NO_DOWNSAMPLE: + add_patch2x2 = False + + model = VisionTransformer( + img_size=2048, + patch_size=16, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + dynamic_img_pad=False, + strict_img_size=False, + ignore_head=kwargs.get("ignore_head", False), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + add_patch2x2=add_patch2x2 + ) + + if gradient_checkpointing: + model.set_grad_checkpointing(True) + return model + +import os +if 'LOAD_VISION_EARLY' in os.environ: + print("LOAD_VISION_EARLY is set") + LOAD_VISION_EARLY = True +else: + LOAD_VISION_EARLY = False + +if 'VIT_WITH_GRAD' in os.environ: + print("VIT_WITH_GRAD is set") + VIT_WITH_GRAD = True +else: + VIT_WITH_GRAD = False + +if 'FIX_SIZE' in os.environ: + print("FIX_SIZE is set") + FIX_SIZE = True +else: + FIX_SIZE = False + +if 'ANYRES_SPLIT' in os.environ: + ANYRES_SPLIT = int(os.environ['ANYRES_SPLIT']) + print(f"ANYRES_SPLIT is set as {ANYRES_SPLIT}") +else: + ANYRES_SPLIT = None + + +if 'FORCE_NO_DOWNSAMPLE' in os.environ: + print("FORCE_NO_DOWNSAMPLE is set") + FORCE_NO_DOWNSAMPLE = True +else: + FORCE_NO_DOWNSAMPLE = False + +from transformers import CLIPImageProcessor +import torch.distributed as dist + +class SigLIPViTAnysizeWrapper(nn.Module): + def __init__(self, vision_tower, path, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.args = args + self.path = path + + self.select_layer = -1 + if self.select_layer < -1: self.select_layer += 1 + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + self.output_dim = 1152 + if not FORCE_NO_DOWNSAMPLE: + if 'patch2x2' or 'patch4x4' in path: + self.output_dim = 1152*2 + + if 'patch4x4pool' in path or 'patch2x2from4x4' in path: + self.output_dim = 1152*4 + + if not delay_load or LOAD_VISION_EARLY: + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + if self.args.mm_projector_type == "conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp_woconv": + self.image_processor.crop_size['height'] = 384 + self.image_processor.crop_size['width'] = 384 + self.image_processor.size['shortest_edge'] = 384 + print("Resizeing clip processor to 384...") + self.image_processor.image_mean = [0.5, 0.5, 0.5] + self.image_processor.image_std = [0.5, 0.5, 0.5] + print("Loading vision model...") + if VIT_WITH_GRAD: + self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384', + gradient_checkpointing=True) + self.vision_tower.train() + else: + self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384', + gradient_checkpointing=False) + for p in self.vision_tower.parameters(): + p.requires_grad = False + self.vision_tower.eval() + self.is_loaded = True + + def train(self, mode = True): + self.training = mode + + if self.is_loaded and not VIT_WITH_GRAD: + self.vision_tower.eval() + + def split_images(self, images, split_res=512, base_size=32): + split_images = [] + sub_images_info = [] + for image in images: + now_sub_images = [] + _, c, h, w = image.shape + if h * w <= split_res * split_res: + split_images.append(image) + sub_images_info.append( + ( + 1, 1, 1, h // base_size, w // base_size, [(0, h // base_size, 0, w // base_size)] + ) + ) + continue + nsplit_h = math.ceil(h / split_res) + nsplit_w = math.ceil(w / split_res) + sub_h = int(h / nsplit_h / base_size) * base_size + sub_w = int(w / nsplit_w / base_size) * base_size + crop_infos = [] + for i in range(nsplit_h): + for j in range(nsplit_w): + begin_h = i * sub_h + begin_w = j * sub_w + + if i == nsplit_h - 1: + end_h = h + else: + end_h = (i + 1) * sub_h + + if j == nsplit_w - 1: + end_w = w + else: + end_w = (j + 1) * sub_w + + assert (end_h - begin_h) % base_size == 0 and (end_w - begin_w) % base_size == 0 + + sub_image = image[:, :, begin_h:end_h, begin_w:end_w] + now_sub_images.append(sub_image) + crop_infos.append( + (begin_h // base_size, end_h // base_size, begin_w // base_size, end_w // base_size) + ) + + split_images += now_sub_images + sub_images_info.append( + ( + len(now_sub_images), nsplit_h, nsplit_w, h // base_size, w // base_size, crop_infos + ) + ) + + return split_images, sub_images_info + + + def unsplit_images(self, features, sizes, sub_images_info): + new_features = [] + for feature, size in zip(features, sizes): + h, w = size + new_features.append( + feature.reshape(1, h, w, -1) + ) + + fused_images = [] + images_sizes = [] + sub_count = 0 + for n_split, nsplit_h, nsplit_w, total_h, total_w, crop_infos in sub_images_info: + sub_features = new_features[sub_count:sub_count+n_split] + sub_count += n_split + + total_feature = new_features[0].new_zeros(1, total_h, total_w, self.hidden_size) + for feature, (begin_h, end_h, begin_w, end_w) in zip(sub_features, crop_infos): + total_feature[:, begin_h:end_h, begin_w:end_w] += feature + + fused_images.append(total_feature.reshape(1, total_h * total_w, self.hidden_size)) + images_sizes.append((total_h, total_w)) + + return fused_images, images_sizes + + + + def forward_func(self, images, force_fix_size=False, cal_attn_pool=False): + if type(images) is list: + xs = [x.to(self.dtype) for x in images] + image_features, img_size, cls_token = self.vision_tower(xs, cal_attn_pool=cal_attn_pool) + image_features = [x.to(images[0].dtype) for x in image_features] + + else: + image_forward_outs, img_size, cls_token = self.vision_tower(images.to(self.dtype), cal_attn_pool=cal_attn_pool) + image_features = image_forward_outs.to(images.dtype) + + return image_features, img_size, cls_token + + def forward(self, images, cal_attn_pool=False): + if VIT_WITH_GRAD: + image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool) + return image_features, img_size + else: + with torch.no_grad(): + image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool) + return image_features, img_size + + + @property + def dummy_feature(self): + return torch.zeros(1, 1152, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.pos_embed.dtype + + @property + def device(self): + return self.vision_tower.pos_embed.device + + @property + def hidden_size(self): + return self.output_dim + + @property + def config(self): + return type('LLaVAConfigWrapper', (), { + # 'image_size': 224, + 'patch_size': 16, + })() diff --git a/opencompass/models/ola/model/multimodal_projector/builder.py b/opencompass/models/ola/model/multimodal_projector/builder.py new file mode 100644 index 00000000..681d933b --- /dev/null +++ b/opencompass/models/ola/model/multimodal_projector/builder.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import re + +import math + +from .pooler_projector import NormalizedDwPooler +import os +import math + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": 'identity'} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + +class OlaMLP(nn.Module): + def __init__(self, in_channels, out_channels, twoview=False): + super().__init__() + + self.proj1 = nn.Linear(in_channels, out_channels) + self.proj2 = nn.Linear(out_channels, out_channels) + self.act = nn.GELU() + self.pooler = NormalizedDwPooler(out_channels) + + embed_std = 1 / math.sqrt(out_channels) + self.image_newline = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + self.image_begin = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + self.image_end = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + + if twoview: + self.image_sep = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + + def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'): + + if modalities in ['image', 'text']: + h, w = size + dtype = x.dtype + x = x.reshape(x.shape[0], h, w, -1) + x = self.proj1(x) + x = self.pooler(x, forward_type='2x') + x = self.act(x) + x = self.proj2(x) + + + b, h, w, c = x.shape + x = torch.cat([ + x, + self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype) + ], dim=2) + x = x.reshape(b, -1, c) + + if x2 is not None: + h2, w2 = size2 + x2 = x2.reshape(x2.shape[0], h2, w2, -1) + x2 = self.proj1(x2) + x2 = self.pooler(x2, forward_type='2x') + x2 = self.act(x2) + x2 = self.proj2(x2) + + b2, h2, w2, c2 = x2.shape + x2 = torch.cat([ + x2, + self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype) + ], dim=2) + x2 = x2.reshape(b, -1, c) + sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype) + x = torch.cat([x, sep, x2], dim=1) + + begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype) + end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype) + x = torch.cat([begin, x, end], dim=1) + return x + elif modalities in ['video']: + # x2 is the true feature, ignore x + h, w = size + dtype = x.dtype + x = x.reshape(x.shape[0], h, w, -1) + x1 = self.proj1(x) + x1 = self.pooler(x1, forward_type='2x') + x1 = self.proj2(x1).mean() * 0.0 + + h2, w2 = size2 + x2 = x2.reshape(x2.shape[0], h2, w2, -1) + x2 = self.proj1(x2) + x2 = self.pooler(x2, forward_type='2x') + x2 = self.act(x2) + x2 = self.proj2(x2) + + b2, h2, w2, c = x2.shape + x2 = torch.cat([ + x2, + self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype) + ], dim=2) + + x2 = x2.reshape(b2, -1, c) + + sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype) + x2 = torch.cat([x2, sep], dim=1) + + x2 = x2.flatten(0, 1) + + begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype) + end = self.image_end.reshape(1, -1).expand(1, c).to(dtype) + x2 = torch.cat([begin, x2, end], dim=0) + x2 = x2.unsqueeze(0) + return x2 + else: + raise ValueError(f'Unknown modalities: {modalities}') + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + elif projector_type == 'ola_mlp': + return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True) + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type) + if mlp_gelu_resnet_match: + mlp_depth = int(mlp_gelu_resnet_match.group(1)) + res_depth = int(mlp_gelu_resnet_match.group(2)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + for _ in range(res_depth): + modules.append(SimpleResBlock(config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') diff --git a/opencompass/models/ola/model/multimodal_projector/pooler_projector.py b/opencompass/models/ola/model/multimodal_projector/pooler_projector.py new file mode 100644 index 00000000..8e8aaaf7 --- /dev/null +++ b/opencompass/models/ola/model/multimodal_projector/pooler_projector.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from transformers.models.clip.modeling_clip import CLIPVisionModel +import os + +class PoolerProjector(nn.Module): + def __init__(self, config, vision_cfg): + super().__init__() + self._config = config + self.hw = vision_cfg.image_size // vision_cfg.patch_size + + self.conv_pool = nn.Conv2d( + config.mm_hidden_size, config.hidden_size, + kernel_size=2, stride=2 + ) + + self.proj = nn.Sequential( + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + + def forward(self, x, *args, **kwargs): + height = width = self.hw + assert height * width == x.shape[1] + x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) + x = self.conv_pool(x) + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + @property + def config(self): + return {"mm_projector_type": 'pooler'} + + +class NormalizedDwPooler(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.predictor = nn.Sequential( + nn.Linear(dim*2, dim), + nn.GELU(), + nn.Linear(dim, dim), + ) + + def forward(self, x, forward_type='2x'): + B, H, W, C = x.shape + + if forward_type == '2x': + new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C) + pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1) + fused_x = torch.cat([new_x, pooled_x], dim=-1) + elif forward_type == '1x': + new_x = x.reshape(B, H, W, 1, C) + fused_x = torch.cat([new_x, new_x], dim=-1) + elif forward_type == '4x': + new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C) + pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1) + fused_x = torch.cat([new_x, pooled_x], dim=-1) + + score = self.predictor(fused_x) + normalized_score = F.softmax(score, dim=-2) + new_x = (new_x * normalized_score).sum(dim=-2) + return new_x diff --git a/opencompass/models/ola/model/multimodal_resampler/builder.py b/opencompass/models/ola/model/multimodal_resampler/builder.py new file mode 100644 index 00000000..994f816c --- /dev/null +++ b/opencompass/models/ola/model/multimodal_resampler/builder.py @@ -0,0 +1,20 @@ +import torch + +class IdentityMap(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_resampler_type": None} + +def build_vision_resampler(model_args, delay_load=False, **kwargs): + # import pdb;pdb.set_trace() + resampler_type = getattr(model_args, 'mm_resampler_type', None) + if resampler_type is None: + return IdentityMap() + else: + raise ValueError(f'Unknown resampler type: {resampler_type}') diff --git a/opencompass/models/ola/model/ola_arch.py b/opencompass/models/ola/model/ola_arch.py new file mode 100644 index 00000000..c91db5d0 --- /dev/null +++ b/opencompass/models/ola/model/ola_arch.py @@ -0,0 +1,397 @@ +from abc import ABC, abstractmethod + +import torch + +from .speech_encoder.builder import build_speech_encoder +from .speech_projector.builder import build_speech_projector +from opencompass.models.ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX +from opencompass.models.ola.utils import lengths_to_padding_mask + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_resampler.builder import build_vision_resampler +from .multimodal_projector.builder import build_vision_projector + +from opencompass.models.ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +class OlaMetaModel: + + def __init__(self, config): + super(OlaMetaModel, self).__init__(config) + if hasattr(config, "speech_encoder"): + self.speech_encoder = build_speech_encoder(config) + self.speech_projector = build_speech_projector(config) + + if hasattr(config, "mm_vision_tower"): + self.vision_tower = build_vision_tower(config, delay_load=True) + self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) + self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) + + def get_speech_encoder(self): + speech_encoder = getattr(self, 'speech_encoder', None) + if type(speech_encoder) is list: + speech_encoder = speech_encoder[0] + return speech_encoder + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_speech_modules(self, model_args, fsdp=None): + self.config.speech_encoder = getattr(model_args, "speech_encoder", None) + self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) + self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear') + self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5) + self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280) + self.config.music_encoder = getattr(model_args, 'music_encoder', None) + + if self.get_speech_encoder() is None: + speech_encoder = build_speech_encoder(self.config) + if fsdp is not None and len(fsdp) > 0: + self.speech_encoder = [speech_encoder] + else: + self.speech_encoder = speech_encoder + + if getattr(self, 'speech_projector', None) is None: + self.speech_projector = build_speech_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.speech_projector.parameters(): + p.requires_grad = True + + if model_args.pretrain_speech_projector is not None: + pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + print('Loading pretrain speech projector weights') + + msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False) + print(msg) + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + + self.config.mm_vision_tower = vision_tower + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) + ## Get the mm_spatial_pool_mode and mm_spatial_pool_stride + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() # 已加载权重,故跳过 + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size) + + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + + if getattr(self, 'mm_projector', None) is None: + self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) + else: + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + print('Loading pretrain mm projector weights') + incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False) + print(incompatible_keys) + +class OlaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_speech_encoder(self): + return self.get_model().get_speech_encoder() + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_speech_projector(self): + return self.get_model().speech_projector + + def encode_speech(self, speech, speech_lengths, speech_wav): + # import pdb; pdb.set_trace() + speech_encoder_type = self.config.speech_encoder_type + speech_encoder = self.get_speech_encoder() + if "whisper" in speech_encoder_type.lower(): + encoder_outs = speech_encoder(speech.permute(0, 2, 1)) + speech_lengths = (speech_lengths + 1) // 2 + else: + encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav) + speech_lengths = (speech_lengths + 1) // 2 + speech_projector_type = self.config.speech_projector_type + speech_projector = self.get_speech_projector() + if speech_projector_type == "linear": + encoder_outs = speech_projector(encoder_outs) + speech_lengths = speech_lengths // speech_projector.k + else: + raise ValueError(f'Unknown speech projector: {speech_projector_type}') + # speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))] + return encoder_outs + + def prepare_inputs_labels_for_speech_vision_text( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None + ): + speech_encoder = self.get_speech_encoder() + vision_tower = self.get_vision_tower() + + if speech_encoder is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if vision_tower is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + # encode speech + if not isinstance(speech, list): + speech = torch.split(speech, speech_chunks.tolist(), dim=0) + speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0) + speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0) + speech_features = [] + for idx in range(len(speech)): + speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx])) + + # encode vision + if isinstance(modalities, str): + modalities = [modalities] + + video_idx_in_batch = [] + for modal in range(len(modalities)): + if 'video' in modalities[modal]: + video_idx_in_batch.append(modal) + + aimg = images[-1] + lowres_img = [] + for idx, img_feat in enumerate(images): + if idx in video_idx_in_batch: + img_feat = aimg.new(1, 3, 128, 128).fill_(0) + lowres_img.append(img_feat) + + lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img) + highres_img_features = [] + highres_img_sizes = [] + for idx, img_feat in enumerate(images_highres): + if img_feat.ndim == 5: + img_feat = img_feat.squeeze(1) + highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat) + highres_img_features.append(highres_img_feature) + highres_img_sizes.append(highres_img_size) + image_features = [] + for idx in range(len(modalities)): + img_feat = self.get_model().mm_projector(lowres_img_features[idx], + lowres_img_sizes[idx], + highres_img_features[idx], + highres_img_sizes[idx], + modalities[idx]) + image_features.append(img_feat.flatten(0, 1)) + + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + + num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + + if num_speech_images == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_images_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + cur_image_idx += 1 + continue + speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]] + + cur_input_ids_nospeech_image = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech_image = [] + for i in range(len(speech_image_token_indices) - 1): + cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image)) + cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_speech_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) + cur_new_labels.append(cur_labels_nospeech_image[i]) + if i < num_speech_images: + if i < num_images: + cur_images_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_images_features) + cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + else: + cur_speech_features = speech_features[cur_speech_idx] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + if num_images == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0) + cur_image_idx += 1 + + if num_speech == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0) + cur_speech_idx += 1 + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as speech features can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False \ No newline at end of file diff --git a/opencompass/models/ola/model/speech_encoder/beats/BEATs.py b/opencompass/models/ola/model/speech_encoder/beats/BEATs.py new file mode 100644 index 00000000..9bbf36a8 --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/beats/BEATs.py @@ -0,0 +1,182 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +# import torchaudio.compliance.kaldi as ta_kaldi + +from .kaldi import fbank as kaldi_fbank + +from .backbone import ( + TransformerEncoder, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class BEATsConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # label predictor + self.finetuned_model: bool = False # whether the model is a fine-tuned model. + self.predictor_dropout: float = 0.1 # dropout probability for the predictor + self.predictor_class: int = 527 # target class number for the predictor + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class BEATs(nn.Module): + def __init__( + self, + cfg: BEATsConfig, + ) -> None: + super().__init__() + logger.info(f"BEATs Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + if cfg.finetuned_model: + self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) + self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) + else: + self.predictor = None + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + feature_only=False, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + if not feature_only and self.predictor is not None: + x = self.predictor_dropout(x) + logits = self.predictor(x) + + if padding_mask is not None and padding_mask.any(): + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) + else: + logits = logits.mean(dim=1) + + lprobs = torch.sigmoid(logits) + + return lprobs, padding_mask + else: + return x, padding_mask \ No newline at end of file diff --git a/opencompass/models/ola/model/speech_encoder/beats/Tokenizers.py b/opencompass/models/ola/model/speech_encoder/beats/Tokenizers.py new file mode 100644 index 00000000..597c8902 --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/beats/Tokenizers.py @@ -0,0 +1,174 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +# import torchaudio.compliance.kaldi as ta_kaldi + +from .kaldi import fbank as kaldi_fbank + +from .backbone import ( + TransformerEncoder, +) +from .quantizer import ( + NormEMAVectorQuantizer, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TokenizersConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # quantizer + self.quant_n: int = 1024 # codebook number in quantizer + self.quant_dim: int = 256 # codebook dimension in quantizer + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class Tokenizers(nn.Module): + def __init__( + self, + cfg: TokenizersConfig, + ) -> None: + super().__init__() + logger.info(f"Tokenizers Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.quantize = NormEMAVectorQuantizer( + n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, + ) + self.quant_n = cfg.quant_n + self.quantize_layer = nn.Sequential( + nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), + nn.Tanh(), + nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize + ) + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_labels( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + quantize_input = self.quantize_layer(x) + quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) + + return embed_ind diff --git a/opencompass/models/ola/model/speech_encoder/beats/__init__.py b/opencompass/models/ola/model/speech_encoder/beats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/opencompass/models/ola/model/speech_encoder/beats/backbone.py b/opencompass/models/ola/model/speech_encoder/beats/backbone.py new file mode 100644 index 00000000..ef6ba72a --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/beats/backbone.py @@ -0,0 +1,782 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import numpy as np +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +import torch.nn.functional as F +from torch.nn import LayerNorm, Parameter +from .modules import ( + GradMultiply, + SamePad, + get_activation_fn, + GLU_Linear, + quant_noise, +) + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + deep_norm=args.deep_norm, + has_relative_attention_bias=self.relative_position_embedding, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + encoder_layers=args.encoder_layers, + ) + for i in range(args.encoder_layers) + ] + ) + if self.relative_position_embedding: + for i in range(1, args.encoder_layers): + del self.layers[i].self_attn.relative_attention_bias + self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + if args.deep_norm: + deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) + for i in range(args.encoder_layers): + nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) + + self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1) + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + if self.layer_wise_gradient_decay_ratio != 1.0: + x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + deep_norm: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + encoder_layers: int = 0, + ) -> None: + + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.final_layer_norm = LayerNorm(self.embedding_dim) + + self.deep_norm = deep_norm + if self.deep_norm: + self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) + else: + self.deep_norm_alpha = 1 + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual * self.deep_norm_alpha + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual * self.deep_norm_alpha + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + alpha = 32 + q *= 1 / alpha + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) + + attn_weights = attn_weights + attn_mask_rel_pos + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) \ No newline at end of file diff --git a/opencompass/models/ola/model/speech_encoder/beats/kaldi.py b/opencompass/models/ola/model/speech_encoder/beats/kaldi.py new file mode 100644 index 00000000..f97fa853 --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/beats/kaldi.py @@ -0,0 +1,813 @@ +import math +from typing import Tuple + +import torch +# import torchaudio +from torch import Tensor + +__all__ = [ + "get_mel_banks", + "inverse_mel_scale", + "inverse_mel_scale_scalar", + "mel_scale", + "mel_scale_scalar", + "spectrogram", + "fbank", + "mfcc", + "vtln_warp_freq", + "vtln_warp_mel_freq", +] + +# numeric_limits::epsilon() 1.1920928955078125e-07 +EPSILON = torch.tensor(torch.finfo(torch.float).eps) +# 1 milliseconds = 0.001 seconds +MILLISECONDS_TO_SECONDS = 0.001 + +# window types +HAMMING = "hamming" +HANNING = "hanning" +POVEY = "povey" +RECTANGULAR = "rectangular" +BLACKMAN = "blackman" +WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] + + +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + r"""Returns the smallest power of 2 that is greater than x""" + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: + r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) + representing how the window is shifted along the waveform. Each row is a frame. + + Args: + waveform (Tensor): Tensor of size ``num_samples`` + window_size (int): Frame length + window_shift (int): Frame shift + snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. + + Returns: + Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame + """ + assert waveform.dim() == 1 + num_samples = waveform.size(0) + strides = (window_shift * waveform.stride(0), waveform.stride(0)) + + if snip_edges: + if num_samples < window_size: + return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = torch.flip(waveform, [0]) + m = (num_samples + (window_shift // 2)) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' + # but we want [2, 1, 0, 0, 1, 2] + pad_left = reversed_waveform[-pad:] + waveform = torch.cat((pad_left, waveform, pad_right), dim=0) + else: + # pad is negative so we want to trim the waveform at the front + waveform = torch.cat((waveform[-pad:], pad_right), dim=0) + + sizes = (m, window_size) + return waveform.as_strided(sizes, strides) + + +def _feature_window_function( + window_type: str, + window_size: int, + blackman_coeff: float, + device: torch.device, + dtype: int, +) -> Tensor: + r"""Returns a window function with the given type and size""" + if window_type == HANNING: + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) + elif window_type == HAMMING: + return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) + elif window_type == POVEY: + # like hanning but goes to zero at edges + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) + elif window_type == RECTANGULAR: + return torch.ones(window_size, device=device, dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = torch.arange(window_size, device=device, dtype=dtype) + # can't use torch.blackman_window as they use different coefficients + return ( + blackman_coeff + - 0.5 * torch.cos(a * window_function) + + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function) + ).to(device=device, dtype=dtype) + else: + raise Exception("Invalid window type " + window_type) + + +def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*)""" + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + if energy_floor == 0.0: + return log_energy + return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties( + waveform: Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, int, int, int]: + r"""Gets the waveform and window properties""" + channel = max(channel, 0) + assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0)) + waveform = waveform[channel, :] # size (n) + window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) + window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) + padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size + + assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format( + window_size, len(waveform) + ) + assert 0 < window_shift, "`window_shift` must be greater than 0" + assert padded_window_size % 2 == 0, ( + "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`" + ) + assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]" + assert sample_frequency > 0, "`sample_frequency` must be greater than zero" + return waveform, window_shift, window_size, padded_window_size + + +def _get_window( + waveform: Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, Tensor]: + r"""Gets a window and its log energy + + Returns: + (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # size (m, window_size) + strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) + + if dither != 0.0: + rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype) + strided_input = strided_input + rand_gauss * dither + + if remove_dc_offset: + # Subtract each row/frame by its mean + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + if raw_energy: + # Compute the log energy of each row/frame before applying preemphasis and + # window function + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + if preemphasis_coefficient != 0.0: + # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j + offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze( + 0 + ) # size (m, window_size + 1) + strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] + + # Apply window_function to each row/frame + window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze( + 0 + ) # size (1, window_size) + strided_input = strided_input * window_function # size (m, window_size) + + # Pad columns with zero until we reach size (m, padded_window_size) + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = torch.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0 + ).squeeze(0) + + # Compute energy after window function (not the raw one) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: + # subtracts the column mean of the tensor size (m, n) if subtract_mean=True + # it returns size (m, n) + if subtract_mean: + col_means = torch.mean(tensor, dim=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + window_type: str = POVEY, +) -> Tensor: + r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's + compute-spectrogram-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A spectrogram identical to what Kaldi would output. The shape is + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0) + + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1, 2) + fft = torch.fft.rfft(strided_input) + + # Convert the FFT into a power spectrum + power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum[:, 0] = signal_log_energy + + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def inverse_mel_scale(mel_freq: Tensor) -> Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def vtln_warp_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: Tensor, +) -> Tensor: + r"""This computes a VTLN warping function that is not the same as HTK's one, + but has similar inputs (this function has the advantage of never producing + empty bins). + + This function computes a warp function F(freq), defined between low_freq + and high_freq inclusive, with the following properties: + F(low_freq) == low_freq + F(high_freq) == high_freq + The function is continuous and piecewise linear with two inflection + points. + The lower inflection point (measured in terms of the unwarped + frequency) is at frequency l, determined as described below. + The higher inflection point is at a frequency h, determined as + described below. + If l <= f <= h, then F(f) = f/vtln_warp_factor. + If the higher inflection point (measured in terms of the unwarped + frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + Since (by the last point) F(h) == h/vtln_warp_factor, then + max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + = vtln_high_cutoff * min(1, vtln_warp_factor). + If the lower inflection point (measured in terms of the unwarped + frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + = vtln_low_cutoff * max(1, vtln_warp_factor) + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + freq (Tensor): given frequency in Hz + + Returns: + Tensor: Freq after vtln warp + """ + assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq" + assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]" + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l # F(l) + Fh = scale * h # F(h) + assert l > low_freq and h < high_freq + # slope of left part of the 3-piece linear function + scale_left = (Fl - low_freq) / (l - low_freq) + # [slope of center part is just "scale"] + + # slope of right part of the 3-piece linear function + scale_right = (high_freq - Fh) / (high_freq - h) + + res = torch.empty_like(freq) + + outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq + before_l = torch.lt(freq, l) # freq < l + before_h = torch.lt(freq, h) # freq < h + after_h = torch.ge(freq, h) # freq >= h + + # order of operations matter here (since there is overlapping frequency regions) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + + return res + + +def vtln_warp_mel_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, + high_freq: float, + vtln_warp_factor: float, + mel_freq: Tensor, +) -> Tensor: + r""" + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + mel_freq (Tensor): Given frequency in Mel + + Returns: + Tensor: ``mel_freq`` after vtln warp + """ + return mel_scale( + vtln_warp_freq( + vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq) + ) + ) + + +def get_mel_banks( + num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float, +) -> Tuple[Tensor, Tensor]: + """ + Returns: + (Tensor, Tensor): The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). + """ + assert num_bins > 3, "Must have at least 3 mel bins" + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + + if high_freq <= 0.0: + high_freq += nyquist + + assert ( + (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq) + ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist) + + # fft-bin width [think of it as Nyquist-freq / half-window-length] + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # divide by num_bins+1 in next line because of end-effects where the bins + # spread out to the sides. + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + if vtln_high < 0.0: + vtln_high += nyquist + + assert vtln_warp_factor == 1.0 or ( + (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high) + ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format( + vtln_low, vtln_high, low_freq, high_freq + ) + + bin = torch.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) + center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) + + if vtln_warp_factor != 1.0: + left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel) + center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel) + right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel) + + center_freqs = inverse_mel_scale(center_mel) # size (num_bins) + # size(1, num_fft_bins) + mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) + + # size (num_bins, num_fft_bins) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + if vtln_warp_factor == 1.0: + # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values + bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + else: + # warping can move the order of left_mel, center_mel, right_mel anywhere + bins = torch.zeros_like(up_slope) + up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel + down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + + return bins, center_freqs + + +def fbank( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's + compute-fbank-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features + (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``) + use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0, device=device, dtype=dtype) + + # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1) + spectrum = torch.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.0) + + # size (num_mel_bins, padded_window_size // 2) + mel_energies, _ = get_mel_banks( + num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp + ) + mel_energies = mel_energies.to(device=device, dtype=dtype) + + # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) + + # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) + mel_energies = torch.mm(spectrum, mel_energies.T) + if use_log_fbank: + # avoid log of zero (which should be prevented anyway by dithering) + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # if use_energy then add it as the last column for htk_compat == true else first column + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) + # returns size (m, num_mel_bins + 1) + if htk_compat: + mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1) + else: + mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) + + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: + # returns a dct matrix of size (num_mel_bins, num_ceps) + # size (num_mel_bins, num_mel_bins) + dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho") + # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) + # this would be the first column in the dct_matrix for torchaudio as it expects a + # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi + # expects a left multiply e.g. dct_matrix * vector). + dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) + dct_matrix = dct_matrix[:, :num_ceps] + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: + # returns size (num_ceps) + # Compute liftering coefficients (scaling on cepstral coeffs) + # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. + i = torch.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) + + +def mfcc( + waveform: Tensor, + blackman_coeff: float = 0.42, + cepstral_lifter: float = 22.0, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + num_ceps: int = 13, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's + compute-mfcc-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible + features (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``"povey"``) + + Returns: + Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) + where m is calculated in _get_strided + """ + assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins) + + device, dtype = waveform.device, waveform.dtype + + # The mel_energies should not be squared (use_power=True), not have mean subtracted + # (subtract_mean=False), and use log (use_log_fbank=True). + # size (m, num_mel_bins + use_energy) + feature = fbank( + waveform=waveform, + blackman_coeff=blackman_coeff, + channel=channel, + dither=dither, + energy_floor=energy_floor, + frame_length=frame_length, + frame_shift=frame_shift, + high_freq=high_freq, + htk_compat=htk_compat, + low_freq=low_freq, + min_duration=min_duration, + num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, + snip_edges=snip_edges, + subtract_mean=False, + use_energy=use_energy, + use_log_fbank=True, + use_power=True, + vtln_high=vtln_high, + vtln_low=vtln_low, + vtln_warp=vtln_warp, + window_type=window_type, + ) + + if use_energy: + # size (m) + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + # offset is 0 if htk_compat==True else 1 + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset : (num_mel_bins + mel_offset)] + + # size (num_mel_bins, num_ceps) + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) + + # size (m, num_ceps) + feature = feature.matmul(dct_matrix) + + if cepstral_lifter != 0.0: + # size (1, num_ceps) + lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.to(device=device, dtype=dtype) + + # if use_energy then replace the last column for htk_compat == true else first column + if use_energy: + feature[:, 0] = signal_log_energy + + if htk_compat: + energy = feature[:, 0].unsqueeze(1) # size (m, 1) + feature = feature[:, 1:] # size (m, num_ceps - 1) + if not use_energy: + # scale on C0 (actually removing a scale we previously added that's + # part of one common definition of the cosine transform.) + energy *= math.sqrt(2) + + feature = torch.cat((feature, energy), dim=1) + + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/opencompass/models/ola/model/speech_encoder/beats/modules.py b/opencompass/models/ola/model/speech_encoder/beats/modules.py new file mode 100644 index 00000000..18e2d206 --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/beats/modules.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +import torch +from torch import Tensor, nn +import torch.nn.functional as F + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module diff --git a/opencompass/models/ola/model/speech_encoder/beats/quantizer.py b/opencompass/models/ola/model/speech_encoder/beats/quantizer.py new file mode 100644 index 00000000..704be4c3 --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/beats/quantizer.py @@ -0,0 +1,215 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on VQGAN code bases +# https://github.com/CompVis/taming-transformers +# --------------------------------------------------------' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as distributed + +try: + from einops import rearrange, repeat +except ImportError: + pass + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + self.decay = decay + self.eps = eps + if codebook_init_path == '': + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = l2norm(weight) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f"load init codebook weight from {codebook_init_path}") + codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + # self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + print("Performing Kemans init for codebook") + embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) + self.weight.data.copy_(embed_normalized) + + +def norm_ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(l2norm(moving_avg.data)) + + +class NormEMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(n_embed)) + if distributed.is_available() and distributed.is_initialized(): + print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") + self.all_reduce_fn = distributed.all_reduce + else: + self.all_reduce_fn = nn.Identity() + + def reset_cluster_size(self, device): + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + # z = rearrange(z, 'b c h w -> b h w c') + # z = z.transpose(1, 2) + z = l2norm(z) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # EMA cluster size + + bins = encodings.sum(0) + self.all_reduce_fn(bins) + + # self.embedding.cluster_size_ema_update(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + self.all_reduce_fn(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + + embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, + embed_normalized) + norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + # z_q = rearrange(z_q, 'b h w c -> b c h w') + # z_q = z_q.transpose(1, 2) + return z_q, loss, encoding_indices \ No newline at end of file diff --git a/opencompass/models/ola/model/speech_encoder/builder.py b/opencompass/models/ola/model/speech_encoder/builder.py new file mode 100644 index 00000000..75d9ea00 --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/builder.py @@ -0,0 +1,13 @@ +from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder +import torch.nn as nn + +def build_speech_encoder(config): + speech_encoder_type = getattr(config, 'speech_encoder_type', None) + if "whisper" in speech_encoder_type.lower(): + return WhisperWrappedEncoder.load(config) + elif "dual" in speech_encoder_type.lower(): + return DualWrappedEncoder(config) + elif "none" in speech_encoder_type.lower(): + return None + + raise ValueError(f'Unknown speech encoder: {speech_encoder_type}') diff --git a/opencompass/models/ola/model/speech_encoder/speech_encoder.py b/opencompass/models/ola/model/speech_encoder/speech_encoder.py new file mode 100644 index 00000000..4d9050f1 --- /dev/null +++ b/opencompass/models/ola/model/speech_encoder/speech_encoder.py @@ -0,0 +1,74 @@ +import types +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import WhisperFeatureExtractor +import whisper + +from opencompass.models.ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs + +class WhisperWrappedEncoder: + + @classmethod + def load(cls, model_config): + + def replace_layer_norm(module): + from whisper.model import LayerNorm + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder + replace_layer_norm(encoder) + return encoder + +class DualWrappedEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.whisper_model = self.load_whisper(config) + self.beats_model = self.load_beats(config) + + def load_whisper(cls, model_config): + + def replace_layer_norm(module): + from whisper.model import LayerNorm + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder + replace_layer_norm(encoder) + return encoder + + def load_beats(cls, model_config): + beats_path = model_config.music_encoder + print("Loading BEATs Model") + beats_ckpt = torch.load(beats_path, map_location='cpu') + beats_cfg = BEATsConfig(beats_ckpt['cfg']) + beats = BEATs(beats_cfg) + beats.load_state_dict(beats_ckpt['model']) + return beats + + def forward(self, x, raw_wav=None, audio_padding_mask=None): + with torch.no_grad(): + self.beats_model = self.beats_model.float() + speech_embeds = self.whisper_model(x) + audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True) + if audio_embeds.size(1) < speech_embeds.size(1): + audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) + elif audio_embeds.size(1) > speech_embeds.size(1): + speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) + speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) + speech_embeds = speech_embeds.to(torch.bfloat16) + return speech_embeds \ No newline at end of file diff --git a/opencompass/models/ola/model/speech_projector/builder.py b/opencompass/models/ola/model/speech_projector/builder.py new file mode 100644 index 00000000..bf55a32d --- /dev/null +++ b/opencompass/models/ola/model/speech_projector/builder.py @@ -0,0 +1,11 @@ +from .speech_projector import EncoderProjectorConcat + + +def build_speech_projector(config): + projector_type = getattr(config, 'speech_projector_type', 'linear') + if projector_type == 'linear': + return EncoderProjectorConcat(config) + elif projector_type == 'none': + return None + + raise ValueError(f'Unknown projector type: {projector_type}') diff --git a/opencompass/models/ola/model/speech_projector/speech_projector.py b/opencompass/models/ola/model/speech_projector/speech_projector.py new file mode 100644 index 00000000..1e015606 --- /dev/null +++ b/opencompass/models/ola/model/speech_projector/speech_projector.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import math + +class EncoderProjectorConcat(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.speech_encoder_ds_rate + self.encoder_dim = config.speech_encoder_hidden_size + self.llm_dim = config.hidden_size + self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(2048, config.hidden_size) + + embed_std = 1 / math.sqrt(config.hidden_size) + self.speech_newline = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + self.speech_begin = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + self.speech_end = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + x = torch.cat([ + x, + self.speech_newline.reshape(1, 1, -1).expand(batch_size, 1, -1).to(x.dtype) + ], dim=1) + begin = self.speech_begin.reshape(1, -1).to(x.dtype) + end = self.speech_end.reshape(1, -1).to(x.dtype) + x = x.flatten(0, 1) + x = torch.cat([begin, x, end], dim=0) + # x = x.flatten(0, 1) + return x \ No newline at end of file diff --git a/opencompass/models/ola/utils.py b/opencompass/models/ola/utils.py new file mode 100644 index 00000000..0d7e6e07 --- /dev/null +++ b/opencompass/models/ola/utils.py @@ -0,0 +1,213 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import torch +import logging +import logging.handlers +import transformers + +from opencompass.models.ola.constants import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True, encoding='UTF-8') + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def lengths_to_mask(lens): + return ~lengths_to_padding_mask(lens) + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" \ No newline at end of file diff --git a/opencompass/models/ola_model.py b/opencompass/models/ola_model.py new file mode 100644 index 00000000..4f628c54 --- /dev/null +++ b/opencompass/models/ola_model.py @@ -0,0 +1,141 @@ +import os +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['PAD2STRIDE'] = '1' + +from opencompass.models.base import BaseModel +from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union +import numpy as np +import torch + +from opencompass.models.base import BaseModel, LMTemplateParser +from opencompass.utils.prompt import PromptList +PromptType = Union[PromptList, str] +import sys + +import torch +import re +from PIL import Image +import numpy as np +import transformers +from typing import Dict, Optional, Sequence, List +from opencompass.models.ola.conversation import conv_templates, SeparatorStyle +from opencompass.models.ola.model.builder import load_pretrained_model +from opencompass.models.ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token +from opencompass.models.ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image +from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX +import argparse +import copy + +class OlaModel(BaseModel): + def __init__(self, + path: str, + max_seq_len: int = 2048, + tokenizer_path: Optional[str] = None, + model_config: Optional[str] = None, + meta_template: Optional[Dict] = None): + + + self.template_parser = LMTemplateParser(meta_template) + self.eos_token_id = None + if meta_template and 'eos_token_id' in meta_template: + self.eos_token_id = meta_template['eos_token_id'] + + + tokenizer, model, _, _ = load_pretrained_model(path, None) + model = model.to('cuda').eval() + model = model.bfloat16() + self.tokenizer=tokenizer + self.model=model + self.gen_kwargs = { + "max_new_tokens":1024, + "temperature":0.2, + "top_p":None, + "num_beams":1, + } + + def generate(self, inputs: List[str], max_out_len: int) -> List[str]: + assert len(inputs)==1 # batch=1 + image_path = None + audio_path = None + video_path = None + text = inputs[0] + + + images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] + images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] + image_sizes = [(224, 224)] + + + + USE_SPEECH=False + speechs = [] + speech_lengths = [] + speech_wavs = [] + speech_chunks = [] + speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')] + speech_lengths = [torch.LongTensor([3000]).to('cuda')] + speech_wavs = [torch.zeros([1, 480000]).to('cuda')] + speech_chunks = [torch.LongTensor([1]).to('cuda')] + + + conv_mode = "qwen_1_5" + if text: + qs = text + else: + qs = '' + conv = conv_templates[conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') + + pad_token_ids = 151643 + + attention_masks = input_ids.ne(pad_token_ids).long().to('cuda') + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + + + + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + images=images, + images_highres=images_highres, + image_sizes=image_sizes, + modalities=['text'], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if self.gen_kwargs["temperature"] > 0 else False, + temperature=self.gen_kwargs["temperature"], + top_p=self.gen_kwargs["top_p"], + num_beams=self.gen_kwargs["num_beams"], + max_new_tokens=self.gen_kwargs["max_new_tokens"], + ) + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) + out=[] + for output in outputs: + output = output.strip() + if output.endswith(stop_str): + output = output[:-len(stop_str)] + out.append(output) + print(f"prompt---->",prompt) + print(f"out---->",out) + print(f"\n") + return out \ No newline at end of file