[Model] Add new model: Ola

This commit is contained in:
Mr.Li 2025-03-04 23:10:00 +08:00
parent fff2d51440
commit 37b894d4a1
31 changed files with 5810 additions and 0 deletions

10
examples/eval_ola.py Normal file
View File

@ -0,0 +1,10 @@
from mmengine.config import read_base
with read_base():
from opencompass.configs.datasets.demo.demo_gsm8k_chat_gen import gsm8k_datasets
from opencompass.configs.datasets.demo.demo_math_chat_gen import math_datasets
from opencompass.configs.models.ola.ola import models as ola_models
datasets = math_datasets
models = ola_models

View File

@ -0,0 +1,12 @@
from opencompass.models import OlaModel
models = [
dict(
type=OlaModel,
path="THUdyh/Ola-7b",
max_seq_len=2048,
abbr='ola',
max_out_len=1024,
batch_size=1,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -35,6 +35,7 @@ from .openai_api import OpenAI # noqa: F401
from .openai_api import OpenAISDK # noqa: F401
from .pangu_api import PanGu # noqa: F401
from .qwen_api import Qwen # noqa: F401
from .ola_model import OlaModel # noqa: F401
from .rendu_api import Rendu # noqa: F401
from .sensetime_api import SenseTime # noqa: F401
from .stepfun_api import StepFun # noqa: F401

View File

@ -0,0 +1,65 @@
import transformers
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_speech_projector: bool = field(default=False)
tune_speech_encoder: bool = field(default=False)
tune_speech_generator_only: bool = field(default=False)
speech_encoder_type: Optional[str] = field(default=None)
speech_encoder: Optional[str] = field(default=None)
pretrain_speech_projector: Optional[str] = field(default=None)
speech_projector_type: Optional[str] = field(default='linear')
speech_encoder_ds_rate: int = 5
speech_encoder_hidden_size: int = 1280
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
is_multimodal: bool = False
input_type: str = field(default="mel")
speech_normalize: bool = False
mel_size: int = 128
has_tgt_units: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
freeze_speech_projector: bool = field(default=False)
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
speech_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)

View File

@ -0,0 +1,14 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
SPEECH_TOKEN_INDEX = -200
DEFAULT_SPEECH_TOKEN = "<speech>"
IMAGE_TOKEN_INDEX= -300
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"

View File

@ -0,0 +1,254 @@
import dataclasses
from enum import auto, Enum
from typing import List, Any, Union, Tuple
import base64
from io import BytesIO
from PIL import Image
class SeparatorStyle(Enum):
"""Different separator style."""
TWO = auto()
PLAIN = auto()
CHATML = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
QWEN2 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
sep: str = "###"
sep2: str = None
version: str = "Unknown"
tokenizer_id: str = ""
tokenizer: Any = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message = message[0]
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.LLAMA_3:
wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
ret = "<|begin_of_text|>" + wrap_sys(self.system)
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message = message[0]
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
ret += message.strip() + self.sep2
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if self.system == "" else self.system + self.sep + "\n"
for role, message in messages:
if message:
if type(message) is tuple:
raise ValueError("Tuple not supported in CHATML")
message, images = message
message = "<speech>" * len(images) + message
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.QWEN2:
start = '<|im_start|>'
end = '<|im_end|>\n'
ret = start + 'system\n' + self.system + end
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
if message.endswith('<|endoftext|>'):
message = message.replace('<|endoftext|>', '')
ret += start + role + "\n" + message + end + '<|endoftext|>'
else:
assert not '<|endoftext|>' in message, f"Invalid message: {message}"
ret += start + role + "\n" + message + end
else:
ret += start + role + "\n"
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, speech = msg
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llama_2 = Conversation(
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_llama_3 = Conversation(
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
roles=("user", "assistant"),
version="llama_v3",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="",
sep2="<|eot_id|>"
)
conv_qwen_v1 = Conversation(
system="You are a helpful assistant.",
roles=("user", "assistant"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.QWEN2,
)
conv_plain = Conversation(
system="",
roles=("", ""),
messages=(
),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="</s>",
)
conv_qwen = Conversation(
system="""<|im_start|>system
You are a helpful assistant.""",
roles=("<|im_start|>user", "<|im_start|>assistant"),
version="qwen",
messages=[],
offset=0,
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
default_conversation = conv_llama_3
conv_templates = {
"v1": conv_vicuna_v1,
"plain": conv_plain,
"llama_2": conv_llama_2,
"llama_3": conv_llama_3,
'v1_qwen2': conv_qwen_v1,
"qwen_1_5": conv_qwen,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())

View File

@ -0,0 +1,231 @@
import copy
import torch
import transformers
import tokenizers
from typing import Dict, Sequence
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX
from opencompass.models.ola import conversation as conversation_lib
from opencompass.models.ola.model import *
from opencompass.models.ola.arguments import DataArguments
from opencompass.models.ola.constants import SPEECH_TOKEN_INDEX
from packaging import version
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech><image>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>\nUser's question in speech: <speech>\n")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
nl_tokens = tokenizer("\n").input_ids
special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens]
for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def preprocess_v1(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_speech: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_speech:
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_speech:
round_len = len(tokenizer_speech_token(rou, tokenizer))
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# FIXME: tokenizer bug
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_SPEECH_TOKEN in source[0]['value']
source[0]['value'] = DEFAULT_SPEECH_TOKEN
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_speech: bool = False
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_speech=has_speech)
raise NotImplementedError

View File

@ -0,0 +1,271 @@
from PIL import Image
import base64
import math
import ast
import torch
from transformers import StoppingCriteria
import os
import io
if 'VIDEO_RESIZE' in os.environ:
# highresxpatch
VIDEO_RESIZE = os.environ['VIDEO_RESIZE']
video_base, video_ps = VIDEO_RESIZE.split('x')
video_base = int(video_base)
video_ps = int(video_ps)
print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}")
else:
HIGHRES_BASE = None
if 'HIGHRES_BASE' in os.environ:
# highresxpatch
HIGHRES_BASE = os.environ['HIGHRES_BASE']
highres_base, highres_ps = HIGHRES_BASE.split('x')
highres_base = int(highres_base)
highres_ps = int(highres_ps)
print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}")
else:
HIGHRES_BASE = None
if 'MAXRES' in os.environ:
# highresxpatch
MAXRES = int(os.environ['MAXRES'])
print(f"MAXRES is set as {MAXRES}")
else:
MAXRES = 1536
if 'MINRES' in os.environ:
# highresxpatch
MINRES = int(os.environ['MINRES'])
print(f"MINRES is set as {MINRES}")
else:
MINRES = 0
if 'VIDEO_MAXRES' in os.environ:
# highresxpatch
VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES'])
print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}")
else:
VIDEO_MAXRES = 1536
if 'VIDEO_MINRES' in os.environ:
# highresxpatch
VIDEO_MINRES = int(os.environ['VIDEO_MINRES'])
print(f"VIDEO_MINRES is set as {VIDEO_MINRES}")
else:
MINRES = 0
if 'PAD2STRIDE' in os.environ:
# highresxpatch
PAD2STRIDE = True
print(f"PAD2STRIDE is set")
else:
PAD2STRIDE = False
if 'LOWRES_RESIZE' in os.environ:
LOWRES_RESIZE = os.environ['LOWRES_RESIZE']
print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}")
if 'x' in LOWRES_RESIZE:
size, ps = LOWRES_RESIZE.split('x')
size = int(size)
ps = int(ps)
LOWRES_RESIZE = (size, ps)
else:
LOWRES_RESIZE = int(LOWRES_RESIZE)
else:
LOWRES_RESIZE = None
def pad_image(image, target_resolution, value=0):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new('RGB', (target_width, target_height), (value, value, value))
paste_x = (target_width - original_width) // 2
paste_y = (target_height - original_height) // 2
new_image.paste(image, (paste_x, paste_y))
return new_image
def resize_images(image, patch_size=14, base_size=896):
h, w = image.size
if base_size == 0:
if h * w > MAXRES * MAXRES:
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
scale = MAXRES * MAXRES / (h * w)
scale = math.sqrt(scale)
elif h * w < MINRES * MINRES:
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
scale = MINRES * MINRES / (h * w)
scale = math.sqrt(scale)
else:
scale = None
else:
scale = base_size * base_size / (h * w)
scale = math.sqrt(scale)
if scale is not None:
new_h = int(h * scale / patch_size) * patch_size
new_w = int(w * scale / patch_size) * patch_size
new_h = max(new_h, patch_size)
new_w = max(new_w, patch_size)
image = image.resize((new_h, new_w))
elif PAD2STRIDE:
if h % patch_size == 0:
new_h = h
else:
new_h = (h // patch_size + 1) * patch_size
if w % patch_size == 0:
new_w = w
else:
new_w = (w // patch_size + 1) * patch_size
image = pad_image(image, (new_h, new_w), value=127)
else:
scale = 1.0
new_h = int(h * scale / patch_size) * patch_size
new_w = int(w * scale / patch_size) * patch_size
new_h = max(new_h, patch_size)
new_w = max(new_w, patch_size)
image = image.resize((new_h, new_w))
return image
def resize_video(image, patch_size=14, base_size=896):
h, w = image.size
if base_size == 0:
if h * w > VIDEO_MAXRES * VIDEO_MAXRES:
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w)
scale = math.sqrt(scale)
elif h * w < VIDEO_MINRES * VIDEO_MINRES:
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
scale = VIDEO_MINRES * VIDEO_MINRES / (h * w)
scale = math.sqrt(scale)
else:
scale = None
else:
scale = base_size * base_size / (h * w)
scale = math.sqrt(scale)
if scale is not None:
new_h = int(h * scale / patch_size) * patch_size
new_w = int(w * scale / patch_size) * patch_size
image = image.resize((new_h, new_w))
elif PAD2STRIDE:
if h % patch_size == 0:
new_h = h
else:
new_h = (h // patch_size + 1) * patch_size
if w % patch_size == 0:
new_w = w
else:
new_w = (w // patch_size + 1) * patch_size
image = pad_image(image, (new_h, new_w), value=127)
else:
scale = 1.0
new_h = int(h * scale / patch_size) * patch_size
new_w = int(w * scale / patch_size) * patch_size
image = image.resize((new_h, new_w))
return image
def process_anyres_video(image, processor):
if VIDEO_RESIZE is not None:
image = resize_video(image, patch_size=video_ps, base_size=video_base)
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
return image.unsqueeze(0)
else:
raise ValueError("VIDEO_RESIZE is not set")
def process_anyres_highres_image(image, processor):
processor2 = None
if type(processor) is tuple:
processor, processor2 = processor[0], processor[1]
if HIGHRES_BASE is not None:
image = resize_images(image, patch_size=highres_ps, base_size=highres_base)
if processor2 is not None:
image_original_resize = image.resize((processor2.size['shortest_edge'], processor.size['shortest_edge']))
image_patches = [image_original_resize] + [image_original_resize]
image_patches = [processor2.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
for image_patch in image_patches]
else:
if LOWRES_RESIZE is not None:
if type(LOWRES_RESIZE) is int:
image_original_resize = resize_images(image, patch_size=14, base_size=LOWRES_RESIZE)
else:
image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0])
else:
image_original_resize = image.resize((336, 336))
image_patches = [image_original_resize]
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
for image_patch in image_patches]
image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0)
def read_image_patch(patch_info):
if 'img_path' in patch_info.keys():
image = Image.open(patch_info['img_path']).convert('RGB')
else:
if 'image_encoing' in patch_info.keys():
patch_info['image_encoding'] = patch_info['image_encoing']
image_file_name = patch_info['patch']
start_bytes = int(patch_info['start_num'])
file_size = int(patch_info['size'])
with open(image_file_name, 'rb') as f:
f.seek(start_bytes)
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
else:
image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
return image
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False

View File

@ -0,0 +1 @@
from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen

View File

@ -0,0 +1,91 @@
import os
import warnings
import shutil
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from opencompass.models.ola.model import *
from opencompass.models.ola.model.speech_encoder.builder import build_speech_encoder
def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs):
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.bfloat16
if use_flash_attn:
kwargs['attn_implementation'] = 'flash_attention_2'
model_cls = OlaQwenForCausalLM
# Load Ola model
if is_lora:
assert model_base is not None, "model_base is required for LoRA models."
from ola.model.language_model.ola_qwen import OlaConfigQwen
lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
print('Loading Ola from base model...')
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
print('Loading additional Ola weights...')
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)
from peft import PeftModel
print('Loading LoRA weights...')
model = PeftModel.from_pretrained(model, model_path)
print('Merging LoRA weights...')
model = model.merge_and_unload()
print('Model is loaded...')
elif model_base is not None:
print('Loading Ola from base model...')
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu')
speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
model.load_state_dict(speech_projector_weights, strict=False)
model = model.to(device=device)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = model_cls.from_pretrained(
model_path,
low_cpu_mem_usage=False,
**kwargs
)
model = model.to(device=device)
model.get_model().speech_encoder = build_speech_encoder(model.config)
model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
image_processor = None
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
print("Loading vision tower...")
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device)
if device != "auto":
vision_tower.to(device="cuda", dtype=torch.bfloat16)
else:
vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
image_processor = vision_tower.image_processor
print("Loading vision tower succeeded.")
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 16384
return tokenizer, model, image_processor, context_len

View File

@ -0,0 +1,237 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
class OlaConfigQwen(Qwen2Config):
model_type = "ola_qwen"
class OlaQwenModel(OlaMetaModel, Qwen2Model):
config_class = OlaConfigQwen
def __init__(self, config: Qwen2Config):
super(OlaQwenModel, self).__init__(config)
class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM):
config_class = OlaConfigQwen
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
config.rope_scaling = None
self.model = OlaQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
speech: Optional[torch.FloatTensor] = None,
speech_lengths: Optional[torch.LongTensor] = None,
speech_chunks: Optional[torch.LongTensor] = None,
speech_wav: Optional[torch.FloatTensor] = None,
images: Optional[torch.FloatTensor] = None,
images_highres: Optional[List[torch.FloatTensor]] = None,
image_sizes: Optional[List[List[int]]] = None,
modalities: Optional[List[str]] = ["image"],
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_speech_vision_text(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
speech,
speech_lengths,
speech_chunks,
speech_wav,
images,
modalities,
image_sizes,
images_highres
)
if labels is None:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
else:
return self.forward_llm_efficient(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_dim = hidden_states.size(-1)
shift_labels = labels[..., 1:].contiguous().reshape(-1)
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
assert shift_labels.size(0) == shift_hidden_states.size(0)
mask = shift_labels > -1
assert mask.float().sum() > 0
shift_labels = shift_labels[mask]
shift_hidden_states = shift_hidden_states[mask, :]
logits = self.lm_head(shift_hidden_states)
logits = logits.float()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
speech: Optional[torch.Tensor] = None,
speech_lengths: Optional[torch.Tensor] = None,
speech_chunks: Optional[torch.Tensor] = None,
speech_wav: Optional[torch.FloatTensor] = None,
images: Optional[torch.Tensor] = None,
images_highres: Optional[List[torch.FloatTensor]] = None,
image_sizes: Optional[torch.Tensor] = None,
modalities: Optional[List[str]] = ["image"],
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_speech_vision_text(
inputs,
position_ids,
attention_mask,
None,
None,
speech,
speech_lengths,
speech_chunks,
speech_wav,
images,
modalities,
image_sizes,
images_highres
)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
speech = kwargs.pop("speech", None)
speech_lengths = kwargs.pop("speech_lengths", None)
speech_chunks = kwargs.pop("speech_chunks", None)
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if speech is not None:
inputs['speech'] = speech
inputs['speech_lengths'] = speech_lengths
inputs['speech_chunks'] = speech_chunks
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("ola_qwen", OlaConfigQwen)
AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM)

View File

@ -0,0 +1,9 @@
import os
from .oryx_vit import SigLIPViTAnysizeWrapper
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None))
is_absolute_path_exists = os.path.exists(vision_tower)
print(f"Buiding OryxViTWrapper from {vision_tower}...")
# path = vision_tower.split(":")[1]
return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,172 @@
import torch
import torch.nn as nn
import re
import math
from .pooler_projector import NormalizedDwPooler
import os
import math
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class OlaMLP(nn.Module):
def __init__(self, in_channels, out_channels, twoview=False):
super().__init__()
self.proj1 = nn.Linear(in_channels, out_channels)
self.proj2 = nn.Linear(out_channels, out_channels)
self.act = nn.GELU()
self.pooler = NormalizedDwPooler(out_channels)
embed_std = 1 / math.sqrt(out_channels)
self.image_newline = nn.Parameter(
torch.randn(out_channels) * embed_std
)
self.image_begin = nn.Parameter(
torch.randn(out_channels) * embed_std
)
self.image_end = nn.Parameter(
torch.randn(out_channels) * embed_std
)
if twoview:
self.image_sep = nn.Parameter(
torch.randn(out_channels) * embed_std
)
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'):
if modalities in ['image', 'text']:
h, w = size
dtype = x.dtype
x = x.reshape(x.shape[0], h, w, -1)
x = self.proj1(x)
x = self.pooler(x, forward_type='2x')
x = self.act(x)
x = self.proj2(x)
b, h, w, c = x.shape
x = torch.cat([
x,
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
], dim=2)
x = x.reshape(b, -1, c)
if x2 is not None:
h2, w2 = size2
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
x2 = self.proj1(x2)
x2 = self.pooler(x2, forward_type='2x')
x2 = self.act(x2)
x2 = self.proj2(x2)
b2, h2, w2, c2 = x2.shape
x2 = torch.cat([
x2,
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
], dim=2)
x2 = x2.reshape(b, -1, c)
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
x = torch.cat([x, sep, x2], dim=1)
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
x = torch.cat([begin, x, end], dim=1)
return x
elif modalities in ['video']:
# x2 is the true feature, ignore x
h, w = size
dtype = x.dtype
x = x.reshape(x.shape[0], h, w, -1)
x1 = self.proj1(x)
x1 = self.pooler(x1, forward_type='2x')
x1 = self.proj2(x1).mean() * 0.0
h2, w2 = size2
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
x2 = self.proj1(x2)
x2 = self.pooler(x2, forward_type='2x')
x2 = self.act(x2)
x2 = self.proj2(x2)
b2, h2, w2, c = x2.shape
x2 = torch.cat([
x2,
self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype)
], dim=2)
x2 = x2.reshape(b2, -1, c)
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype)
x2 = torch.cat([x2, sep], dim=1)
x2 = x2.flatten(0, 1)
begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype)
end = self.image_end.reshape(1, -1).expand(1, c).to(dtype)
x2 = torch.cat([begin, x2, end], dim=0)
x2 = x2.unsqueeze(0)
return x2
else:
raise ValueError(f'Unknown modalities: {modalities}')
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
elif projector_type == 'ola_mlp':
return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type)
if mlp_gelu_resnet_match:
mlp_depth = int(mlp_gelu_resnet_match.group(1))
res_depth = int(mlp_gelu_resnet_match.group(2))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
for _ in range(res_depth):
modules.append(SimpleResBlock(config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')

View File

@ -0,0 +1,67 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.models.clip.modeling_clip import CLIPVisionModel
import os
class PoolerProjector(nn.Module):
def __init__(self, config, vision_cfg):
super().__init__()
self._config = config
self.hw = vision_cfg.image_size // vision_cfg.patch_size
self.conv_pool = nn.Conv2d(
config.mm_hidden_size, config.hidden_size,
kernel_size=2, stride=2
)
self.proj = nn.Sequential(
nn.GELU(),
nn.Linear(config.hidden_size, config.hidden_size),
)
def forward(self, x, *args, **kwargs):
height = width = self.hw
assert height * width == x.shape[1]
x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
x = self.conv_pool(x)
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
@property
def config(self):
return {"mm_projector_type": 'pooler'}
class NormalizedDwPooler(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.predictor = nn.Sequential(
nn.Linear(dim*2, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
def forward(self, x, forward_type='2x'):
B, H, W, C = x.shape
if forward_type == '2x':
new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
fused_x = torch.cat([new_x, pooled_x], dim=-1)
elif forward_type == '1x':
new_x = x.reshape(B, H, W, 1, C)
fused_x = torch.cat([new_x, new_x], dim=-1)
elif forward_type == '4x':
new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
fused_x = torch.cat([new_x, pooled_x], dim=-1)
score = self.predictor(fused_x)
normalized_score = F.softmax(score, dim=-2)
new_x = (new_x * normalized_score).sum(dim=-2)
return new_x

View File

@ -0,0 +1,20 @@
import torch
class IdentityMap(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_resampler_type": None}
def build_vision_resampler(model_args, delay_load=False, **kwargs):
# import pdb;pdb.set_trace()
resampler_type = getattr(model_args, 'mm_resampler_type', None)
if resampler_type is None:
return IdentityMap()
else:
raise ValueError(f'Unknown resampler type: {resampler_type}')

View File

@ -0,0 +1,397 @@
from abc import ABC, abstractmethod
import torch
from .speech_encoder.builder import build_speech_encoder
from .speech_projector.builder import build_speech_projector
from opencompass.models.ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX
from opencompass.models.ola.utils import lengths_to_padding_mask
from .multimodal_encoder.builder import build_vision_tower
from .multimodal_resampler.builder import build_vision_resampler
from .multimodal_projector.builder import build_vision_projector
from opencompass.models.ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
class OlaMetaModel:
def __init__(self, config):
super(OlaMetaModel, self).__init__(config)
if hasattr(config, "speech_encoder"):
self.speech_encoder = build_speech_encoder(config)
self.speech_projector = build_speech_projector(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
def get_speech_encoder(self):
speech_encoder = getattr(self, 'speech_encoder', None)
if type(speech_encoder) is list:
speech_encoder = speech_encoder[0]
return speech_encoder
def get_vision_tower(self):
vision_tower = getattr(self, 'vision_tower', None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def initialize_speech_modules(self, model_args, fsdp=None):
self.config.speech_encoder = getattr(model_args, "speech_encoder", None)
self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None)
self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear')
self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5)
self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280)
self.config.music_encoder = getattr(model_args, 'music_encoder', None)
if self.get_speech_encoder() is None:
speech_encoder = build_speech_encoder(self.config)
if fsdp is not None and len(fsdp) > 0:
self.speech_encoder = [speech_encoder]
else:
self.speech_encoder = speech_encoder
if getattr(self, 'speech_projector', None) is None:
self.speech_projector = build_speech_projector(self.config)
else:
# In case it is frozen by LoRA
for p in self.speech_projector.parameters():
p.requires_grad = True
if model_args.pretrain_speech_projector is not None:
pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
print('Loading pretrain speech projector weights')
msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False)
print(msg)
def initialize_vision_modules(self, model_args, fsdp=None):
vision_tower = model_args.vision_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
self.config.mm_vision_tower = vision_tower
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
## Get the mm_spatial_pool_mode and mm_spatial_pool_stride
for k, v in vision_resampler.config.items():
setattr(self.config, k, v)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
self.vision_resampler = [vision_resampler]
else:
self.vision_tower = vision_tower
self.vision_resampler = vision_resampler
else:
if fsdp is not None and len(fsdp) > 0:
vision_resampler = self.vision_resampler[0]
vision_tower = self.vision_tower[0]
else:
vision_resampler = self.vision_resampler
vision_tower = self.vision_tower
vision_tower.load_model() # 已加载权重,故跳过
# In case it is frozen by LoRA
for p in self.vision_resampler.parameters():
p.requires_grad = True
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size)
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
else:
for p in self.mm_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
print('Loading pretrain mm projector weights')
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False)
print(incompatible_keys)
class OlaMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_speech_encoder(self):
return self.get_model().get_speech_encoder()
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def get_speech_projector(self):
return self.get_model().speech_projector
def encode_speech(self, speech, speech_lengths, speech_wav):
# import pdb; pdb.set_trace()
speech_encoder_type = self.config.speech_encoder_type
speech_encoder = self.get_speech_encoder()
if "whisper" in speech_encoder_type.lower():
encoder_outs = speech_encoder(speech.permute(0, 2, 1))
speech_lengths = (speech_lengths + 1) // 2
else:
encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav)
speech_lengths = (speech_lengths + 1) // 2
speech_projector_type = self.config.speech_projector_type
speech_projector = self.get_speech_projector()
if speech_projector_type == "linear":
encoder_outs = speech_projector(encoder_outs)
speech_lengths = speech_lengths // speech_projector.k
else:
raise ValueError(f'Unknown speech projector: {speech_projector_type}')
# speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))]
return encoder_outs
def prepare_inputs_labels_for_speech_vision_text(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None
):
speech_encoder = self.get_speech_encoder()
vision_tower = self.get_vision_tower()
if speech_encoder is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if vision_tower is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
# encode speech
if not isinstance(speech, list):
speech = torch.split(speech, speech_chunks.tolist(), dim=0)
speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0)
speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0)
speech_features = []
for idx in range(len(speech)):
speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx]))
# encode vision
if isinstance(modalities, str):
modalities = [modalities]
video_idx_in_batch = []
for modal in range(len(modalities)):
if 'video' in modalities[modal]:
video_idx_in_batch.append(modal)
aimg = images[-1]
lowres_img = []
for idx, img_feat in enumerate(images):
if idx in video_idx_in_batch:
img_feat = aimg.new(1, 3, 128, 128).fill_(0)
lowres_img.append(img_feat)
lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img)
highres_img_features = []
highres_img_sizes = []
for idx, img_feat in enumerate(images_highres):
if img_feat.ndim == 5:
img_feat = img_feat.squeeze(1)
highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat)
highres_img_features.append(highres_img_feature)
highres_img_sizes.append(highres_img_size)
image_features = []
for idx in range(len(modalities)):
img_feat = self.get_model().mm_projector(lowres_img_features[idx],
lowres_img_sizes[idx],
highres_img_features[idx],
highres_img_sizes[idx],
modalities[idx])
image_features.append(img_feat.flatten(0, 1))
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_speech_idx = 0
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
if num_speech_images == 0:
cur_speech_features = speech_features[cur_speech_idx]
cur_images_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_speech_idx += 1
cur_image_idx += 1
continue
speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_nospeech_image = []
cur_labels = labels[batch_idx]
cur_labels_nospeech_image = []
for i in range(len(speech_image_token_indices) - 1):
cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
split_sizes = [x.shape[0] for x in cur_labels_nospeech_image]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image))
cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_speech_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i])
cur_new_labels.append(cur_labels_nospeech_image[i])
if i < num_speech_images:
if i < num_images:
cur_images_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_images_features)
cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
else:
cur_speech_features = speech_features[cur_speech_idx]
cur_speech_idx += 1
cur_new_input_embeds.append(cur_speech_features)
cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
if num_images == 0:
cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0)
cur_image_idx += 1
if num_speech == 0:
cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0)
cur_speech_idx += 1
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as speech features can make the sequence longer
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
new_input_embeds_padded.append(torch.cat((
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False

View File

@ -0,0 +1,182 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import torch
import torch.nn as nn
from torch.nn import LayerNorm
# import torchaudio.compliance.kaldi as ta_kaldi
from .kaldi import fbank as kaldi_fbank
from .backbone import (
TransformerEncoder,
)
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class BEATsConfig:
def __init__(self, cfg=None):
self.input_patch_size: int = -1 # path size of patch embedding
self.embed_dim: int = 512 # patch embedding dimension
self.conv_bias: bool = False # include bias in conv encoder
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.deep_norm: bool = False # apply deep_norm first in the transformer
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.gru_rel_pos: bool = False # apply gated relative position embedding
# label predictor
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
self.predictor_class: int = 527 # target class number for the predictor
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class BEATs(nn.Module):
def __init__(
self,
cfg: BEATsConfig,
) -> None:
super().__init__()
logger.info(f"BEATs Config: {cfg.__dict__}")
self.cfg = cfg
self.embed = cfg.embed_dim
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
bias=cfg.conv_bias)
self.dropout_input = nn.Dropout(cfg.dropout_input)
assert not cfg.deep_norm or not cfg.layer_norm_first
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
if cfg.finetuned_model:
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
else:
self.predictor = None
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1)
return padding_mask
def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2 ** 15
fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
feature_only=False,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(fbank, padding_mask)
fbank = fbank.unsqueeze(1)
features = self.patch_embedding(fbank)
features = features.reshape(features.shape[0], features.shape[1], -1)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x = self.dropout_input(features)
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
)
if not feature_only and self.predictor is not None:
x = self.predictor_dropout(x)
logits = self.predictor(x)
if padding_mask is not None and padding_mask.any():
logits[padding_mask] = 0
logits = logits.sum(dim=1)
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
else:
logits = logits.mean(dim=1)
lprobs = torch.sigmoid(logits)
return lprobs, padding_mask
else:
return x, padding_mask

View File

@ -0,0 +1,174 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import torch
import torch.nn as nn
from torch.nn import LayerNorm
# import torchaudio.compliance.kaldi as ta_kaldi
from .kaldi import fbank as kaldi_fbank
from .backbone import (
TransformerEncoder,
)
from .quantizer import (
NormEMAVectorQuantizer,
)
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class TokenizersConfig:
def __init__(self, cfg=None):
self.input_patch_size: int = -1 # path size of patch embedding
self.embed_dim: int = 512 # patch embedding dimension
self.conv_bias: bool = False # include bias in conv encoder
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.deep_norm: bool = False # apply deep_norm first in the transformer
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.gru_rel_pos: bool = False # apply gated relative position embedding
# quantizer
self.quant_n: int = 1024 # codebook number in quantizer
self.quant_dim: int = 256 # codebook dimension in quantizer
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class Tokenizers(nn.Module):
def __init__(
self,
cfg: TokenizersConfig,
) -> None:
super().__init__()
logger.info(f"Tokenizers Config: {cfg.__dict__}")
self.cfg = cfg
self.embed = cfg.embed_dim
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
bias=cfg.conv_bias)
self.dropout_input = nn.Dropout(cfg.dropout_input)
assert not cfg.deep_norm or not cfg.layer_norm_first
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.quantize = NormEMAVectorQuantizer(
n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
)
self.quant_n = cfg.quant_n
self.quantize_layer = nn.Sequential(
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
nn.Tanh(),
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
)
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1)
return padding_mask
def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2 ** 15
fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank
def extract_labels(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(fbank, padding_mask)
fbank = fbank.unsqueeze(1)
features = self.patch_embedding(fbank)
features = features.reshape(features.shape[0], features.shape[1], -1)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x = self.dropout_input(features)
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
)
quantize_input = self.quantize_layer(x)
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
return embed_ind

View File

@ -0,0 +1,782 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import numpy as np
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn import LayerNorm, Parameter
from .modules import (
GradMultiply,
SamePad,
get_activation_fn,
GLU_Linear,
quant_noise,
)
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):
self.relative_position_embedding = args.relative_position_embedding
self.num_buckets = args.num_buckets
self.max_distance = args.max_distance
else:
self.relative_position_embedding = False
self.num_buckets = 0
self.max_distance = 0
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
deep_norm=args.deep_norm,
has_relative_attention_bias=self.relative_position_embedding,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
gru_rel_pos=args.gru_rel_pos,
encoder_layers=args.encoder_layers,
)
for i in range(args.encoder_layers)
]
)
if self.relative_position_embedding:
for i in range(1, args.encoder_layers):
del self.layers[i].self_attn.relative_attention_bias
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
if args.deep_norm:
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
for i in range(args.encoder_layers):
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
z = None
if tgt_layer is not None:
layer_results.append((x, z))
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
if self.layer_wise_gradient_decay_ratio != 1.0:
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
if tgt_layer is not None:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
deep_norm: bool = False,
has_relative_attention_bias: bool = False,
num_buckets: int = 0,
max_distance: int = 0,
rescale_init: bool = False,
gru_rel_pos: bool = False,
encoder_layers: int = 0,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.activation_name = activation_fn
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
num_buckets=num_buckets,
max_distance=max_distance,
rescale_init=rescale_init,
gru_rel_pos=gru_rel_pos,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
if self.activation_name == "glu":
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
else:
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.final_layer_norm = LayerNorm(self.embedding_dim)
self.deep_norm = deep_norm
if self.deep_norm:
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
else:
self.deep_norm_alpha = 1
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
pos_bias=None
):
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual * self.deep_norm_alpha + x
x = self.self_attn_layer_norm(x)
residual = x
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual * self.deep_norm_alpha + x
x = self.final_layer_norm(x)
return x, attn, pos_bias
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
num_buckets=32,
max_distance=128,
gru_rel_pos=False,
rescale_init=False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
k_bias = True
if rescale_init:
k_bias = False
k_embed_dim = embed_dim
q_embed_dim = embed_dim
self.k_proj = quant_noise(
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
if self.has_relative_attention_bias:
nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets
max_distance = self.max_distance
relative_buckets = 0
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(
relative_position,
bidirectional=True
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
alpha = 32
q *= 1 / alpha
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v, position_bias
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
attn_weights = attn_weights + attn_mask_rel_pos
attn_weights_float = F.softmax(
attn_weights, dim=-1
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights, position_bias
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)

View File

@ -0,0 +1,813 @@
import math
from typing import Tuple
import torch
# import torchaudio
from torch import Tensor
__all__ = [
"get_mel_banks",
"inverse_mel_scale",
"inverse_mel_scale_scalar",
"mel_scale",
"mel_scale_scalar",
"spectrogram",
"fbank",
"mfcc",
"vtln_warp_freq",
"vtln_warp_mel_freq",
]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001
# window types
HAMMING = "hamming"
HANNING = "hanning"
POVEY = "povey"
RECTANGULAR = "rectangular"
BLACKMAN = "blackman"
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def _get_epsilon(device, dtype):
return EPSILON.to(device=device, dtype=dtype)
def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x"""
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
representing how the window is shifted along the waveform. Each row is a frame.
Args:
waveform (Tensor): Tensor of size ``num_samples``
window_size (int): Frame length
window_shift (int): Frame shift
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends.
Returns:
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
"""
assert waveform.dim() == 1
num_samples = waveform.size(0)
strides = (window_shift * waveform.stride(0), waveform.stride(0))
if snip_edges:
if num_samples < window_size:
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
else:
m = 1 + (num_samples - window_size) // window_shift
else:
reversed_waveform = torch.flip(waveform, [0])
m = (num_samples + (window_shift // 2)) // window_shift
pad = window_size // 2 - window_shift // 2
pad_right = reversed_waveform
if pad > 0:
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
# but we want [2, 1, 0, 0, 1, 2]
pad_left = reversed_waveform[-pad:]
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
else:
# pad is negative so we want to trim the waveform at the front
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
sizes = (m, window_size)
return waveform.as_strided(sizes, strides)
def _feature_window_function(
window_type: str,
window_size: int,
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
elif window_type == POVEY:
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size, device=device, dtype=dtype)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients
return (
blackman_coeff
- 0.5 * torch.cos(a * window_function)
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
).to(device=device, dtype=dtype)
else:
raise Exception("Invalid window type " + window_type)
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(
waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
r"""Gets the waveform and window properties"""
channel = max(channel, 0)
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
window_size, len(waveform)
)
assert 0 < window_shift, "`window_shift` must be greater than 0"
assert padded_window_size % 2 == 0, (
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
)
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
return waveform, window_shift, window_size, padded_window_size
def _get_window(
waveform: Tensor,
padded_window_size: int,
window_size: int,
window_shift: int,
window_type: str,
blackman_coeff: float,
snip_edges: bool,
raw_energy: bool,
energy_floor: float,
dither: float,
remove_dc_offset: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
r"""Gets a window and its log energy
Returns:
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
if dither != 0.0:
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
strided_input = strided_input + rand_gauss * dither
if remove_dc_offset:
# Subtract each row/frame by its mean
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
0
) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
0
) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
return strided_input, signal_log_energy
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
tensor = tensor - col_means
return tensor
def spectrogram(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
window_type: str = POVEY,
) -> Tensor:
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
compute-spectrogram-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A spectrogram identical to what Kaldi would output. The shape is
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
strided_input, signal_log_energy = _get_window(
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1, 2)
fft = torch.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
return power_spectrum
def inverse_mel_scale_scalar(mel_freq: float) -> float:
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
def mel_scale_scalar(freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def vtln_warp_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq: float,
high_freq: float,
vtln_warp_factor: float,
freq: Tensor,
) -> Tensor:
r"""This computes a VTLN warping function that is not the same as HTK's one,
but has similar inputs (this function has the advantage of never producing
empty bins).
This function computes a warp function F(freq), defined between low_freq
and high_freq inclusive, with the following properties:
F(low_freq) == low_freq
F(high_freq) == high_freq
The function is continuous and piecewise linear with two inflection
points.
The lower inflection point (measured in terms of the unwarped
frequency) is at frequency l, determined as described below.
The higher inflection point is at a frequency h, determined as
described below.
If l <= f <= h, then F(f) = f/vtln_warp_factor.
If the higher inflection point (measured in terms of the unwarped
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
Since (by the last point) F(h) == h/vtln_warp_factor, then
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
= vtln_high_cutoff * min(1, vtln_warp_factor).
If the lower inflection point (measured in terms of the unwarped
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
= vtln_low_cutoff * max(1, vtln_warp_factor)
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
freq (Tensor): given frequency in Hz
Returns:
Tensor: Freq after vtln warp
"""
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor
Fl = scale * l # F(l)
Fh = scale * h # F(h)
assert l > low_freq and h < high_freq
# slope of left part of the 3-piece linear function
scale_left = (Fl - low_freq) / (l - low_freq)
# [slope of center part is just "scale"]
# slope of right part of the 3-piece linear function
scale_right = (high_freq - Fh) / (high_freq - h)
res = torch.empty_like(freq)
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
before_l = torch.lt(freq, l) # freq < l
before_h = torch.lt(freq, h) # freq < h
after_h = torch.ge(freq, h) # freq >= h
# order of operations matter here (since there is overlapping frequency regions)
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
res[before_h] = scale * freq[before_h]
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
res[outside_low_high_freq] = freq[outside_low_high_freq]
return res
def vtln_warp_mel_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq,
high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor,
) -> Tensor:
r"""
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
mel_freq (Tensor): Given frequency in Mel
Returns:
Tensor: ``mel_freq`` after vtln warp
"""
return mel_scale(
vtln_warp_freq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
)
)
def get_mel_banks(
num_bins: int,
window_length_padded: int,
sample_freq: float,
low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float,
) -> Tuple[Tensor, Tensor]:
"""
Returns:
(Tensor, Tensor): The tuple consists of ``bins`` (which is
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
center frequencies of bins of size (``num_bins``)).
"""
assert num_bins > 3, "Must have at least 3 mel bins"
assert window_length_padded % 2 == 0
num_fft_bins = window_length_padded / 2
nyquist = 0.5 * sample_freq
if high_freq <= 0.0:
high_freq += nyquist
assert (
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
mel_low_freq = mel_scale_scalar(low_freq)
mel_high_freq = mel_scale_scalar(high_freq)
# divide by num_bins+1 in next line because of end-effects where the bins
# spread out to the sides.
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
if vtln_high < 0.0:
vtln_high += nyquist
assert vtln_warp_factor == 1.0 or (
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
vtln_low, vtln_high, low_freq, high_freq
)
bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
if vtln_warp_factor != 1.0:
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins)
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
if vtln_warp_factor == 1.0:
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
else:
# warping can move the order of left_mel, center_mel, right_mel anywhere
bins = torch.zeros_like(up_slope)
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
bins[up_idx] = up_slope[up_idx]
bins[down_idx] = down_slope[down_idx]
return bins, center_freqs
def fbank(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
compute-fbank-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
(need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0, device=device, dtype=dtype)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1)
spectrum = torch.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.0)
# size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
)
mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = torch.mm(spectrum, mel_energies.T)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
if htk_compat:
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
else:
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
return mel_energies
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
return dct_matrix
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps)
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
def mfcc(
waveform: Tensor,
blackman_coeff: float = 0.42,
cepstral_lifter: float = 22.0,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
num_ceps: int = 13,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
features (need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``"povey"``)
Returns:
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided
"""
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
device, dtype = waveform.device, waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(
waveform=waveform,
blackman_coeff=blackman_coeff,
channel=channel,
dither=dither,
energy_floor=energy_floor,
frame_length=frame_length,
frame_shift=frame_shift,
high_freq=high_freq,
htk_compat=htk_compat,
low_freq=low_freq,
min_duration=min_duration,
num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient,
raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset,
round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency,
snip_edges=snip_edges,
subtract_mean=False,
use_energy=use_energy,
use_log_fbank=True,
use_power=True,
vtln_high=vtln_high,
vtln_low=vtln_low,
vtln_warp=vtln_warp,
window_type=window_type,
)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
# size (m, num_ceps)
feature = feature.matmul(dct_matrix)
if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs.to(device=device, dtype=dtype)
# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy
if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)
feature = torch.cat((feature, energy), dim=1)
feature = _subtract_column_mean(feature, subtract_mean)
return feature

View File

@ -0,0 +1,218 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import warnings
import torch
from torch import Tensor, nn
import torch.nn.functional as F
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
return x * self.act(x)
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
super(GLU_Linear, self).__init__()
self.glu_type = glu_type
self.output_dim = output_dim
if glu_type == "sigmoid":
self.glu_act = torch.nn.Sigmoid()
elif glu_type == "swish":
self.glu_act = Swish()
elif glu_type == "relu":
self.glu_act = torch.nn.ReLU()
elif glu_type == "gelu":
self.glu_act = torch.nn.GELU()
if bias_in_glu:
self.linear = nn.Linear(input_dim, output_dim * 2, True)
else:
self.linear = nn.Linear(input_dim, output_dim * 2, False)
def forward(self, x):
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
x = self.linear(x)
if self.glu_type == "bilinear":
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
else:
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
return x
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str):
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
warnings.warn(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "glu":
return lambda x: x
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module

View File

@ -0,0 +1,215 @@
# --------------------------------------------------------
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
# Github source: https://github.com/microsoft/unilm/tree/master/beats
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on VQGAN code bases
# https://github.com/CompVis/taming-transformers
# --------------------------------------------------------'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as distributed
try:
from einops import rearrange, repeat
except ImportError:
pass
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ means.t()
else:
diffs = rearrange(samples, 'n d -> n () d') \
- rearrange(means, 'c d -> () c d')
dists = -(diffs ** 2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
if use_cosine_sim:
new_means = l2norm(new_means)
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EmbeddingEMA(nn.Module):
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
super().__init__()
self.num_tokens = num_tokens
self.codebook_dim = codebook_dim
self.decay = decay
self.eps = eps
if codebook_init_path == '':
if not kmeans_init:
weight = torch.randn(num_tokens, codebook_dim)
weight = l2norm(weight)
else:
weight = torch.zeros(num_tokens, codebook_dim)
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
else:
print(f"load init codebook weight from {codebook_init_path}")
codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
weight = codebook_ckpt_weight.clone()
self.register_buffer('initted', torch.Tensor([True]))
self.weight = nn.Parameter(weight, requires_grad=False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.update = True
@torch.jit.ignore
def init_embed_(self, data):
if self.initted:
return
print("Performing Kemans init for codebook")
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
self.weight.data.copy_(embed)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
def cluster_size_ema_update(self, new_cluster_size):
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
def embed_avg_ema_update(self, new_embed_avg):
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
def weight_update(self, num_tokens):
n = self.cluster_size.sum()
smoothed_cluster_size = (
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
)
# normalize embedding average with smoothed cluster size
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
self.weight.data.copy_(embed_normalized)
def norm_ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
moving_avg.data.copy_(l2norm(moving_avg.data))
class NormEMAVectorQuantizer(nn.Module):
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
super().__init__()
self.codebook_dim = embedding_dim
self.num_tokens = n_embed
self.beta = beta
self.decay = decay
# learnable = True if orthogonal_reg_weight > 0 else False
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
self.statistic_code_usage = statistic_code_usage
if statistic_code_usage:
self.register_buffer('cluster_size', torch.zeros(n_embed))
if distributed.is_available() and distributed.is_initialized():
print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
self.all_reduce_fn = distributed.all_reduce
else:
self.all_reduce_fn = nn.Identity()
def reset_cluster_size(self, device):
if self.statistic_code_usage:
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
self.cluster_size = self.cluster_size.to(device)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
# z, 'b c h w -> b h w c'
# z = rearrange(z, 'b c h w -> b h w c')
# z = z.transpose(1, 2)
z = l2norm(z)
z_flattened = z.reshape(-1, self.codebook_dim)
self.embedding.init_embed_(z_flattened)
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
if not self.training:
with torch.no_grad():
cluster_size = encodings.sum(0)
self.all_reduce_fn(cluster_size)
ema_inplace(self.cluster_size, cluster_size, self.decay)
if self.training and self.embedding.update:
# EMA cluster size
bins = encodings.sum(0)
self.all_reduce_fn(bins)
# self.embedding.cluster_size_ema_update(bins)
ema_inplace(self.cluster_size, bins, self.decay)
zero_mask = (bins == 0)
bins = bins.masked_fill(zero_mask, 1.)
embed_sum = z_flattened.t() @ encodings
self.all_reduce_fn(embed_sum)
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
embed_normalized = l2norm(embed_normalized)
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
embed_normalized)
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
# z_q, 'b h w c -> b c h w'
# z_q = rearrange(z_q, 'b h w c -> b c h w')
# z_q = z_q.transpose(1, 2)
return z_q, loss, encoding_indices

View File

@ -0,0 +1,13 @@
from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder
import torch.nn as nn
def build_speech_encoder(config):
speech_encoder_type = getattr(config, 'speech_encoder_type', None)
if "whisper" in speech_encoder_type.lower():
return WhisperWrappedEncoder.load(config)
elif "dual" in speech_encoder_type.lower():
return DualWrappedEncoder(config)
elif "none" in speech_encoder_type.lower():
return None
raise ValueError(f'Unknown speech encoder: {speech_encoder_type}')

View File

@ -0,0 +1,74 @@
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor
import whisper
from opencompass.models.ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
replace_layer_norm(encoder)
return encoder
class DualWrappedEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.whisper_model = self.load_whisper(config)
self.beats_model = self.load_beats(config)
def load_whisper(cls, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
replace_layer_norm(encoder)
return encoder
def load_beats(cls, model_config):
beats_path = model_config.music_encoder
print("Loading BEATs Model")
beats_ckpt = torch.load(beats_path, map_location='cpu')
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
beats = BEATs(beats_cfg)
beats.load_state_dict(beats_ckpt['model'])
return beats
def forward(self, x, raw_wav=None, audio_padding_mask=None):
with torch.no_grad():
self.beats_model = self.beats_model.float()
speech_embeds = self.whisper_model(x)
audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True)
if audio_embeds.size(1) < speech_embeds.size(1):
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
elif audio_embeds.size(1) > speech_embeds.size(1):
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
speech_embeds = speech_embeds.to(torch.bfloat16)
return speech_embeds

View File

@ -0,0 +1,11 @@
from .speech_projector import EncoderProjectorConcat
def build_speech_projector(config):
projector_type = getattr(config, 'speech_projector_type', 'linear')
if projector_type == 'linear':
return EncoderProjectorConcat(config)
elif projector_type == 'none':
return None
raise ValueError(f'Unknown projector type: {projector_type}')

View File

@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import math
class EncoderProjectorConcat(nn.Module):
def __init__(self, config):
super().__init__()
self.k = config.speech_encoder_ds_rate
self.encoder_dim = config.speech_encoder_hidden_size
self.llm_dim = config.hidden_size
self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2048, config.hidden_size)
embed_std = 1 / math.sqrt(config.hidden_size)
self.speech_newline = nn.Parameter(
torch.randn(config.hidden_size) * embed_std
)
self.speech_begin = nn.Parameter(
torch.randn(config.hidden_size) * embed_std
)
self.speech_end = nn.Parameter(
torch.randn(config.hidden_size) * embed_std
)
def forward(self, x):
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.k, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = torch.cat([
x,
self.speech_newline.reshape(1, 1, -1).expand(batch_size, 1, -1).to(x.dtype)
], dim=1)
begin = self.speech_begin.reshape(1, -1).to(x.dtype)
end = self.speech_end.reshape(1, -1).to(x.dtype)
x = x.flatten(0, 1)
x = torch.cat([begin, x, end], dim=0)
# x = x.flatten(0, 1)
return x

View File

@ -0,0 +1,213 @@
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torch
import logging
import logging.handlers
import transformers
from opencompass.models.ola.constants import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True, encoding='UTF-8')
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def lengths_to_padding_mask(lens):
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
def lengths_to_mask(lens):
return ~lengths_to_padding_mask(lens)
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"

View File

@ -0,0 +1,141 @@
import os
os.environ['LOWRES_RESIZE'] = '384x32'
os.environ['HIGHRES_BASE'] = '0x32'
os.environ['VIDEO_RESIZE'] = "0x64"
os.environ['VIDEO_MAXRES'] = "480"
os.environ['VIDEO_MINRES'] = "288"
os.environ['MAXRES'] = '1536'
os.environ['MINRES'] = '0'
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
os.environ['LOAD_VISION_EARLY'] = '1'
os.environ['PAD2STRIDE'] = '1'
from opencompass.models.base import BaseModel
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from opencompass.models.base import BaseModel, LMTemplateParser
from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str]
import sys
import torch
import re
from PIL import Image
import numpy as np
import transformers
from typing import Dict, Optional, Sequence, List
from opencompass.models.ola.conversation import conv_templates, SeparatorStyle
from opencompass.models.ola.model.builder import load_pretrained_model
from opencompass.models.ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
from opencompass.models.ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
import argparse
import copy
class OlaModel(BaseModel):
def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_path: Optional[str] = None,
model_config: Optional[str] = None,
meta_template: Optional[Dict] = None):
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
tokenizer, model, _, _ = load_pretrained_model(path, None)
model = model.to('cuda').eval()
model = model.bfloat16()
self.tokenizer=tokenizer
self.model=model
self.gen_kwargs = {
"max_new_tokens":1024,
"temperature":0.2,
"top_p":None,
"num_beams":1,
}
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
assert len(inputs)==1 # batch=1
image_path = None
audio_path = None
video_path = None
text = inputs[0]
images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
image_sizes = [(224, 224)]
USE_SPEECH=False
speechs = []
speech_lengths = []
speech_wavs = []
speech_chunks = []
speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
speech_lengths = [torch.LongTensor([3000]).to('cuda')]
speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
speech_chunks = [torch.LongTensor([1]).to('cuda')]
conv_mode = "qwen_1_5"
if text:
qs = text
else:
qs = ''
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
pad_token_ids = 151643
attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=images,
images_highres=images_highres,
image_sizes=image_sizes,
modalities=['text'],
speech=speechs,
speech_lengths=speech_lengths,
speech_chunks=speech_chunks,
speech_wav=speech_wavs,
attention_mask=attention_masks,
use_cache=True,
stopping_criteria=[stopping_criteria],
do_sample=True if self.gen_kwargs["temperature"] > 0 else False,
temperature=self.gen_kwargs["temperature"],
top_p=self.gen_kwargs["top_p"],
num_beams=self.gen_kwargs["num_beams"],
max_new_tokens=self.gen_kwargs["max_new_tokens"],
)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
out=[]
for output in outputs:
output = output.strip()
if output.endswith(stop_str):
output = output[:-len(stop_str)]
out.append(output)
print(f"prompt---->",prompt)
print(f"out---->",out)
print(f"\n")
return out