mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge 37b894d4a1
into d572761cef
This commit is contained in:
commit
5582aeb9f9
10
examples/eval_ola.py
Normal file
10
examples/eval_ola.py
Normal file
@ -0,0 +1,10 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from opencompass.configs.datasets.demo.demo_gsm8k_chat_gen import gsm8k_datasets
|
||||
from opencompass.configs.datasets.demo.demo_math_chat_gen import math_datasets
|
||||
from opencompass.configs.models.ola.ola import models as ola_models
|
||||
|
||||
|
||||
datasets = math_datasets
|
||||
models = ola_models
|
12
opencompass/configs/models/ola/ola.py
Normal file
12
opencompass/configs/models/ola/ola.py
Normal file
@ -0,0 +1,12 @@
|
||||
from opencompass.models import OlaModel
|
||||
models = [
|
||||
dict(
|
||||
type=OlaModel,
|
||||
path="THUdyh/Ola-7b",
|
||||
max_seq_len=2048,
|
||||
abbr='ola',
|
||||
max_out_len=1024,
|
||||
batch_size=1,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
@ -35,6 +35,7 @@ from .openai_api import OpenAI # noqa: F401
|
||||
from .openai_api import OpenAISDK # noqa: F401
|
||||
from .pangu_api import PanGu # noqa: F401
|
||||
from .qwen_api import Qwen # noqa: F401
|
||||
from .ola_model import OlaModel # noqa: F401
|
||||
from .rendu_api import Rendu # noqa: F401
|
||||
from .sensetime_api import SenseTime # noqa: F401
|
||||
from .stepfun_api import StepFun # noqa: F401
|
||||
|
65
opencompass/models/ola/arguments.py
Normal file
65
opencompass/models/ola/arguments.py
Normal file
@ -0,0 +1,65 @@
|
||||
import transformers
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
||||
version: Optional[str] = field(default="v0")
|
||||
freeze_backbone: bool = field(default=False)
|
||||
tune_speech_projector: bool = field(default=False)
|
||||
tune_speech_encoder: bool = field(default=False)
|
||||
tune_speech_generator_only: bool = field(default=False)
|
||||
speech_encoder_type: Optional[str] = field(default=None)
|
||||
speech_encoder: Optional[str] = field(default=None)
|
||||
pretrain_speech_projector: Optional[str] = field(default=None)
|
||||
speech_projector_type: Optional[str] = field(default='linear')
|
||||
speech_encoder_ds_rate: int = 5
|
||||
speech_encoder_hidden_size: int = 1280
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(default=None,
|
||||
metadata={"help": "Path to the training data."})
|
||||
is_multimodal: bool = False
|
||||
input_type: str = field(default="mel")
|
||||
speech_normalize: bool = False
|
||||
mel_size: int = 128
|
||||
has_tgt_units: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
cache_dir: Optional[str] = field(default=None)
|
||||
optim: str = field(default="adamw_torch")
|
||||
freeze_speech_projector: bool = field(default=False)
|
||||
model_max_length: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help":
|
||||
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
||||
},
|
||||
)
|
||||
double_quant: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Compress the quantization statistics through double quantization."}
|
||||
)
|
||||
quant_type: str = field(
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
||||
)
|
||||
bits: int = field(
|
||||
default=16,
|
||||
metadata={"help": "How many bits to use."}
|
||||
)
|
||||
lora_enable: bool = False
|
||||
lora_r: int = 64
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.05
|
||||
lora_weight_path: str = ""
|
||||
lora_bias: str = "none"
|
||||
speech_projector_lr: Optional[float] = None
|
||||
group_by_modality_length: bool = field(default=False)
|
14
opencompass/models/ola/constants.py
Normal file
14
opencompass/models/ola/constants.py
Normal file
@ -0,0 +1,14 @@
|
||||
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
||||
WORKER_HEART_BEAT_INTERVAL = 15
|
||||
|
||||
LOGDIR = "."
|
||||
|
||||
# Model Constants
|
||||
IGNORE_INDEX = -100
|
||||
SPEECH_TOKEN_INDEX = -200
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
IMAGE_TOKEN_INDEX= -300
|
||||
DEFAULT_IMAGE_TOKEN = "<image>"
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
254
opencompass/models/ola/conversation.py
Normal file
254
opencompass/models/ola/conversation.py
Normal file
@ -0,0 +1,254 @@
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any, Union, Tuple
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
TWO = auto()
|
||||
PLAIN = auto()
|
||||
CHATML = auto()
|
||||
LLAMA_2 = auto()
|
||||
LLAMA_3 = auto()
|
||||
QWEN2 = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
version: str = "Unknown"
|
||||
|
||||
tokenizer_id: str = ""
|
||||
tokenizer: Any = None
|
||||
# Stop criteria (the default one is EOS token)
|
||||
stop_str: Union[str, List[str]] = None
|
||||
# Stops generation if meeting any token in this list
|
||||
stop_token_ids: List[int] = None
|
||||
|
||||
skip_next: bool = False
|
||||
|
||||
def get_prompt(self):
|
||||
messages = self.messages
|
||||
|
||||
if self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message = message[0]
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
||||
wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
|
||||
ret = "<|begin_of_text|>" + wrap_sys(self.system)
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message = message[0]
|
||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||
ret += message.strip() + self.sep2
|
||||
else:
|
||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
||||
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
||||
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
||||
ret = ""
|
||||
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if i == 0:
|
||||
assert message, "first message should not be none"
|
||||
assert role == self.roles[0], "first message should come from user"
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
if i == 0:
|
||||
message = wrap_sys(self.system) + message
|
||||
if i % 2 == 0:
|
||||
message = wrap_inst(message)
|
||||
ret += self.sep + message
|
||||
else:
|
||||
ret += " " + message + " " + self.sep2
|
||||
else:
|
||||
ret += ""
|
||||
ret = ret.lstrip(self.sep)
|
||||
elif self.sep_style == SeparatorStyle.PLAIN:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
ret += message + seps[i % 2]
|
||||
else:
|
||||
ret += ""
|
||||
|
||||
elif self.sep_style == SeparatorStyle.CHATML:
|
||||
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
||||
for role, message in messages:
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
raise ValueError("Tuple not supported in CHATML")
|
||||
message, images = message
|
||||
message = "<speech>" * len(images) + message
|
||||
ret += role + "\n" + message + self.sep + "\n"
|
||||
else:
|
||||
ret += role + "\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.QWEN2:
|
||||
start = '<|im_start|>'
|
||||
end = '<|im_end|>\n'
|
||||
ret = start + 'system\n' + self.system + end
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
|
||||
if message.endswith('<|endoftext|>'):
|
||||
message = message.replace('<|endoftext|>', '')
|
||||
ret += start + role + "\n" + message + end + '<|endoftext|>'
|
||||
else:
|
||||
assert not '<|endoftext|>' in message, f"Invalid message: {message}"
|
||||
ret += start + role + "\n" + message + end
|
||||
else:
|
||||
ret += start + role + "\n"
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
return ret
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
if type(msg) is tuple:
|
||||
msg, speech = msg
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
version=self.version)
|
||||
|
||||
def dict(self):
|
||||
if len(self.get_images()) > 0:
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
}
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
}
|
||||
|
||||
conv_vicuna_v1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
version="v1",
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_llama_2 = Conversation(
|
||||
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
version="llama_v2",
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.LLAMA_2,
|
||||
sep="<s>",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_llama_3 = Conversation(
|
||||
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
||||
roles=("user", "assistant"),
|
||||
version="llama_v3",
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.LLAMA_3,
|
||||
sep="",
|
||||
sep2="<|eot_id|>"
|
||||
)
|
||||
|
||||
|
||||
conv_qwen_v1 = Conversation(
|
||||
system="You are a helpful assistant.",
|
||||
roles=("user", "assistant"),
|
||||
version="v1",
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.QWEN2,
|
||||
)
|
||||
|
||||
conv_plain = Conversation(
|
||||
system="",
|
||||
roles=("", ""),
|
||||
messages=(
|
||||
),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.PLAIN,
|
||||
sep="</s>",
|
||||
)
|
||||
|
||||
conv_qwen = Conversation(
|
||||
system="""<|im_start|>system
|
||||
You are a helpful assistant.""",
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
version="qwen",
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
|
||||
default_conversation = conv_llama_3
|
||||
conv_templates = {
|
||||
"v1": conv_vicuna_v1,
|
||||
"plain": conv_plain,
|
||||
"llama_2": conv_llama_2,
|
||||
"llama_3": conv_llama_3,
|
||||
'v1_qwen2': conv_qwen_v1,
|
||||
"qwen_1_5": conv_qwen,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(default_conversation.get_prompt())
|
0
opencompass/models/ola/datasets/__init__.py
Normal file
0
opencompass/models/ola/datasets/__init__.py
Normal file
231
opencompass/models/ola/datasets/preprocess.py
Normal file
231
opencompass/models/ola/datasets/preprocess.py
Normal file
@ -0,0 +1,231 @@
|
||||
import copy
|
||||
import torch
|
||||
import transformers
|
||||
import tokenizers
|
||||
|
||||
from typing import Dict, Sequence
|
||||
|
||||
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX
|
||||
from opencompass.models.ola import conversation as conversation_lib
|
||||
from opencompass.models.ola.model import *
|
||||
from opencompass.models.ola.arguments import DataArguments
|
||||
from opencompass.models.ola.constants import SPEECH_TOKEN_INDEX
|
||||
|
||||
from packaging import version
|
||||
|
||||
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
|
||||
|
||||
|
||||
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
|
||||
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
|
||||
|
||||
def insert_separator(X, sep):
|
||||
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||
|
||||
input_ids = []
|
||||
offset = 0
|
||||
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||
offset = 1
|
||||
input_ids.append(prompt_chunks[0][0])
|
||||
|
||||
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
|
||||
input_ids.extend(x[offset:])
|
||||
|
||||
if return_tensors is not None:
|
||||
if return_tensors == 'pt':
|
||||
return torch.tensor(input_ids, dtype=torch.long)
|
||||
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||
return input_ids
|
||||
|
||||
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
||||
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
||||
|
||||
def insert_separator(X, sep):
|
||||
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||
|
||||
input_ids = []
|
||||
offset = 0
|
||||
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||
offset = 1
|
||||
input_ids.append(prompt_chunks[0][0])
|
||||
|
||||
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
||||
input_ids.extend(x[offset:])
|
||||
|
||||
if return_tensors is not None:
|
||||
if return_tensors == 'pt':
|
||||
return torch.tensor(input_ids, dtype=torch.long)
|
||||
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||
return input_ids
|
||||
|
||||
def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
||||
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech><image>')]
|
||||
|
||||
def insert_separator(X, sep):
|
||||
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||
|
||||
input_ids = []
|
||||
offset = 0
|
||||
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||
offset = 1
|
||||
input_ids.append(prompt_chunks[0][0])
|
||||
|
||||
for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)):
|
||||
input_ids.extend(x[offset:])
|
||||
|
||||
if return_tensors is not None:
|
||||
if return_tensors == 'pt':
|
||||
return torch.tensor(input_ids, dtype=torch.long)
|
||||
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||
return input_ids
|
||||
|
||||
def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
||||
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>\nUser's question in speech: <speech>\n")]
|
||||
|
||||
def insert_separator(X, sep):
|
||||
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||
|
||||
input_ids = []
|
||||
offset = 0
|
||||
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||
offset = 1
|
||||
input_ids.append(prompt_chunks[0][0])
|
||||
|
||||
nl_tokens = tokenizer("\n").input_ids
|
||||
special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens]
|
||||
|
||||
for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)):
|
||||
input_ids.extend(x[offset:])
|
||||
|
||||
if return_tensors is not None:
|
||||
if return_tensors == 'pt':
|
||||
return torch.tensor(input_ids, dtype=torch.long)
|
||||
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||
return input_ids
|
||||
|
||||
def preprocess_v1(
|
||||
sources,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
has_speech: bool = False
|
||||
) -> Dict:
|
||||
conv = conversation_lib.default_conversation.copy()
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
# Apply prompt templates
|
||||
conversations = []
|
||||
for i, source in enumerate(sources):
|
||||
if roles[source[0]["from"]] != conv.roles[0]:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
|
||||
conv.messages = []
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], f"{i}"
|
||||
conv.append_message(role, sentence["value"])
|
||||
conversations.append(conv.get_prompt())
|
||||
|
||||
# Tokenize conversations
|
||||
|
||||
if has_speech:
|
||||
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
||||
else:
|
||||
input_ids = tokenizer(
|
||||
conversations,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
|
||||
targets = input_ids.clone()
|
||||
|
||||
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
||||
|
||||
# Mask targets
|
||||
sep = conv.sep + conv.roles[1] + ": "
|
||||
for conversation, target in zip(conversations, targets):
|
||||
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
||||
|
||||
rounds = conversation.split(conv.sep2)
|
||||
cur_len = 1
|
||||
target[:cur_len] = IGNORE_INDEX
|
||||
for i, rou in enumerate(rounds):
|
||||
if rou == "":
|
||||
break
|
||||
|
||||
parts = rou.split(sep)
|
||||
if len(parts) != 2:
|
||||
break
|
||||
parts[0] += sep
|
||||
|
||||
if has_speech:
|
||||
round_len = len(tokenizer_speech_token(rou, tokenizer))
|
||||
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
|
||||
else:
|
||||
round_len = len(tokenizer(rou).input_ids)
|
||||
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
||||
|
||||
# FIXME: tokenizer bug
|
||||
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
||||
round_len -= 1
|
||||
instruction_len -= 1
|
||||
|
||||
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
||||
|
||||
cur_len += round_len
|
||||
target[cur_len:] = IGNORE_INDEX
|
||||
|
||||
if cur_len < tokenizer.model_max_length:
|
||||
if cur_len != total_len:
|
||||
target[:] = IGNORE_INDEX
|
||||
print(
|
||||
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
||||
f" (ignored)"
|
||||
)
|
||||
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=targets,
|
||||
)
|
||||
|
||||
|
||||
def preprocess_plain(
|
||||
sources: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
# add end signal and concatenate together
|
||||
conversations = []
|
||||
for source in sources:
|
||||
assert len(source) == 2
|
||||
assert DEFAULT_SPEECH_TOKEN in source[0]['value']
|
||||
source[0]['value'] = DEFAULT_SPEECH_TOKEN
|
||||
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
||||
conversations.append(conversation)
|
||||
# tokenize conversations
|
||||
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
||||
targets = copy.deepcopy(input_ids)
|
||||
for target, source in zip(targets, sources):
|
||||
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer))
|
||||
target[:tokenized_len] = IGNORE_INDEX
|
||||
|
||||
return dict(input_ids=input_ids, labels=targets)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
has_speech: bool = False
|
||||
) -> Dict:
|
||||
"""
|
||||
Given a list of sources, each is a conversation list. This transform:
|
||||
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
||||
2. Concatenate conversations together;
|
||||
3. Tokenize the concatenated conversation;
|
||||
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
||||
"""
|
||||
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
||||
return preprocess_plain(sources, tokenizer)
|
||||
if conversation_lib.default_conversation.version.startswith("v1"):
|
||||
return preprocess_v1(sources, tokenizer, has_speech=has_speech)
|
||||
raise NotImplementedError
|
271
opencompass/models/ola/mm_utils.py
Normal file
271
opencompass/models/ola/mm_utils.py
Normal file
@ -0,0 +1,271 @@
|
||||
from PIL import Image
|
||||
import base64
|
||||
import math
|
||||
import ast
|
||||
|
||||
import torch
|
||||
from transformers import StoppingCriteria
|
||||
import os
|
||||
import io
|
||||
|
||||
if 'VIDEO_RESIZE' in os.environ:
|
||||
# highresxpatch
|
||||
VIDEO_RESIZE = os.environ['VIDEO_RESIZE']
|
||||
video_base, video_ps = VIDEO_RESIZE.split('x')
|
||||
video_base = int(video_base)
|
||||
video_ps = int(video_ps)
|
||||
print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}")
|
||||
else:
|
||||
HIGHRES_BASE = None
|
||||
|
||||
if 'HIGHRES_BASE' in os.environ:
|
||||
# highresxpatch
|
||||
HIGHRES_BASE = os.environ['HIGHRES_BASE']
|
||||
highres_base, highres_ps = HIGHRES_BASE.split('x')
|
||||
highres_base = int(highres_base)
|
||||
highres_ps = int(highres_ps)
|
||||
print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}")
|
||||
else:
|
||||
HIGHRES_BASE = None
|
||||
|
||||
if 'MAXRES' in os.environ:
|
||||
# highresxpatch
|
||||
MAXRES = int(os.environ['MAXRES'])
|
||||
print(f"MAXRES is set as {MAXRES}")
|
||||
else:
|
||||
MAXRES = 1536
|
||||
|
||||
if 'MINRES' in os.environ:
|
||||
# highresxpatch
|
||||
MINRES = int(os.environ['MINRES'])
|
||||
print(f"MINRES is set as {MINRES}")
|
||||
else:
|
||||
MINRES = 0
|
||||
|
||||
if 'VIDEO_MAXRES' in os.environ:
|
||||
# highresxpatch
|
||||
VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES'])
|
||||
print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}")
|
||||
else:
|
||||
VIDEO_MAXRES = 1536
|
||||
|
||||
if 'VIDEO_MINRES' in os.environ:
|
||||
# highresxpatch
|
||||
VIDEO_MINRES = int(os.environ['VIDEO_MINRES'])
|
||||
print(f"VIDEO_MINRES is set as {VIDEO_MINRES}")
|
||||
else:
|
||||
MINRES = 0
|
||||
|
||||
if 'PAD2STRIDE' in os.environ:
|
||||
# highresxpatch
|
||||
PAD2STRIDE = True
|
||||
print(f"PAD2STRIDE is set")
|
||||
else:
|
||||
PAD2STRIDE = False
|
||||
|
||||
if 'LOWRES_RESIZE' in os.environ:
|
||||
LOWRES_RESIZE = os.environ['LOWRES_RESIZE']
|
||||
print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}")
|
||||
if 'x' in LOWRES_RESIZE:
|
||||
size, ps = LOWRES_RESIZE.split('x')
|
||||
size = int(size)
|
||||
ps = int(ps)
|
||||
LOWRES_RESIZE = (size, ps)
|
||||
else:
|
||||
LOWRES_RESIZE = int(LOWRES_RESIZE)
|
||||
else:
|
||||
LOWRES_RESIZE = None
|
||||
|
||||
|
||||
def pad_image(image, target_resolution, value=0):
|
||||
"""
|
||||
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image.
|
||||
target_resolution (tuple): The target resolution (width, height) of the image.
|
||||
|
||||
Returns:
|
||||
PIL.Image.Image: The resized and padded image.
|
||||
"""
|
||||
original_width, original_height = image.size
|
||||
target_width, target_height = target_resolution
|
||||
# Create a new image with the target size and paste the resized image onto it
|
||||
new_image = Image.new('RGB', (target_width, target_height), (value, value, value))
|
||||
paste_x = (target_width - original_width) // 2
|
||||
paste_y = (target_height - original_height) // 2
|
||||
new_image.paste(image, (paste_x, paste_y))
|
||||
return new_image
|
||||
|
||||
def resize_images(image, patch_size=14, base_size=896):
|
||||
h, w = image.size
|
||||
if base_size == 0:
|
||||
if h * w > MAXRES * MAXRES:
|
||||
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
||||
scale = MAXRES * MAXRES / (h * w)
|
||||
scale = math.sqrt(scale)
|
||||
elif h * w < MINRES * MINRES:
|
||||
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
||||
scale = MINRES * MINRES / (h * w)
|
||||
scale = math.sqrt(scale)
|
||||
else:
|
||||
scale = None
|
||||
else:
|
||||
scale = base_size * base_size / (h * w)
|
||||
scale = math.sqrt(scale)
|
||||
|
||||
|
||||
if scale is not None:
|
||||
new_h = int(h * scale / patch_size) * patch_size
|
||||
new_w = int(w * scale / patch_size) * patch_size
|
||||
new_h = max(new_h, patch_size)
|
||||
new_w = max(new_w, patch_size)
|
||||
image = image.resize((new_h, new_w))
|
||||
elif PAD2STRIDE:
|
||||
if h % patch_size == 0:
|
||||
new_h = h
|
||||
else:
|
||||
new_h = (h // patch_size + 1) * patch_size
|
||||
|
||||
if w % patch_size == 0:
|
||||
new_w = w
|
||||
else:
|
||||
new_w = (w // patch_size + 1) * patch_size
|
||||
image = pad_image(image, (new_h, new_w), value=127)
|
||||
else:
|
||||
scale = 1.0
|
||||
new_h = int(h * scale / patch_size) * patch_size
|
||||
new_w = int(w * scale / patch_size) * patch_size
|
||||
new_h = max(new_h, patch_size)
|
||||
new_w = max(new_w, patch_size)
|
||||
image = image.resize((new_h, new_w))
|
||||
|
||||
return image
|
||||
|
||||
def resize_video(image, patch_size=14, base_size=896):
|
||||
h, w = image.size
|
||||
if base_size == 0:
|
||||
if h * w > VIDEO_MAXRES * VIDEO_MAXRES:
|
||||
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
||||
scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w)
|
||||
scale = math.sqrt(scale)
|
||||
elif h * w < VIDEO_MINRES * VIDEO_MINRES:
|
||||
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
||||
scale = VIDEO_MINRES * VIDEO_MINRES / (h * w)
|
||||
scale = math.sqrt(scale)
|
||||
else:
|
||||
scale = None
|
||||
else:
|
||||
scale = base_size * base_size / (h * w)
|
||||
scale = math.sqrt(scale)
|
||||
|
||||
if scale is not None:
|
||||
new_h = int(h * scale / patch_size) * patch_size
|
||||
new_w = int(w * scale / patch_size) * patch_size
|
||||
image = image.resize((new_h, new_w))
|
||||
elif PAD2STRIDE:
|
||||
if h % patch_size == 0:
|
||||
new_h = h
|
||||
else:
|
||||
new_h = (h // patch_size + 1) * patch_size
|
||||
|
||||
if w % patch_size == 0:
|
||||
new_w = w
|
||||
else:
|
||||
new_w = (w // patch_size + 1) * patch_size
|
||||
image = pad_image(image, (new_h, new_w), value=127)
|
||||
else:
|
||||
scale = 1.0
|
||||
new_h = int(h * scale / patch_size) * patch_size
|
||||
new_w = int(w * scale / patch_size) * patch_size
|
||||
image = image.resize((new_h, new_w))
|
||||
|
||||
return image
|
||||
|
||||
def process_anyres_video(image, processor):
|
||||
if VIDEO_RESIZE is not None:
|
||||
image = resize_video(image, patch_size=video_ps, base_size=video_base)
|
||||
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||
return image.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError("VIDEO_RESIZE is not set")
|
||||
|
||||
def process_anyres_highres_image(image, processor):
|
||||
processor2 = None
|
||||
if type(processor) is tuple:
|
||||
processor, processor2 = processor[0], processor[1]
|
||||
|
||||
if HIGHRES_BASE is not None:
|
||||
image = resize_images(image, patch_size=highres_ps, base_size=highres_base)
|
||||
|
||||
if processor2 is not None:
|
||||
image_original_resize = image.resize((processor2.size['shortest_edge'], processor.size['shortest_edge']))
|
||||
image_patches = [image_original_resize] + [image_original_resize]
|
||||
image_patches = [processor2.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
||||
for image_patch in image_patches]
|
||||
else:
|
||||
if LOWRES_RESIZE is not None:
|
||||
if type(LOWRES_RESIZE) is int:
|
||||
image_original_resize = resize_images(image, patch_size=14, base_size=LOWRES_RESIZE)
|
||||
else:
|
||||
image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0])
|
||||
else:
|
||||
image_original_resize = image.resize((336, 336))
|
||||
image_patches = [image_original_resize]
|
||||
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
||||
for image_patch in image_patches]
|
||||
image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||
return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0)
|
||||
|
||||
def read_image_patch(patch_info):
|
||||
if 'img_path' in patch_info.keys():
|
||||
image = Image.open(patch_info['img_path']).convert('RGB')
|
||||
else:
|
||||
if 'image_encoing' in patch_info.keys():
|
||||
patch_info['image_encoding'] = patch_info['image_encoing']
|
||||
image_file_name = patch_info['patch']
|
||||
start_bytes = int(patch_info['start_num'])
|
||||
file_size = int(patch_info['size'])
|
||||
|
||||
with open(image_file_name, 'rb') as f:
|
||||
f.seek(start_bytes)
|
||||
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
|
||||
image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
|
||||
else:
|
||||
image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def get_model_name_from_path(model_path):
|
||||
model_path = model_path.strip("/")
|
||||
model_paths = model_path.split("/")
|
||||
if model_paths[-1].startswith('checkpoint-'):
|
||||
return model_paths[-2] + "_" + model_paths[-1]
|
||||
else:
|
||||
return model_paths[-1]
|
||||
|
||||
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
self.keyword_ids = []
|
||||
for keyword in keywords:
|
||||
cur_keyword_ids = tokenizer(keyword).input_ids
|
||||
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
||||
cur_keyword_ids = cur_keyword_ids[1:]
|
||||
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
||||
self.tokenizer = tokenizer
|
||||
self.start_len = input_ids.shape[1]
|
||||
|
||||
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
||||
offset = min(output_ids.shape[1] - self.start_len, 3)
|
||||
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
||||
for keyword_id in self.keyword_ids:
|
||||
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
|
||||
return True
|
||||
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
||||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
1
opencompass/models/ola/model/__init__.py
Normal file
1
opencompass/models/ola/model/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen
|
91
opencompass/models/ola/model/builder.py
Normal file
91
opencompass/models/ola/model/builder.py
Normal file
@ -0,0 +1,91 @@
|
||||
import os
|
||||
import warnings
|
||||
import shutil
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
||||
import torch
|
||||
from opencompass.models.ola.model import *
|
||||
from opencompass.models.ola.model.speech_encoder.builder import build_speech_encoder
|
||||
|
||||
def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs):
|
||||
if load_8bit:
|
||||
kwargs['load_in_8bit'] = True
|
||||
elif load_4bit:
|
||||
kwargs['load_in_4bit'] = True
|
||||
kwargs['quantization_config'] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type='nf4'
|
||||
)
|
||||
else:
|
||||
kwargs['torch_dtype'] = torch.bfloat16
|
||||
|
||||
if use_flash_attn:
|
||||
kwargs['attn_implementation'] = 'flash_attention_2'
|
||||
|
||||
model_cls = OlaQwenForCausalLM
|
||||
|
||||
# Load Ola model
|
||||
if is_lora:
|
||||
assert model_base is not None, "model_base is required for LoRA models."
|
||||
from ola.model.language_model.ola_qwen import OlaConfigQwen
|
||||
lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
||||
print('Loading Ola from base model...')
|
||||
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
|
||||
print('Loading additional Ola weights...')
|
||||
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
||||
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
||||
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
||||
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
||||
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
||||
model.load_state_dict(non_lora_trainables, strict=False)
|
||||
|
||||
from peft import PeftModel
|
||||
print('Loading LoRA weights...')
|
||||
model = PeftModel.from_pretrained(model, model_path)
|
||||
print('Merging LoRA weights...')
|
||||
model = model.merge_and_unload()
|
||||
print('Model is loaded...')
|
||||
elif model_base is not None:
|
||||
print('Loading Ola from base model...')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
||||
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
||||
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
|
||||
|
||||
speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu')
|
||||
speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
|
||||
model.load_state_dict(speech_projector_weights, strict=False)
|
||||
model = model.to(device=device)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
model = model_cls.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=False,
|
||||
**kwargs
|
||||
)
|
||||
model = model.to(device=device)
|
||||
|
||||
model.get_model().speech_encoder = build_speech_encoder(model.config)
|
||||
model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
|
||||
|
||||
image_processor = None
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
vision_tower = model.get_vision_tower()
|
||||
print("Loading vision tower...")
|
||||
if not vision_tower.is_loaded:
|
||||
vision_tower.load_model(device_map=device)
|
||||
if device != "auto":
|
||||
vision_tower.to(device="cuda", dtype=torch.bfloat16)
|
||||
else:
|
||||
vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
|
||||
image_processor = vision_tower.image_processor
|
||||
print("Loading vision tower succeeded.")
|
||||
|
||||
if hasattr(model.config, "max_sequence_length"):
|
||||
context_len = model.config.max_sequence_length
|
||||
else:
|
||||
context_len = 16384
|
||||
|
||||
return tokenizer, model, image_processor, context_len
|
237
opencompass/models/ola/model/language_model/ola_qwen.py
Normal file
237
opencompass/models/ola/model/language_model/ola_qwen.py
Normal file
@ -0,0 +1,237 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import transformers
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.generation.utils import GenerateOutput
|
||||
|
||||
from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM
|
||||
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
|
||||
|
||||
|
||||
class OlaConfigQwen(Qwen2Config):
|
||||
model_type = "ola_qwen"
|
||||
|
||||
|
||||
class OlaQwenModel(OlaMetaModel, Qwen2Model):
|
||||
config_class = OlaConfigQwen
|
||||
|
||||
def __init__(self, config: Qwen2Config):
|
||||
super(OlaQwenModel, self).__init__(config)
|
||||
|
||||
|
||||
class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM):
|
||||
config_class = OlaConfigQwen
|
||||
|
||||
def __init__(self, config):
|
||||
super(Qwen2ForCausalLM, self).__init__(config)
|
||||
|
||||
config.rope_scaling = None
|
||||
self.model = OlaQwenModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
speech: Optional[torch.FloatTensor] = None,
|
||||
speech_lengths: Optional[torch.LongTensor] = None,
|
||||
speech_chunks: Optional[torch.LongTensor] = None,
|
||||
speech_wav: Optional[torch.FloatTensor] = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
images_highres: Optional[List[torch.FloatTensor]] = None,
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
modalities: Optional[List[str]] = ["image"],
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
if inputs_embeds is None:
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
past_key_values,
|
||||
inputs_embeds,
|
||||
labels
|
||||
) = self.prepare_inputs_labels_for_speech_vision_text(
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
past_key_values,
|
||||
labels,
|
||||
speech,
|
||||
speech_lengths,
|
||||
speech_chunks,
|
||||
speech_wav,
|
||||
images,
|
||||
modalities,
|
||||
image_sizes,
|
||||
images_highres
|
||||
)
|
||||
|
||||
if labels is None:
|
||||
return super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict
|
||||
)
|
||||
else:
|
||||
return self.forward_llm_efficient(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict
|
||||
)
|
||||
|
||||
|
||||
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_dim = hidden_states.size(-1)
|
||||
shift_labels = labels[..., 1:].contiguous().reshape(-1)
|
||||
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
|
||||
assert shift_labels.size(0) == shift_hidden_states.size(0)
|
||||
mask = shift_labels > -1
|
||||
assert mask.float().sum() > 0
|
||||
shift_labels = shift_labels[mask]
|
||||
shift_hidden_states = shift_hidden_states[mask, :]
|
||||
logits = self.lm_head(shift_hidden_states)
|
||||
logits = logits.float()
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(logits, shift_labels)
|
||||
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
speech: Optional[torch.Tensor] = None,
|
||||
speech_lengths: Optional[torch.Tensor] = None,
|
||||
speech_chunks: Optional[torch.Tensor] = None,
|
||||
speech_wav: Optional[torch.FloatTensor] = None,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
images_highres: Optional[List[torch.FloatTensor]] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
modalities: Optional[List[str]] = ["image"],
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
position_ids = kwargs.pop("position_ids", None)
|
||||
attention_mask = kwargs.pop("attention_mask", None)
|
||||
if "inputs_embeds" in kwargs:
|
||||
raise NotImplementedError("`inputs_embeds` is not supported")
|
||||
|
||||
(
|
||||
inputs,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
_,
|
||||
inputs_embeds,
|
||||
_
|
||||
) = self.prepare_inputs_labels_for_speech_vision_text(
|
||||
inputs,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
None,
|
||||
None,
|
||||
speech,
|
||||
speech_lengths,
|
||||
speech_chunks,
|
||||
speech_wav,
|
||||
images,
|
||||
modalities,
|
||||
image_sizes,
|
||||
images_highres
|
||||
)
|
||||
|
||||
return super().generate(
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
||||
inputs_embeds=None, **kwargs):
|
||||
speech = kwargs.pop("speech", None)
|
||||
speech_lengths = kwargs.pop("speech_lengths", None)
|
||||
speech_chunks = kwargs.pop("speech_chunks", None)
|
||||
images = kwargs.pop("images", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
inputs = super().prepare_inputs_for_generation(
|
||||
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
||||
)
|
||||
if speech is not None:
|
||||
inputs['speech'] = speech
|
||||
inputs['speech_lengths'] = speech_lengths
|
||||
inputs['speech_chunks'] = speech_chunks
|
||||
if images is not None:
|
||||
inputs["images"] = images
|
||||
if image_sizes is not None:
|
||||
inputs["image_sizes"] = image_sizes
|
||||
return inputs
|
||||
|
||||
AutoConfig.register("ola_qwen", OlaConfigQwen)
|
||||
AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM)
|
@ -0,0 +1,9 @@
|
||||
import os
|
||||
from .oryx_vit import SigLIPViTAnysizeWrapper
|
||||
|
||||
def build_vision_tower(vision_tower_cfg, **kwargs):
|
||||
vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None))
|
||||
is_absolute_path_exists = os.path.exists(vision_tower)
|
||||
print(f"Buiding OryxViTWrapper from {vision_tower}...")
|
||||
# path = vision_tower.split(":")[1]
|
||||
return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs)
|
1075
opencompass/models/ola/model/multimodal_encoder/oryx_vit.py
Normal file
1075
opencompass/models/ola/model/multimodal_encoder/oryx_vit.py
Normal file
File diff suppressed because it is too large
Load Diff
172
opencompass/models/ola/model/multimodal_projector/builder.py
Normal file
172
opencompass/models/ola/model/multimodal_projector/builder.py
Normal file
@ -0,0 +1,172 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import re
|
||||
|
||||
import math
|
||||
|
||||
from .pooler_projector import NormalizedDwPooler
|
||||
import os
|
||||
import math
|
||||
|
||||
class IdentityMap(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {"mm_projector_type": 'identity'}
|
||||
|
||||
|
||||
class SimpleResBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(channels)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(channels, channels),
|
||||
nn.GELU(),
|
||||
nn.Linear(channels, channels)
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.pre_norm(x)
|
||||
return x + self.proj(x)
|
||||
|
||||
class OlaMLP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, twoview=False):
|
||||
super().__init__()
|
||||
|
||||
self.proj1 = nn.Linear(in_channels, out_channels)
|
||||
self.proj2 = nn.Linear(out_channels, out_channels)
|
||||
self.act = nn.GELU()
|
||||
self.pooler = NormalizedDwPooler(out_channels)
|
||||
|
||||
embed_std = 1 / math.sqrt(out_channels)
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.randn(out_channels) * embed_std
|
||||
)
|
||||
self.image_begin = nn.Parameter(
|
||||
torch.randn(out_channels) * embed_std
|
||||
)
|
||||
self.image_end = nn.Parameter(
|
||||
torch.randn(out_channels) * embed_std
|
||||
)
|
||||
|
||||
if twoview:
|
||||
self.image_sep = nn.Parameter(
|
||||
torch.randn(out_channels) * embed_std
|
||||
)
|
||||
|
||||
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'):
|
||||
|
||||
if modalities in ['image', 'text']:
|
||||
h, w = size
|
||||
dtype = x.dtype
|
||||
x = x.reshape(x.shape[0], h, w, -1)
|
||||
x = self.proj1(x)
|
||||
x = self.pooler(x, forward_type='2x')
|
||||
x = self.act(x)
|
||||
x = self.proj2(x)
|
||||
|
||||
|
||||
b, h, w, c = x.shape
|
||||
x = torch.cat([
|
||||
x,
|
||||
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
|
||||
], dim=2)
|
||||
x = x.reshape(b, -1, c)
|
||||
|
||||
if x2 is not None:
|
||||
h2, w2 = size2
|
||||
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
||||
x2 = self.proj1(x2)
|
||||
x2 = self.pooler(x2, forward_type='2x')
|
||||
x2 = self.act(x2)
|
||||
x2 = self.proj2(x2)
|
||||
|
||||
b2, h2, w2, c2 = x2.shape
|
||||
x2 = torch.cat([
|
||||
x2,
|
||||
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
|
||||
], dim=2)
|
||||
x2 = x2.reshape(b, -1, c)
|
||||
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
|
||||
x = torch.cat([x, sep, x2], dim=1)
|
||||
|
||||
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
||||
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
||||
x = torch.cat([begin, x, end], dim=1)
|
||||
return x
|
||||
elif modalities in ['video']:
|
||||
# x2 is the true feature, ignore x
|
||||
h, w = size
|
||||
dtype = x.dtype
|
||||
x = x.reshape(x.shape[0], h, w, -1)
|
||||
x1 = self.proj1(x)
|
||||
x1 = self.pooler(x1, forward_type='2x')
|
||||
x1 = self.proj2(x1).mean() * 0.0
|
||||
|
||||
h2, w2 = size2
|
||||
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
||||
x2 = self.proj1(x2)
|
||||
x2 = self.pooler(x2, forward_type='2x')
|
||||
x2 = self.act(x2)
|
||||
x2 = self.proj2(x2)
|
||||
|
||||
b2, h2, w2, c = x2.shape
|
||||
x2 = torch.cat([
|
||||
x2,
|
||||
self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype)
|
||||
], dim=2)
|
||||
|
||||
x2 = x2.reshape(b2, -1, c)
|
||||
|
||||
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype)
|
||||
x2 = torch.cat([x2, sep], dim=1)
|
||||
|
||||
x2 = x2.flatten(0, 1)
|
||||
|
||||
begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype)
|
||||
end = self.image_end.reshape(1, -1).expand(1, c).to(dtype)
|
||||
x2 = torch.cat([begin, x2, end], dim=0)
|
||||
x2 = x2.unsqueeze(0)
|
||||
return x2
|
||||
else:
|
||||
raise ValueError(f'Unknown modalities: {modalities}')
|
||||
|
||||
def build_vision_projector(config, delay_load=False, **kwargs):
|
||||
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
||||
|
||||
if projector_type == 'linear':
|
||||
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
||||
|
||||
elif projector_type == 'ola_mlp':
|
||||
return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True)
|
||||
|
||||
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
||||
if mlp_gelu_match:
|
||||
mlp_depth = int(mlp_gelu_match.group(1))
|
||||
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type)
|
||||
if mlp_gelu_resnet_match:
|
||||
mlp_depth = int(mlp_gelu_resnet_match.group(1))
|
||||
res_depth = int(mlp_gelu_resnet_match.group(2))
|
||||
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
||||
for _ in range(res_depth):
|
||||
modules.append(SimpleResBlock(config.hidden_size))
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
if projector_type == 'identity':
|
||||
return IdentityMap()
|
||||
|
||||
raise ValueError(f'Unknown projector type: {projector_type}')
|
@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from transformers.models.clip.modeling_clip import CLIPVisionModel
|
||||
import os
|
||||
|
||||
class PoolerProjector(nn.Module):
|
||||
def __init__(self, config, vision_cfg):
|
||||
super().__init__()
|
||||
self._config = config
|
||||
self.hw = vision_cfg.image_size // vision_cfg.patch_size
|
||||
|
||||
self.conv_pool = nn.Conv2d(
|
||||
config.mm_hidden_size, config.hidden_size,
|
||||
kernel_size=2, stride=2
|
||||
)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(config.hidden_size, config.hidden_size),
|
||||
)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
height = width = self.hw
|
||||
assert height * width == x.shape[1]
|
||||
x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
|
||||
x = self.conv_pool(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {"mm_projector_type": 'pooler'}
|
||||
|
||||
|
||||
class NormalizedDwPooler(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.predictor = nn.Sequential(
|
||||
nn.Linear(dim*2, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim),
|
||||
)
|
||||
|
||||
def forward(self, x, forward_type='2x'):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
if forward_type == '2x':
|
||||
new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
|
||||
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
|
||||
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
||||
elif forward_type == '1x':
|
||||
new_x = x.reshape(B, H, W, 1, C)
|
||||
fused_x = torch.cat([new_x, new_x], dim=-1)
|
||||
elif forward_type == '4x':
|
||||
new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
|
||||
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
|
||||
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
||||
|
||||
score = self.predictor(fused_x)
|
||||
normalized_score = F.softmax(score, dim=-2)
|
||||
new_x = (new_x * normalized_score).sum(dim=-2)
|
||||
return new_x
|
20
opencompass/models/ola/model/multimodal_resampler/builder.py
Normal file
20
opencompass/models/ola/model/multimodal_resampler/builder.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
class IdentityMap(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {"mm_resampler_type": None}
|
||||
|
||||
def build_vision_resampler(model_args, delay_load=False, **kwargs):
|
||||
# import pdb;pdb.set_trace()
|
||||
resampler_type = getattr(model_args, 'mm_resampler_type', None)
|
||||
if resampler_type is None:
|
||||
return IdentityMap()
|
||||
else:
|
||||
raise ValueError(f'Unknown resampler type: {resampler_type}')
|
397
opencompass/models/ola/model/ola_arch.py
Normal file
397
opencompass/models/ola/model/ola_arch.py
Normal file
@ -0,0 +1,397 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from .speech_encoder.builder import build_speech_encoder
|
||||
from .speech_projector.builder import build_speech_projector
|
||||
from opencompass.models.ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX
|
||||
from opencompass.models.ola.utils import lengths_to_padding_mask
|
||||
|
||||
from .multimodal_encoder.builder import build_vision_tower
|
||||
from .multimodal_resampler.builder import build_vision_resampler
|
||||
from .multimodal_projector.builder import build_vision_projector
|
||||
|
||||
from opencompass.models.ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
||||
|
||||
class OlaMetaModel:
|
||||
|
||||
def __init__(self, config):
|
||||
super(OlaMetaModel, self).__init__(config)
|
||||
if hasattr(config, "speech_encoder"):
|
||||
self.speech_encoder = build_speech_encoder(config)
|
||||
self.speech_projector = build_speech_projector(config)
|
||||
|
||||
if hasattr(config, "mm_vision_tower"):
|
||||
self.vision_tower = build_vision_tower(config, delay_load=True)
|
||||
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
|
||||
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
|
||||
|
||||
def get_speech_encoder(self):
|
||||
speech_encoder = getattr(self, 'speech_encoder', None)
|
||||
if type(speech_encoder) is list:
|
||||
speech_encoder = speech_encoder[0]
|
||||
return speech_encoder
|
||||
|
||||
def get_vision_tower(self):
|
||||
vision_tower = getattr(self, 'vision_tower', None)
|
||||
if type(vision_tower) is list:
|
||||
vision_tower = vision_tower[0]
|
||||
return vision_tower
|
||||
|
||||
def initialize_speech_modules(self, model_args, fsdp=None):
|
||||
self.config.speech_encoder = getattr(model_args, "speech_encoder", None)
|
||||
self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None)
|
||||
self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear')
|
||||
self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5)
|
||||
self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280)
|
||||
self.config.music_encoder = getattr(model_args, 'music_encoder', None)
|
||||
|
||||
if self.get_speech_encoder() is None:
|
||||
speech_encoder = build_speech_encoder(self.config)
|
||||
if fsdp is not None and len(fsdp) > 0:
|
||||
self.speech_encoder = [speech_encoder]
|
||||
else:
|
||||
self.speech_encoder = speech_encoder
|
||||
|
||||
if getattr(self, 'speech_projector', None) is None:
|
||||
self.speech_projector = build_speech_projector(self.config)
|
||||
else:
|
||||
# In case it is frozen by LoRA
|
||||
for p in self.speech_projector.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
if model_args.pretrain_speech_projector is not None:
|
||||
pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu')
|
||||
def get_w(weights, keyword):
|
||||
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
||||
print('Loading pretrain speech projector weights')
|
||||
|
||||
msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False)
|
||||
print(msg)
|
||||
|
||||
def initialize_vision_modules(self, model_args, fsdp=None):
|
||||
vision_tower = model_args.vision_tower
|
||||
mm_vision_select_layer = model_args.mm_vision_select_layer
|
||||
mm_vision_select_feature = model_args.mm_vision_select_feature
|
||||
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
||||
|
||||
self.config.mm_vision_tower = vision_tower
|
||||
|
||||
if self.get_vision_tower() is None:
|
||||
vision_tower = build_vision_tower(model_args)
|
||||
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
||||
## Get the mm_spatial_pool_mode and mm_spatial_pool_stride
|
||||
for k, v in vision_resampler.config.items():
|
||||
setattr(self.config, k, v)
|
||||
|
||||
if fsdp is not None and len(fsdp) > 0:
|
||||
self.vision_tower = [vision_tower]
|
||||
self.vision_resampler = [vision_resampler]
|
||||
else:
|
||||
self.vision_tower = vision_tower
|
||||
self.vision_resampler = vision_resampler
|
||||
else:
|
||||
if fsdp is not None and len(fsdp) > 0:
|
||||
vision_resampler = self.vision_resampler[0]
|
||||
vision_tower = self.vision_tower[0]
|
||||
else:
|
||||
vision_resampler = self.vision_resampler
|
||||
vision_tower = self.vision_tower
|
||||
vision_tower.load_model() # 已加载权重,故跳过
|
||||
|
||||
# In case it is frozen by LoRA
|
||||
for p in self.vision_resampler.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.config.use_mm_proj = True
|
||||
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
||||
self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size)
|
||||
|
||||
self.config.mm_vision_select_layer = mm_vision_select_layer
|
||||
self.config.mm_vision_select_feature = mm_vision_select_feature
|
||||
|
||||
if getattr(self, 'mm_projector', None) is None:
|
||||
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
||||
else:
|
||||
for p in self.mm_projector.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
if pretrain_mm_mlp_adapter is not None:
|
||||
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
||||
def get_w(weights, keyword):
|
||||
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
||||
|
||||
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
||||
print('Loading pretrain mm projector weights')
|
||||
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False)
|
||||
print(incompatible_keys)
|
||||
|
||||
class OlaMetaForCausalLM(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self):
|
||||
pass
|
||||
|
||||
def get_speech_encoder(self):
|
||||
return self.get_model().get_speech_encoder()
|
||||
|
||||
def get_vision_tower(self):
|
||||
return self.get_model().get_vision_tower()
|
||||
|
||||
def get_speech_projector(self):
|
||||
return self.get_model().speech_projector
|
||||
|
||||
def encode_speech(self, speech, speech_lengths, speech_wav):
|
||||
# import pdb; pdb.set_trace()
|
||||
speech_encoder_type = self.config.speech_encoder_type
|
||||
speech_encoder = self.get_speech_encoder()
|
||||
if "whisper" in speech_encoder_type.lower():
|
||||
encoder_outs = speech_encoder(speech.permute(0, 2, 1))
|
||||
speech_lengths = (speech_lengths + 1) // 2
|
||||
else:
|
||||
encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav)
|
||||
speech_lengths = (speech_lengths + 1) // 2
|
||||
speech_projector_type = self.config.speech_projector_type
|
||||
speech_projector = self.get_speech_projector()
|
||||
if speech_projector_type == "linear":
|
||||
encoder_outs = speech_projector(encoder_outs)
|
||||
speech_lengths = speech_lengths // speech_projector.k
|
||||
else:
|
||||
raise ValueError(f'Unknown speech projector: {speech_projector_type}')
|
||||
# speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))]
|
||||
return encoder_outs
|
||||
|
||||
def prepare_inputs_labels_for_speech_vision_text(
|
||||
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
||||
speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None
|
||||
):
|
||||
speech_encoder = self.get_speech_encoder()
|
||||
vision_tower = self.get_vision_tower()
|
||||
|
||||
if speech_encoder is None or input_ids.shape[1] == 1:
|
||||
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
||||
|
||||
if vision_tower is None or input_ids.shape[1] == 1:
|
||||
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
||||
# encode speech
|
||||
if not isinstance(speech, list):
|
||||
speech = torch.split(speech, speech_chunks.tolist(), dim=0)
|
||||
speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0)
|
||||
speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0)
|
||||
speech_features = []
|
||||
for idx in range(len(speech)):
|
||||
speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx]))
|
||||
|
||||
# encode vision
|
||||
if isinstance(modalities, str):
|
||||
modalities = [modalities]
|
||||
|
||||
video_idx_in_batch = []
|
||||
for modal in range(len(modalities)):
|
||||
if 'video' in modalities[modal]:
|
||||
video_idx_in_batch.append(modal)
|
||||
|
||||
aimg = images[-1]
|
||||
lowres_img = []
|
||||
for idx, img_feat in enumerate(images):
|
||||
if idx in video_idx_in_batch:
|
||||
img_feat = aimg.new(1, 3, 128, 128).fill_(0)
|
||||
lowres_img.append(img_feat)
|
||||
|
||||
lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img)
|
||||
highres_img_features = []
|
||||
highres_img_sizes = []
|
||||
for idx, img_feat in enumerate(images_highres):
|
||||
if img_feat.ndim == 5:
|
||||
img_feat = img_feat.squeeze(1)
|
||||
highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat)
|
||||
highres_img_features.append(highres_img_feature)
|
||||
highres_img_sizes.append(highres_img_size)
|
||||
image_features = []
|
||||
for idx in range(len(modalities)):
|
||||
img_feat = self.get_model().mm_projector(lowres_img_features[idx],
|
||||
lowres_img_sizes[idx],
|
||||
highres_img_features[idx],
|
||||
highres_img_sizes[idx],
|
||||
modalities[idx])
|
||||
image_features.append(img_feat.flatten(0, 1))
|
||||
|
||||
_labels = labels
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
||||
else:
|
||||
attention_mask = attention_mask.bool()
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
||||
if labels is None:
|
||||
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
||||
|
||||
# remove the padding using attention_mask -- FIXME
|
||||
_input_ids = input_ids
|
||||
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
||||
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
||||
|
||||
new_input_embeds = []
|
||||
new_labels = []
|
||||
cur_speech_idx = 0
|
||||
cur_image_idx = 0
|
||||
for batch_idx, cur_input_ids in enumerate(input_ids):
|
||||
|
||||
num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
|
||||
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
||||
|
||||
num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
|
||||
|
||||
if num_speech_images == 0:
|
||||
cur_speech_features = speech_features[cur_speech_idx]
|
||||
cur_images_features = image_features[cur_image_idx]
|
||||
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
||||
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0)
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
new_labels.append(labels[batch_idx])
|
||||
cur_speech_idx += 1
|
||||
cur_image_idx += 1
|
||||
continue
|
||||
speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]]
|
||||
|
||||
cur_input_ids_nospeech_image = []
|
||||
cur_labels = labels[batch_idx]
|
||||
cur_labels_nospeech_image = []
|
||||
for i in range(len(speech_image_token_indices) - 1):
|
||||
cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
|
||||
cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
|
||||
split_sizes = [x.shape[0] for x in cur_labels_nospeech_image]
|
||||
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image))
|
||||
cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0)
|
||||
cur_new_input_embeds = []
|
||||
cur_new_labels = []
|
||||
|
||||
for i in range(num_speech_images + 1):
|
||||
cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i])
|
||||
cur_new_labels.append(cur_labels_nospeech_image[i])
|
||||
if i < num_speech_images:
|
||||
if i < num_images:
|
||||
cur_images_features = image_features[cur_image_idx]
|
||||
cur_image_idx += 1
|
||||
cur_new_input_embeds.append(cur_images_features)
|
||||
cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
||||
else:
|
||||
cur_speech_features = speech_features[cur_speech_idx]
|
||||
cur_speech_idx += 1
|
||||
cur_new_input_embeds.append(cur_speech_features)
|
||||
cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
||||
|
||||
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
||||
|
||||
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
||||
cur_new_labels = torch.cat(cur_new_labels)
|
||||
|
||||
if num_images == 0:
|
||||
cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0)
|
||||
cur_image_idx += 1
|
||||
|
||||
if num_speech == 0:
|
||||
cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0)
|
||||
cur_speech_idx += 1
|
||||
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
new_labels.append(cur_new_labels)
|
||||
|
||||
# Truncate sequences to max length as speech features can make the sequence longer
|
||||
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
||||
if tokenizer_model_max_length is not None:
|
||||
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
||||
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
||||
|
||||
# Combine them
|
||||
max_len = max(x.shape[0] for x in new_input_embeds)
|
||||
batch_size = len(new_input_embeds)
|
||||
|
||||
new_input_embeds_padded = []
|
||||
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
||||
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
||||
|
||||
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
||||
cur_len = cur_new_embed.shape[0]
|
||||
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
||||
new_input_embeds_padded.append(torch.cat((
|
||||
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
||||
cur_new_embed
|
||||
), dim=0))
|
||||
if cur_len > 0:
|
||||
new_labels_padded[i, -cur_len:] = cur_new_labels
|
||||
attention_mask[i, -cur_len:] = True
|
||||
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
||||
else:
|
||||
new_input_embeds_padded.append(torch.cat((
|
||||
cur_new_embed,
|
||||
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
||||
), dim=0))
|
||||
if cur_len > 0:
|
||||
new_labels_padded[i, :cur_len] = cur_new_labels
|
||||
attention_mask[i, :cur_len] = True
|
||||
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
||||
|
||||
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
||||
|
||||
if _labels is None:
|
||||
new_labels = None
|
||||
else:
|
||||
new_labels = new_labels_padded
|
||||
|
||||
if _attention_mask is None:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
||||
|
||||
if _position_ids is None:
|
||||
position_ids = None
|
||||
|
||||
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
||||
|
||||
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
||||
if model_args.mm_use_im_patch_token:
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if model_args.mm_use_im_start_end:
|
||||
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = self.get_input_embeddings().weight.data
|
||||
output_embeddings = self.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
if model_args.tune_mm_mlp_adapter:
|
||||
for p in self.get_input_embeddings().parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.get_output_embeddings().parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if model_args.pretrain_mm_mlp_adapter:
|
||||
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
||||
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
||||
assert num_new_tokens == 2
|
||||
if input_embeddings.shape == embed_tokens_weight.shape:
|
||||
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
||||
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
||||
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
||||
else:
|
||||
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
||||
elif model_args.mm_use_im_patch_token:
|
||||
if model_args.tune_mm_mlp_adapter:
|
||||
for p in self.get_input_embeddings().parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.get_output_embeddings().parameters():
|
||||
p.requires_grad = False
|
182
opencompass/models/ola/model/speech_encoder/beats/BEATs.py
Normal file
182
opencompass/models/ola/model/speech_encoder/beats/BEATs.py
Normal file
@ -0,0 +1,182 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
# import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
from .kaldi import fbank as kaldi_fbank
|
||||
|
||||
from .backbone import (
|
||||
TransformerEncoder,
|
||||
)
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BEATsConfig:
|
||||
def __init__(self, cfg=None):
|
||||
self.input_patch_size: int = -1 # path size of patch embedding
|
||||
self.embed_dim: int = 512 # patch embedding dimension
|
||||
self.conv_bias: bool = False # include bias in conv encoder
|
||||
|
||||
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||
self.activation_fn: str = "gelu" # activation function to use
|
||||
|
||||
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
||||
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||
|
||||
# dropouts
|
||||
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||
|
||||
# positional embeddings
|
||||
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||
|
||||
# relative position embedding
|
||||
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||
|
||||
# label predictor
|
||||
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
||||
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
||||
self.predictor_class: int = 527 # target class number for the predictor
|
||||
|
||||
if cfg is not None:
|
||||
self.update(cfg)
|
||||
|
||||
def update(self, cfg: dict):
|
||||
self.__dict__.update(cfg)
|
||||
|
||||
|
||||
class BEATs(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: BEATsConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"BEATs Config: {cfg.__dict__}")
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
self.embed = cfg.embed_dim
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
if self.embed != cfg.encoder_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.input_patch_size = cfg.input_patch_size
|
||||
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||
bias=cfg.conv_bias)
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
|
||||
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
if cfg.finetuned_model:
|
||||
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
||||
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
||||
else:
|
||||
self.predictor = None
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(
|
||||
padding_mask.size(0), features.size(1), -1
|
||||
)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
) -> torch.Tensor:
|
||||
fbanks = []
|
||||
for waveform in source:
|
||||
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||
fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||
fbanks.append(fbank)
|
||||
fbank = torch.stack(fbanks, dim=0)
|
||||
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||
return fbank
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
feature_only=False,
|
||||
):
|
||||
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||
|
||||
fbank = fbank.unsqueeze(1)
|
||||
features = self.patch_embedding(fbank)
|
||||
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
x = self.dropout_input(features)
|
||||
|
||||
x, layer_results = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
if not feature_only and self.predictor is not None:
|
||||
x = self.predictor_dropout(x)
|
||||
logits = self.predictor(x)
|
||||
|
||||
if padding_mask is not None and padding_mask.any():
|
||||
logits[padding_mask] = 0
|
||||
logits = logits.sum(dim=1)
|
||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
||||
else:
|
||||
logits = logits.mean(dim=1)
|
||||
|
||||
lprobs = torch.sigmoid(logits)
|
||||
|
||||
return lprobs, padding_mask
|
||||
else:
|
||||
return x, padding_mask
|
174
opencompass/models/ola/model/speech_encoder/beats/Tokenizers.py
Normal file
174
opencompass/models/ola/model/speech_encoder/beats/Tokenizers.py
Normal file
@ -0,0 +1,174 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
# import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
from .kaldi import fbank as kaldi_fbank
|
||||
|
||||
from .backbone import (
|
||||
TransformerEncoder,
|
||||
)
|
||||
from .quantizer import (
|
||||
NormEMAVectorQuantizer,
|
||||
)
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenizersConfig:
|
||||
def __init__(self, cfg=None):
|
||||
self.input_patch_size: int = -1 # path size of patch embedding
|
||||
self.embed_dim: int = 512 # patch embedding dimension
|
||||
self.conv_bias: bool = False # include bias in conv encoder
|
||||
|
||||
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||
self.activation_fn: str = "gelu" # activation function to use
|
||||
|
||||
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||
|
||||
# dropouts
|
||||
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||
|
||||
# positional embeddings
|
||||
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||
|
||||
# relative position embedding
|
||||
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||
|
||||
# quantizer
|
||||
self.quant_n: int = 1024 # codebook number in quantizer
|
||||
self.quant_dim: int = 256 # codebook dimension in quantizer
|
||||
|
||||
if cfg is not None:
|
||||
self.update(cfg)
|
||||
|
||||
def update(self, cfg: dict):
|
||||
self.__dict__.update(cfg)
|
||||
|
||||
|
||||
class Tokenizers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: TokenizersConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
self.embed = cfg.embed_dim
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
if self.embed != cfg.encoder_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.input_patch_size = cfg.input_patch_size
|
||||
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||
bias=cfg.conv_bias)
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
|
||||
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
self.quantize = NormEMAVectorQuantizer(
|
||||
n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
|
||||
)
|
||||
self.quant_n = cfg.quant_n
|
||||
self.quantize_layer = nn.Sequential(
|
||||
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
|
||||
)
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(
|
||||
padding_mask.size(0), features.size(1), -1
|
||||
)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
) -> torch.Tensor:
|
||||
fbanks = []
|
||||
for waveform in source:
|
||||
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||
fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||
fbanks.append(fbank)
|
||||
fbank = torch.stack(fbanks, dim=0)
|
||||
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||
return fbank
|
||||
|
||||
def extract_labels(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
fbank_mean: float = 15.41663,
|
||||
fbank_std: float = 6.55582,
|
||||
):
|
||||
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||
|
||||
fbank = fbank.unsqueeze(1)
|
||||
features = self.patch_embedding(fbank)
|
||||
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
x = self.dropout_input(features)
|
||||
|
||||
x, layer_results = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
quantize_input = self.quantize_layer(x)
|
||||
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
||||
|
||||
return embed_ind
|
782
opencompass/models/ola/model/speech_encoder/beats/backbone.py
Normal file
782
opencompass/models/ola/model/speech_encoder/beats/backbone.py
Normal file
@ -0,0 +1,782 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Dict, Optional, Tuple
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm, Parameter
|
||||
from .modules import (
|
||||
GradMultiply,
|
||||
SamePad,
|
||||
get_activation_fn,
|
||||
GLU_Linear,
|
||||
quant_noise,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = args.dropout
|
||||
self.embedding_dim = args.encoder_embed_dim
|
||||
|
||||
self.pos_conv = nn.Conv1d(
|
||||
self.embedding_dim,
|
||||
self.embedding_dim,
|
||||
kernel_size=args.conv_pos,
|
||||
padding=args.conv_pos // 2,
|
||||
groups=args.conv_pos_groups,
|
||||
)
|
||||
dropout = 0
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(self.pos_conv.bias, 0)
|
||||
|
||||
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||
|
||||
if hasattr(args, "relative_position_embedding"):
|
||||
self.relative_position_embedding = args.relative_position_embedding
|
||||
self.num_buckets = args.num_buckets
|
||||
self.max_distance = args.max_distance
|
||||
else:
|
||||
self.relative_position_embedding = False
|
||||
self.num_buckets = 0
|
||||
self.max_distance = 0
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
deep_norm=args.deep_norm,
|
||||
has_relative_attention_bias=self.relative_position_embedding,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
gru_rel_pos=args.gru_rel_pos,
|
||||
encoder_layers=args.encoder_layers,
|
||||
)
|
||||
for i in range(args.encoder_layers)
|
||||
]
|
||||
)
|
||||
if self.relative_position_embedding:
|
||||
for i in range(1, args.encoder_layers):
|
||||
del self.layers[i].self_attn.relative_attention_bias
|
||||
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
||||
|
||||
self.layer_norm_first = args.layer_norm_first
|
||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||
self.layerdrop = args.encoder_layerdrop
|
||||
|
||||
self.apply(init_bert_params)
|
||||
|
||||
if args.deep_norm:
|
||||
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
||||
for i in range(args.encoder_layers):
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
||||
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
||||
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
||||
|
||||
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None):
|
||||
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
||||
|
||||
if padding_mask is not None:
|
||||
x[padding_mask] = 0
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
z = None
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
r = None
|
||||
pos_bias = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
if self.layer_wise_gradient_decay_ratio != 1.0:
|
||||
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
||||
dropout_probability = np.random.random()
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class TransformerSentenceEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
ffn_embedding_dim: float = 3072,
|
||||
num_attention_heads: float = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
deep_norm: bool = False,
|
||||
has_relative_attention_bias: bool = False,
|
||||
num_buckets: int = 0,
|
||||
max_distance: int = 0,
|
||||
rescale_init: bool = False,
|
||||
gru_rel_pos: bool = False,
|
||||
encoder_layers: int = 0,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
self.activation_name = activation_fn
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
num_buckets=num_buckets,
|
||||
max_distance=max_distance,
|
||||
rescale_init=rescale_init,
|
||||
gru_rel_pos=gru_rel_pos,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
if self.activation_name == "glu":
|
||||
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||
else:
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
self.deep_norm = deep_norm
|
||||
if self.deep_norm:
|
||||
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
||||
else:
|
||||
self.deep_norm_alpha = 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
need_weights: bool = False,
|
||||
pos_bias=None
|
||||
):
|
||||
residual = x
|
||||
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias
|
||||
)
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
else:
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias
|
||||
)
|
||||
|
||||
x = self.dropout1(x)
|
||||
x = residual * self.deep_norm_alpha + x
|
||||
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual * self.deep_norm_alpha + x
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, attn, pos_bias
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
has_relative_attention_bias=False,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
gru_rel_pos=False,
|
||||
rescale_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = nn.Dropout(dropout)
|
||||
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_head_dim = self.head_dim
|
||||
self.k_head_dim = self.head_dim
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
k_bias = True
|
||||
if rescale_init:
|
||||
k_bias = False
|
||||
|
||||
k_embed_dim = embed_dim
|
||||
q_embed_dim = embed_dim
|
||||
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
||||
)
|
||||
self.v_proj = quant_noise(
|
||||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.q_proj = quant_noise(
|
||||
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
self.out_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self.gru_rel_pos = gru_rel_pos
|
||||
if self.gru_rel_pos:
|
||||
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
nn.init.xavier_normal_(self.bias_v)
|
||||
if self.has_relative_attention_bias:
|
||||
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||
|
||||
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||
num_buckets = self.num_buckets
|
||||
max_distance = self.max_distance
|
||||
relative_buckets = 0
|
||||
|
||||
if bidirectional:
|
||||
num_buckets = num_buckets // 2
|
||||
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||
relative_positions = torch.abs(relative_positions)
|
||||
else:
|
||||
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_positions < max_exact
|
||||
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_positions.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length):
|
||||
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_positions_bucket(
|
||||
relative_position,
|
||||
bidirectional=True
|
||||
)
|
||||
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(relative_position_bucket)
|
||||
values = values.permute([2, 0, 1])
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = True,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
position_bias: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert key_bsz == bsz
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
if self.has_relative_attention_bias and position_bias is None:
|
||||
position_bias = self.compute_bias(tgt_len, src_len)
|
||||
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
alpha = 32
|
||||
q *= 1 / alpha
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if k is not None:
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if v is not None:
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
assert v is not None
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
||||
key_padding_mask
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v, position_bias
|
||||
|
||||
if position_bias is not None:
|
||||
attn_mask_rel_pos = position_bias
|
||||
if self.gru_rel_pos == 1:
|
||||
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
||||
_B, _H, _L, __ = query_layer.size()
|
||||
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
||||
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
||||
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
||||
|
||||
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
||||
|
||||
attn_weights = attn_weights + attn_mask_rel_pos
|
||||
|
||||
attn_weights_float = F.softmax(
|
||||
attn_weights, dim=-1
|
||||
)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights, position_bias
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
if src_len > prev_key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||
device=prev_key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask.float()
|
||||
elif key_padding_mask is not None:
|
||||
if src_len > key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - key_padding_mask.size(1)),
|
||||
device=key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[filler.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask.float()
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the BERT Model.
|
||||
This overrides the default initializations depending on the specified arguments.
|
||||
1. If normal_init_linear_weights is set then weights of linear
|
||||
layer will be initialized using the normal distribution and
|
||||
bais will be set to the specified value.
|
||||
2. If normal_init_embed_weights is set then weights of embedding
|
||||
layer will be initialized using the normal distribution.
|
||||
3. If normal_init_proj_weights is set then weights of
|
||||
in_project_weight for MultiHeadAttention initialized using
|
||||
the normal distribution (to be validated).
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(
|
||||
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||
)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
813
opencompass/models/ola/model/speech_encoder/beats/kaldi.py
Normal file
813
opencompass/models/ola/model/speech_encoder/beats/kaldi.py
Normal file
@ -0,0 +1,813 @@
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
# import torchaudio
|
||||
from torch import Tensor
|
||||
|
||||
__all__ = [
|
||||
"get_mel_banks",
|
||||
"inverse_mel_scale",
|
||||
"inverse_mel_scale_scalar",
|
||||
"mel_scale",
|
||||
"mel_scale_scalar",
|
||||
"spectrogram",
|
||||
"fbank",
|
||||
"mfcc",
|
||||
"vtln_warp_freq",
|
||||
"vtln_warp_mel_freq",
|
||||
]
|
||||
|
||||
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
|
||||
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
|
||||
# 1 milliseconds = 0.001 seconds
|
||||
MILLISECONDS_TO_SECONDS = 0.001
|
||||
|
||||
# window types
|
||||
HAMMING = "hamming"
|
||||
HANNING = "hanning"
|
||||
POVEY = "povey"
|
||||
RECTANGULAR = "rectangular"
|
||||
BLACKMAN = "blackman"
|
||||
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
||||
|
||||
|
||||
def _get_epsilon(device, dtype):
|
||||
return EPSILON.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def _next_power_of_2(x: int) -> int:
|
||||
r"""Returns the smallest power of 2 that is greater than x"""
|
||||
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
|
||||
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
||||
representing how the window is shifted along the waveform. Each row is a frame.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of size ``num_samples``
|
||||
window_size (int): Frame length
|
||||
window_shift (int): Frame shift
|
||||
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends.
|
||||
|
||||
Returns:
|
||||
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
||||
"""
|
||||
assert waveform.dim() == 1
|
||||
num_samples = waveform.size(0)
|
||||
strides = (window_shift * waveform.stride(0), waveform.stride(0))
|
||||
|
||||
if snip_edges:
|
||||
if num_samples < window_size:
|
||||
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
||||
else:
|
||||
m = 1 + (num_samples - window_size) // window_shift
|
||||
else:
|
||||
reversed_waveform = torch.flip(waveform, [0])
|
||||
m = (num_samples + (window_shift // 2)) // window_shift
|
||||
pad = window_size // 2 - window_shift // 2
|
||||
pad_right = reversed_waveform
|
||||
if pad > 0:
|
||||
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
|
||||
# but we want [2, 1, 0, 0, 1, 2]
|
||||
pad_left = reversed_waveform[-pad:]
|
||||
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
|
||||
else:
|
||||
# pad is negative so we want to trim the waveform at the front
|
||||
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
|
||||
|
||||
sizes = (m, window_size)
|
||||
return waveform.as_strided(sizes, strides)
|
||||
|
||||
|
||||
def _feature_window_function(
|
||||
window_type: str,
|
||||
window_size: int,
|
||||
blackman_coeff: float,
|
||||
device: torch.device,
|
||||
dtype: int,
|
||||
) -> Tensor:
|
||||
r"""Returns a window function with the given type and size"""
|
||||
if window_type == HANNING:
|
||||
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
|
||||
elif window_type == HAMMING:
|
||||
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
|
||||
elif window_type == POVEY:
|
||||
# like hanning but goes to zero at edges
|
||||
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
|
||||
elif window_type == RECTANGULAR:
|
||||
return torch.ones(window_size, device=device, dtype=dtype)
|
||||
elif window_type == BLACKMAN:
|
||||
a = 2 * math.pi / (window_size - 1)
|
||||
window_function = torch.arange(window_size, device=device, dtype=dtype)
|
||||
# can't use torch.blackman_window as they use different coefficients
|
||||
return (
|
||||
blackman_coeff
|
||||
- 0.5 * torch.cos(a * window_function)
|
||||
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
|
||||
).to(device=device, dtype=dtype)
|
||||
else:
|
||||
raise Exception("Invalid window type " + window_type)
|
||||
|
||||
|
||||
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
|
||||
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
||||
device, dtype = strided_input.device, strided_input.dtype
|
||||
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
||||
if energy_floor == 0.0:
|
||||
return log_energy
|
||||
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
|
||||
|
||||
|
||||
def _get_waveform_and_window_properties(
|
||||
waveform: Tensor,
|
||||
channel: int,
|
||||
sample_frequency: float,
|
||||
frame_shift: float,
|
||||
frame_length: float,
|
||||
round_to_power_of_two: bool,
|
||||
preemphasis_coefficient: float,
|
||||
) -> Tuple[Tensor, int, int, int]:
|
||||
r"""Gets the waveform and window properties"""
|
||||
channel = max(channel, 0)
|
||||
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
|
||||
waveform = waveform[channel, :] # size (n)
|
||||
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
||||
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
||||
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
|
||||
|
||||
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
|
||||
window_size, len(waveform)
|
||||
)
|
||||
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
||||
assert padded_window_size % 2 == 0, (
|
||||
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
|
||||
)
|
||||
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
||||
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
||||
return waveform, window_shift, window_size, padded_window_size
|
||||
|
||||
|
||||
def _get_window(
|
||||
waveform: Tensor,
|
||||
padded_window_size: int,
|
||||
window_size: int,
|
||||
window_shift: int,
|
||||
window_type: str,
|
||||
blackman_coeff: float,
|
||||
snip_edges: bool,
|
||||
raw_energy: bool,
|
||||
energy_floor: float,
|
||||
dither: float,
|
||||
remove_dc_offset: bool,
|
||||
preemphasis_coefficient: float,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""Gets a window and its log energy
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
epsilon = _get_epsilon(device, dtype)
|
||||
|
||||
# size (m, window_size)
|
||||
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
||||
|
||||
if dither != 0.0:
|
||||
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
|
||||
strided_input = strided_input + rand_gauss * dither
|
||||
|
||||
if remove_dc_offset:
|
||||
# Subtract each row/frame by its mean
|
||||
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
||||
strided_input = strided_input - row_means
|
||||
|
||||
if raw_energy:
|
||||
# Compute the log energy of each row/frame before applying preemphasis and
|
||||
# window function
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
||||
|
||||
if preemphasis_coefficient != 0.0:
|
||||
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
||||
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
|
||||
0
|
||||
) # size (m, window_size + 1)
|
||||
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
||||
|
||||
# Apply window_function to each row/frame
|
||||
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
|
||||
0
|
||||
) # size (1, window_size)
|
||||
strided_input = strided_input * window_function # size (m, window_size)
|
||||
|
||||
# Pad columns with zero until we reach size (m, padded_window_size)
|
||||
if padded_window_size != window_size:
|
||||
padding_right = padded_window_size - window_size
|
||||
strided_input = torch.nn.functional.pad(
|
||||
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
||||
).squeeze(0)
|
||||
|
||||
# Compute energy after window function (not the raw one)
|
||||
if not raw_energy:
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
||||
|
||||
return strided_input, signal_log_energy
|
||||
|
||||
|
||||
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
||||
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
||||
# it returns size (m, n)
|
||||
if subtract_mean:
|
||||
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
|
||||
tensor = tensor - col_means
|
||||
return tensor
|
||||
|
||||
|
||||
def spectrogram(
|
||||
waveform: Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
min_duration: float = 0.0,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
window_type: str = POVEY,
|
||||
) -> Tensor:
|
||||
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-spectrogram-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``'povey'``)
|
||||
|
||||
Returns:
|
||||
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
||||
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
epsilon = _get_epsilon(device, dtype)
|
||||
|
||||
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
||||
)
|
||||
|
||||
if len(waveform) < min_duration * sample_frequency:
|
||||
# signal is too short
|
||||
return torch.empty(0)
|
||||
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform,
|
||||
padded_window_size,
|
||||
window_size,
|
||||
window_shift,
|
||||
window_type,
|
||||
blackman_coeff,
|
||||
snip_edges,
|
||||
raw_energy,
|
||||
energy_floor,
|
||||
dither,
|
||||
remove_dc_offset,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
|
||||
# size (m, padded_window_size // 2 + 1, 2)
|
||||
fft = torch.fft.rfft(strided_input)
|
||||
|
||||
# Convert the FFT into a power spectrum
|
||||
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
|
||||
power_spectrum[:, 0] = signal_log_energy
|
||||
|
||||
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
||||
return power_spectrum
|
||||
|
||||
|
||||
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
||||
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
||||
|
||||
|
||||
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
||||
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
||||
|
||||
|
||||
def mel_scale_scalar(freq: float) -> float:
|
||||
return 1127.0 * math.log(1.0 + freq / 700.0)
|
||||
|
||||
|
||||
def mel_scale(freq: Tensor) -> Tensor:
|
||||
return 1127.0 * (1.0 + freq / 700.0).log()
|
||||
|
||||
|
||||
def vtln_warp_freq(
|
||||
vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
freq: Tensor,
|
||||
) -> Tensor:
|
||||
r"""This computes a VTLN warping function that is not the same as HTK's one,
|
||||
but has similar inputs (this function has the advantage of never producing
|
||||
empty bins).
|
||||
|
||||
This function computes a warp function F(freq), defined between low_freq
|
||||
and high_freq inclusive, with the following properties:
|
||||
F(low_freq) == low_freq
|
||||
F(high_freq) == high_freq
|
||||
The function is continuous and piecewise linear with two inflection
|
||||
points.
|
||||
The lower inflection point (measured in terms of the unwarped
|
||||
frequency) is at frequency l, determined as described below.
|
||||
The higher inflection point is at a frequency h, determined as
|
||||
described below.
|
||||
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
||||
If the higher inflection point (measured in terms of the unwarped
|
||||
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
||||
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
||||
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
||||
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
||||
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
||||
If the lower inflection point (measured in terms of the unwarped
|
||||
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
||||
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
||||
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
||||
Args:
|
||||
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||
low_freq (float): Lower frequency cutoffs in mel computation
|
||||
high_freq (float): Upper frequency cutoffs in mel computation
|
||||
vtln_warp_factor (float): Vtln warp factor
|
||||
freq (Tensor): given frequency in Hz
|
||||
|
||||
Returns:
|
||||
Tensor: Freq after vtln warp
|
||||
"""
|
||||
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
|
||||
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
|
||||
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
||||
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
||||
scale = 1.0 / vtln_warp_factor
|
||||
Fl = scale * l # F(l)
|
||||
Fh = scale * h # F(h)
|
||||
assert l > low_freq and h < high_freq
|
||||
# slope of left part of the 3-piece linear function
|
||||
scale_left = (Fl - low_freq) / (l - low_freq)
|
||||
# [slope of center part is just "scale"]
|
||||
|
||||
# slope of right part of the 3-piece linear function
|
||||
scale_right = (high_freq - Fh) / (high_freq - h)
|
||||
|
||||
res = torch.empty_like(freq)
|
||||
|
||||
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
|
||||
before_l = torch.lt(freq, l) # freq < l
|
||||
before_h = torch.lt(freq, h) # freq < h
|
||||
after_h = torch.ge(freq, h) # freq >= h
|
||||
|
||||
# order of operations matter here (since there is overlapping frequency regions)
|
||||
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
||||
res[before_h] = scale * freq[before_h]
|
||||
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
||||
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def vtln_warp_mel_freq(
|
||||
vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
mel_freq: Tensor,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||
low_freq (float): Lower frequency cutoffs in mel computation
|
||||
high_freq (float): Upper frequency cutoffs in mel computation
|
||||
vtln_warp_factor (float): Vtln warp factor
|
||||
mel_freq (Tensor): Given frequency in Mel
|
||||
|
||||
Returns:
|
||||
Tensor: ``mel_freq`` after vtln warp
|
||||
"""
|
||||
return mel_scale(
|
||||
vtln_warp_freq(
|
||||
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_mel_banks(
|
||||
num_bins: int,
|
||||
window_length_padded: int,
|
||||
sample_freq: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_low: float,
|
||||
vtln_high: float,
|
||||
vtln_warp_factor: float,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
||||
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
||||
center frequencies of bins of size (``num_bins``)).
|
||||
"""
|
||||
assert num_bins > 3, "Must have at least 3 mel bins"
|
||||
assert window_length_padded % 2 == 0
|
||||
num_fft_bins = window_length_padded / 2
|
||||
nyquist = 0.5 * sample_freq
|
||||
|
||||
if high_freq <= 0.0:
|
||||
high_freq += nyquist
|
||||
|
||||
assert (
|
||||
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
|
||||
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
||||
|
||||
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
||||
fft_bin_width = sample_freq / window_length_padded
|
||||
mel_low_freq = mel_scale_scalar(low_freq)
|
||||
mel_high_freq = mel_scale_scalar(high_freq)
|
||||
|
||||
# divide by num_bins+1 in next line because of end-effects where the bins
|
||||
# spread out to the sides.
|
||||
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
|
||||
if vtln_high < 0.0:
|
||||
vtln_high += nyquist
|
||||
|
||||
assert vtln_warp_factor == 1.0 or (
|
||||
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
||||
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
|
||||
vtln_low, vtln_high, low_freq, high_freq
|
||||
)
|
||||
|
||||
bin = torch.arange(num_bins).unsqueeze(1)
|
||||
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
||||
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
|
||||
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
||||
|
||||
if vtln_warp_factor != 1.0:
|
||||
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
|
||||
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
|
||||
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
|
||||
|
||||
center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
|
||||
# size(1, num_fft_bins)
|
||||
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
|
||||
|
||||
# size (num_bins, num_fft_bins)
|
||||
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||
|
||||
if vtln_warp_factor == 1.0:
|
||||
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
||||
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
|
||||
else:
|
||||
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
||||
bins = torch.zeros_like(up_slope)
|
||||
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
|
||||
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
|
||||
bins[up_idx] = up_slope[up_idx]
|
||||
bins[down_idx] = down_slope[down_idx]
|
||||
|
||||
return bins, center_freqs
|
||||
|
||||
|
||||
def fbank(
|
||||
waveform: Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
high_freq: float = 0.0,
|
||||
htk_compat: bool = False,
|
||||
low_freq: float = 20.0,
|
||||
min_duration: float = 0.0,
|
||||
num_mel_bins: int = 23,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
use_energy: bool = False,
|
||||
use_log_fbank: bool = True,
|
||||
use_power: bool = True,
|
||||
vtln_high: float = -500.0,
|
||||
vtln_low: float = 100.0,
|
||||
vtln_warp: float = 1.0,
|
||||
window_type: str = POVEY,
|
||||
) -> Tensor:
|
||||
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-fbank-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
||||
(Default: ``0.0``)
|
||||
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
|
||||
(need to change other parameters). (Default: ``False``)
|
||||
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
||||
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
|
||||
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
|
||||
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
||||
negative, offset from high-mel-freq (Default: ``-500.0``)
|
||||
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
||||
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``'povey'``)
|
||||
|
||||
Returns:
|
||||
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
|
||||
where m is calculated in _get_strided
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
|
||||
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
||||
)
|
||||
|
||||
if len(waveform) < min_duration * sample_frequency:
|
||||
# signal is too short
|
||||
return torch.empty(0, device=device, dtype=dtype)
|
||||
|
||||
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform,
|
||||
padded_window_size,
|
||||
window_size,
|
||||
window_shift,
|
||||
window_type,
|
||||
blackman_coeff,
|
||||
snip_edges,
|
||||
raw_energy,
|
||||
energy_floor,
|
||||
dither,
|
||||
remove_dc_offset,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
|
||||
# size (m, padded_window_size // 2 + 1)
|
||||
spectrum = torch.fft.rfft(strided_input).abs()
|
||||
if use_power:
|
||||
spectrum = spectrum.pow(2.0)
|
||||
|
||||
# size (num_mel_bins, padded_window_size // 2)
|
||||
mel_energies, _ = get_mel_banks(
|
||||
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
|
||||
)
|
||||
mel_energies = mel_energies.to(device=device, dtype=dtype)
|
||||
|
||||
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
||||
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
||||
|
||||
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
||||
mel_energies = torch.mm(spectrum, mel_energies.T)
|
||||
if use_log_fbank:
|
||||
# avoid log of zero (which should be prevented anyway by dithering)
|
||||
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
||||
|
||||
# if use_energy then add it as the last column for htk_compat == true else first column
|
||||
if use_energy:
|
||||
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
||||
# returns size (m, num_mel_bins + 1)
|
||||
if htk_compat:
|
||||
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
|
||||
else:
|
||||
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
|
||||
|
||||
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
||||
return mel_energies
|
||||
|
||||
|
||||
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
||||
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
||||
# size (num_mel_bins, num_mel_bins)
|
||||
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
|
||||
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
||||
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
||||
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
||||
# expects a left multiply e.g. dct_matrix * vector).
|
||||
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
||||
dct_matrix = dct_matrix[:, :num_ceps]
|
||||
return dct_matrix
|
||||
|
||||
|
||||
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
||||
# returns size (num_ceps)
|
||||
# Compute liftering coefficients (scaling on cepstral coeffs)
|
||||
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
||||
i = torch.arange(num_ceps)
|
||||
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
|
||||
|
||||
|
||||
def mfcc(
|
||||
waveform: Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
cepstral_lifter: float = 22.0,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
high_freq: float = 0.0,
|
||||
htk_compat: bool = False,
|
||||
low_freq: float = 20.0,
|
||||
num_ceps: int = 13,
|
||||
min_duration: float = 0.0,
|
||||
num_mel_bins: int = 23,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
use_energy: bool = False,
|
||||
vtln_high: float = -500.0,
|
||||
vtln_low: float = 100.0,
|
||||
vtln_warp: float = 1.0,
|
||||
window_type: str = POVEY,
|
||||
) -> Tensor:
|
||||
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-mfcc-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
||||
(Default: ``0.0``)
|
||||
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
||||
features (need to change other parameters). (Default: ``False``)
|
||||
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
||||
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
||||
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
||||
negative, offset from high-mel-freq (Default: ``-500.0``)
|
||||
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
||||
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``"povey"``)
|
||||
|
||||
Returns:
|
||||
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
||||
where m is calculated in _get_strided
|
||||
"""
|
||||
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
|
||||
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
|
||||
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
||||
# (subtract_mean=False), and use log (use_log_fbank=True).
|
||||
# size (m, num_mel_bins + use_energy)
|
||||
feature = fbank(
|
||||
waveform=waveform,
|
||||
blackman_coeff=blackman_coeff,
|
||||
channel=channel,
|
||||
dither=dither,
|
||||
energy_floor=energy_floor,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
high_freq=high_freq,
|
||||
htk_compat=htk_compat,
|
||||
low_freq=low_freq,
|
||||
min_duration=min_duration,
|
||||
num_mel_bins=num_mel_bins,
|
||||
preemphasis_coefficient=preemphasis_coefficient,
|
||||
raw_energy=raw_energy,
|
||||
remove_dc_offset=remove_dc_offset,
|
||||
round_to_power_of_two=round_to_power_of_two,
|
||||
sample_frequency=sample_frequency,
|
||||
snip_edges=snip_edges,
|
||||
subtract_mean=False,
|
||||
use_energy=use_energy,
|
||||
use_log_fbank=True,
|
||||
use_power=True,
|
||||
vtln_high=vtln_high,
|
||||
vtln_low=vtln_low,
|
||||
vtln_warp=vtln_warp,
|
||||
window_type=window_type,
|
||||
)
|
||||
|
||||
if use_energy:
|
||||
# size (m)
|
||||
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
||||
# offset is 0 if htk_compat==True else 1
|
||||
mel_offset = int(not htk_compat)
|
||||
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
|
||||
|
||||
# size (num_mel_bins, num_ceps)
|
||||
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
||||
|
||||
# size (m, num_ceps)
|
||||
feature = feature.matmul(dct_matrix)
|
||||
|
||||
if cepstral_lifter != 0.0:
|
||||
# size (1, num_ceps)
|
||||
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
||||
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
||||
|
||||
# if use_energy then replace the last column for htk_compat == true else first column
|
||||
if use_energy:
|
||||
feature[:, 0] = signal_log_energy
|
||||
|
||||
if htk_compat:
|
||||
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
||||
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
||||
if not use_energy:
|
||||
# scale on C0 (actually removing a scale we previously added that's
|
||||
# part of one common definition of the cosine transform.)
|
||||
energy *= math.sqrt(2)
|
||||
|
||||
feature = torch.cat((feature, energy), dim=1)
|
||||
|
||||
feature = _subtract_column_mean(feature, subtract_mean)
|
||||
return feature
|
218
opencompass/models/ola/model/speech_encoder/beats/modules.py
Normal file
218
opencompass/models/ola/model/speech_encoder/beats/modules.py
Normal file
@ -0,0 +1,218 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
self.act = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.act(x)
|
||||
|
||||
|
||||
class GLU_Linear(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||
super(GLU_Linear, self).__init__()
|
||||
|
||||
self.glu_type = glu_type
|
||||
self.output_dim = output_dim
|
||||
|
||||
if glu_type == "sigmoid":
|
||||
self.glu_act = torch.nn.Sigmoid()
|
||||
elif glu_type == "swish":
|
||||
self.glu_act = Swish()
|
||||
elif glu_type == "relu":
|
||||
self.glu_act = torch.nn.ReLU()
|
||||
elif glu_type == "gelu":
|
||||
self.glu_act = torch.nn.GELU()
|
||||
|
||||
if bias_in_glu:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||
else:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||
|
||||
def forward(self, x):
|
||||
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||
x = self.linear(x)
|
||||
|
||||
if self.glu_type == "bilinear":
|
||||
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
||||
else:
|
||||
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return (
|
||||
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str):
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
warnings.warn(
|
||||
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||
)
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "glu":
|
||||
return lambda x: x
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert (
|
||||
module.weight.size(1) % block_size == 0
|
||||
), "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert (
|
||||
module.in_channels % block_size == 0
|
||||
), "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(
|
||||
in_features // block_size * out_features, device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(
|
||||
weight.size(0), weight.size(1), device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = (
|
||||
mask.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
)
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(
|
||||
torch.bool
|
||||
) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
215
opencompass/models/ola/model/speech_encoder/beats/quantizer.py
Normal file
215
opencompass/models/ola/model/speech_encoder/beats/quantizer.py
Normal file
@ -0,0 +1,215 @@
|
||||
# --------------------------------------------------------
|
||||
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on VQGAN code bases
|
||||
# https://github.com/CompVis/taming-transformers
|
||||
# --------------------------------------------------------'
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as distributed
|
||||
|
||||
try:
|
||||
from einops import rearrange, repeat
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def sample_vectors(samples, num):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
||||
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if use_cosine_sim:
|
||||
dists = samples @ means.t()
|
||||
else:
|
||||
diffs = rearrange(samples, 'n d -> n () d') \
|
||||
- rearrange(means, 'c d -> () c d')
|
||||
dists = -(diffs ** 2).sum(dim=-1)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
if use_cosine_sim:
|
||||
new_means = l2norm(new_means)
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
return means, bins
|
||||
|
||||
|
||||
class EmbeddingEMA(nn.Module):
|
||||
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
|
||||
super().__init__()
|
||||
self.num_tokens = num_tokens
|
||||
self.codebook_dim = codebook_dim
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
if codebook_init_path == '':
|
||||
if not kmeans_init:
|
||||
weight = torch.randn(num_tokens, codebook_dim)
|
||||
weight = l2norm(weight)
|
||||
else:
|
||||
weight = torch.zeros(num_tokens, codebook_dim)
|
||||
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
||||
else:
|
||||
print(f"load init codebook weight from {codebook_init_path}")
|
||||
codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
|
||||
weight = codebook_ckpt_weight.clone()
|
||||
self.register_buffer('initted', torch.Tensor([True]))
|
||||
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
||||
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
||||
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
||||
self.update = True
|
||||
|
||||
@torch.jit.ignore
|
||||
def init_embed_(self, data):
|
||||
if self.initted:
|
||||
return
|
||||
print("Performing Kemans init for codebook")
|
||||
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
||||
self.weight.data.copy_(embed)
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
def forward(self, embed_id):
|
||||
return F.embedding(embed_id, self.weight)
|
||||
|
||||
def cluster_size_ema_update(self, new_cluster_size):
|
||||
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
||||
|
||||
def embed_avg_ema_update(self, new_embed_avg):
|
||||
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
||||
|
||||
def weight_update(self, num_tokens):
|
||||
n = self.cluster_size.sum()
|
||||
smoothed_cluster_size = (
|
||||
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
||||
)
|
||||
# normalize embedding average with smoothed cluster size
|
||||
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
||||
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
||||
self.weight.data.copy_(embed_normalized)
|
||||
|
||||
|
||||
def norm_ema_inplace(moving_avg, new, decay):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
moving_avg.data.copy_(l2norm(moving_avg.data))
|
||||
|
||||
|
||||
class NormEMAVectorQuantizer(nn.Module):
|
||||
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
||||
statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
|
||||
super().__init__()
|
||||
self.codebook_dim = embedding_dim
|
||||
self.num_tokens = n_embed
|
||||
self.beta = beta
|
||||
self.decay = decay
|
||||
|
||||
# learnable = True if orthogonal_reg_weight > 0 else False
|
||||
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
|
||||
|
||||
self.statistic_code_usage = statistic_code_usage
|
||||
if statistic_code_usage:
|
||||
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
||||
if distributed.is_available() and distributed.is_initialized():
|
||||
print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
|
||||
self.all_reduce_fn = distributed.all_reduce
|
||||
else:
|
||||
self.all_reduce_fn = nn.Identity()
|
||||
|
||||
def reset_cluster_size(self, device):
|
||||
if self.statistic_code_usage:
|
||||
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
||||
self.cluster_size = self.cluster_size.to(device)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
# z, 'b c h w -> b h w c'
|
||||
# z = rearrange(z, 'b c h w -> b h w c')
|
||||
# z = z.transpose(1, 2)
|
||||
z = l2norm(z)
|
||||
z_flattened = z.reshape(-1, self.codebook_dim)
|
||||
|
||||
self.embedding.init_embed_(z_flattened)
|
||||
|
||||
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
||||
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
||||
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
||||
|
||||
encoding_indices = torch.argmin(d, dim=1)
|
||||
|
||||
z_q = self.embedding(encoding_indices).view(z.shape)
|
||||
|
||||
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
||||
|
||||
if not self.training:
|
||||
with torch.no_grad():
|
||||
cluster_size = encodings.sum(0)
|
||||
self.all_reduce_fn(cluster_size)
|
||||
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
||||
|
||||
if self.training and self.embedding.update:
|
||||
# EMA cluster size
|
||||
|
||||
bins = encodings.sum(0)
|
||||
self.all_reduce_fn(bins)
|
||||
|
||||
# self.embedding.cluster_size_ema_update(bins)
|
||||
ema_inplace(self.cluster_size, bins, self.decay)
|
||||
|
||||
zero_mask = (bins == 0)
|
||||
bins = bins.masked_fill(zero_mask, 1.)
|
||||
|
||||
embed_sum = z_flattened.t() @ encodings
|
||||
self.all_reduce_fn(embed_sum)
|
||||
|
||||
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
||||
embed_normalized = l2norm(embed_normalized)
|
||||
|
||||
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
|
||||
embed_normalized)
|
||||
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
# z_q, 'b h w c -> b c h w'
|
||||
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
||||
# z_q = z_q.transpose(1, 2)
|
||||
return z_q, loss, encoding_indices
|
13
opencompass/models/ola/model/speech_encoder/builder.py
Normal file
13
opencompass/models/ola/model/speech_encoder/builder.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder
|
||||
import torch.nn as nn
|
||||
|
||||
def build_speech_encoder(config):
|
||||
speech_encoder_type = getattr(config, 'speech_encoder_type', None)
|
||||
if "whisper" in speech_encoder_type.lower():
|
||||
return WhisperWrappedEncoder.load(config)
|
||||
elif "dual" in speech_encoder_type.lower():
|
||||
return DualWrappedEncoder(config)
|
||||
elif "none" in speech_encoder_type.lower():
|
||||
return None
|
||||
|
||||
raise ValueError(f'Unknown speech encoder: {speech_encoder_type}')
|
@ -0,0 +1,74 @@
|
||||
import types
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import WhisperFeatureExtractor
|
||||
import whisper
|
||||
|
||||
from opencompass.models.ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs
|
||||
|
||||
class WhisperWrappedEncoder:
|
||||
|
||||
@classmethod
|
||||
def load(cls, model_config):
|
||||
|
||||
def replace_layer_norm(module):
|
||||
from whisper.model import LayerNorm
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, LayerNorm):
|
||||
old_params = child.state_dict()
|
||||
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
|
||||
new_layer_norm.load_state_dict(old_params)
|
||||
setattr(module, name, new_layer_norm)
|
||||
else:
|
||||
replace_layer_norm(child)
|
||||
|
||||
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
|
||||
replace_layer_norm(encoder)
|
||||
return encoder
|
||||
|
||||
class DualWrappedEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.whisper_model = self.load_whisper(config)
|
||||
self.beats_model = self.load_beats(config)
|
||||
|
||||
def load_whisper(cls, model_config):
|
||||
|
||||
def replace_layer_norm(module):
|
||||
from whisper.model import LayerNorm
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, LayerNorm):
|
||||
old_params = child.state_dict()
|
||||
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
|
||||
new_layer_norm.load_state_dict(old_params)
|
||||
setattr(module, name, new_layer_norm)
|
||||
else:
|
||||
replace_layer_norm(child)
|
||||
|
||||
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
|
||||
replace_layer_norm(encoder)
|
||||
return encoder
|
||||
|
||||
def load_beats(cls, model_config):
|
||||
beats_path = model_config.music_encoder
|
||||
print("Loading BEATs Model")
|
||||
beats_ckpt = torch.load(beats_path, map_location='cpu')
|
||||
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
|
||||
beats = BEATs(beats_cfg)
|
||||
beats.load_state_dict(beats_ckpt['model'])
|
||||
return beats
|
||||
|
||||
def forward(self, x, raw_wav=None, audio_padding_mask=None):
|
||||
with torch.no_grad():
|
||||
self.beats_model = self.beats_model.float()
|
||||
speech_embeds = self.whisper_model(x)
|
||||
audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True)
|
||||
if audio_embeds.size(1) < speech_embeds.size(1):
|
||||
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
|
||||
elif audio_embeds.size(1) > speech_embeds.size(1):
|
||||
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
|
||||
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
|
||||
speech_embeds = speech_embeds.to(torch.bfloat16)
|
||||
return speech_embeds
|
11
opencompass/models/ola/model/speech_projector/builder.py
Normal file
11
opencompass/models/ola/model/speech_projector/builder.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .speech_projector import EncoderProjectorConcat
|
||||
|
||||
|
||||
def build_speech_projector(config):
|
||||
projector_type = getattr(config, 'speech_projector_type', 'linear')
|
||||
if projector_type == 'linear':
|
||||
return EncoderProjectorConcat(config)
|
||||
elif projector_type == 'none':
|
||||
return None
|
||||
|
||||
raise ValueError(f'Unknown projector type: {projector_type}')
|
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
class EncoderProjectorConcat(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.k = config.speech_encoder_ds_rate
|
||||
self.encoder_dim = config.speech_encoder_hidden_size
|
||||
self.llm_dim = config.hidden_size
|
||||
self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(2048, config.hidden_size)
|
||||
|
||||
embed_std = 1 / math.sqrt(config.hidden_size)
|
||||
self.speech_newline = nn.Parameter(
|
||||
torch.randn(config.hidden_size) * embed_std
|
||||
)
|
||||
self.speech_begin = nn.Parameter(
|
||||
torch.randn(config.hidden_size) * embed_std
|
||||
)
|
||||
self.speech_end = nn.Parameter(
|
||||
torch.randn(config.hidden_size) * embed_std
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, dim = x.size()
|
||||
num_frames_to_discard = seq_len % self.k
|
||||
if num_frames_to_discard > 0:
|
||||
x = x[:, :-num_frames_to_discard, :]
|
||||
seq_len = x.size(1)
|
||||
|
||||
x = x.contiguous()
|
||||
x = x.view(batch_size, seq_len // self.k, dim * self.k)
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
x = torch.cat([
|
||||
x,
|
||||
self.speech_newline.reshape(1, 1, -1).expand(batch_size, 1, -1).to(x.dtype)
|
||||
], dim=1)
|
||||
begin = self.speech_begin.reshape(1, -1).to(x.dtype)
|
||||
end = self.speech_end.reshape(1, -1).to(x.dtype)
|
||||
x = x.flatten(0, 1)
|
||||
x = torch.cat([begin, x, end], dim=0)
|
||||
# x = x.flatten(0, 1)
|
||||
return x
|
213
opencompass/models/ola/utils.py
Normal file
213
opencompass/models/ola/utils.py
Normal file
@ -0,0 +1,213 @@
|
||||
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
||||
# Copyright 2023 Haotian Liu
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import logging
|
||||
import logging.handlers
|
||||
import transformers
|
||||
|
||||
from opencompass.models.ola.constants import LOGDIR
|
||||
|
||||
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
||||
|
||||
handler = None
|
||||
|
||||
|
||||
def build_logger(logger_name, logger_filename):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
stdout_logger = logging.getLogger("stdout")
|
||||
stdout_logger.setLevel(logging.INFO)
|
||||
sl = StreamToLogger(stdout_logger, logging.INFO)
|
||||
sys.stdout = sl
|
||||
|
||||
stderr_logger = logging.getLogger("stderr")
|
||||
stderr_logger.setLevel(logging.ERROR)
|
||||
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
sys.stderr = sl
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None:
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
filename, when='D', utc=True, encoding='UTF-8')
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ''
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ''
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == '\n':
|
||||
self.logger.log(self.log_level, line.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != '':
|
||||
self.logger.log(self.log_level, self.linebuf.rstrip())
|
||||
self.linebuf = ''
|
||||
|
||||
|
||||
def maybe_zero_3(param, ignore_status=False, name=None):
|
||||
from deepspeed import zero
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
if hasattr(param, "ds_id"):
|
||||
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
||||
if not ignore_status:
|
||||
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
||||
with zero.GatheredParameters([param]):
|
||||
param = param.data.detach().cpu().clone()
|
||||
else:
|
||||
param = param.detach().cpu().clone()
|
||||
return param
|
||||
|
||||
|
||||
# Borrowed from peft.utils.get_peft_model_state_dict
|
||||
def get_peft_state_maybe_zero_3(named_params, bias):
|
||||
if bias == "none":
|
||||
to_return = {k: t for k, t in named_params if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = {}
|
||||
maybe_lora_bias = {}
|
||||
lora_bias_names = set()
|
||||
for k, t in named_params:
|
||||
if "lora_" in k:
|
||||
to_return[k] = t
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
lora_bias_names.add(bias_name)
|
||||
elif "bias" in k:
|
||||
maybe_lora_bias[k] = t
|
||||
for k, t in maybe_lora_bias:
|
||||
if bias_name in lora_bias_names:
|
||||
to_return[bias_name] = t
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
|
||||
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
||||
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
||||
if require_grad_only:
|
||||
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
||||
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
|
||||
def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match):
|
||||
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
||||
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
def lengths_to_padding_mask(lens):
|
||||
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
||||
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
||||
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
||||
return mask
|
||||
|
||||
|
||||
def lengths_to_mask(lens):
|
||||
return ~lengths_to_padding_mask(lens)
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def get_model_name_from_path(model_path):
|
||||
model_path = model_path.strip("/")
|
||||
model_paths = model_path.split("/")
|
||||
if model_paths[-1].startswith('checkpoint-'):
|
||||
return model_paths[-2] + "_" + model_paths[-1]
|
||||
else:
|
||||
return model_paths[-1]
|
||||
|
||||
|
||||
def violates_moderation(text):
|
||||
"""
|
||||
Check whether the text violates OpenAI moderation API.
|
||||
"""
|
||||
url = "https://api.openai.com/v1/moderations"
|
||||
headers = {"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
||||
text = text.replace("\n", "")
|
||||
data = "{" + '"input": ' + f'"{text}"' + "}"
|
||||
data = data.encode("utf-8")
|
||||
try:
|
||||
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
||||
flagged = ret.json()["results"][0]["flagged"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
flagged = False
|
||||
except KeyError as e:
|
||||
flagged = False
|
||||
|
||||
return flagged
|
||||
|
||||
|
||||
def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
141
opencompass/models/ola_model.py
Normal file
141
opencompass/models/ola_model.py
Normal file
@ -0,0 +1,141 @@
|
||||
import os
|
||||
os.environ['LOWRES_RESIZE'] = '384x32'
|
||||
os.environ['HIGHRES_BASE'] = '0x32'
|
||||
os.environ['VIDEO_RESIZE'] = "0x64"
|
||||
os.environ['VIDEO_MAXRES'] = "480"
|
||||
os.environ['VIDEO_MINRES'] = "288"
|
||||
os.environ['MAXRES'] = '1536'
|
||||
os.environ['MINRES'] = '0'
|
||||
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
||||
os.environ['LOAD_VISION_EARLY'] = '1'
|
||||
os.environ['PAD2STRIDE'] = '1'
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from opencompass.models.base import BaseModel, LMTemplateParser
|
||||
from opencompass.utils.prompt import PromptList
|
||||
PromptType = Union[PromptList, str]
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import re
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import transformers
|
||||
from typing import Dict, Optional, Sequence, List
|
||||
from opencompass.models.ola.conversation import conv_templates, SeparatorStyle
|
||||
from opencompass.models.ola.model.builder import load_pretrained_model
|
||||
from opencompass.models.ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
||||
from opencompass.models.ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
||||
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
class OlaModel(BaseModel):
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
model_config: Optional[str] = None,
|
||||
meta_template: Optional[Dict] = None):
|
||||
|
||||
|
||||
self.template_parser = LMTemplateParser(meta_template)
|
||||
self.eos_token_id = None
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
|
||||
|
||||
tokenizer, model, _, _ = load_pretrained_model(path, None)
|
||||
model = model.to('cuda').eval()
|
||||
model = model.bfloat16()
|
||||
self.tokenizer=tokenizer
|
||||
self.model=model
|
||||
self.gen_kwargs = {
|
||||
"max_new_tokens":1024,
|
||||
"temperature":0.2,
|
||||
"top_p":None,
|
||||
"num_beams":1,
|
||||
}
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
assert len(inputs)==1 # batch=1
|
||||
image_path = None
|
||||
audio_path = None
|
||||
video_path = None
|
||||
text = inputs[0]
|
||||
|
||||
|
||||
images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
||||
images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
||||
image_sizes = [(224, 224)]
|
||||
|
||||
|
||||
|
||||
USE_SPEECH=False
|
||||
speechs = []
|
||||
speech_lengths = []
|
||||
speech_wavs = []
|
||||
speech_chunks = []
|
||||
speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
|
||||
speech_lengths = [torch.LongTensor([3000]).to('cuda')]
|
||||
speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
|
||||
speech_chunks = [torch.LongTensor([1]).to('cuda')]
|
||||
|
||||
|
||||
conv_mode = "qwen_1_5"
|
||||
if text:
|
||||
qs = text
|
||||
else:
|
||||
qs = ''
|
||||
conv = conv_templates[conv_mode].copy()
|
||||
conv.append_message(conv.roles[0], qs)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
||||
|
||||
pad_token_ids = 151643
|
||||
|
||||
attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
|
||||
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
||||
keywords = [stop_str]
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
||||
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = self.model.generate(
|
||||
input_ids,
|
||||
images=images,
|
||||
images_highres=images_highres,
|
||||
image_sizes=image_sizes,
|
||||
modalities=['text'],
|
||||
speech=speechs,
|
||||
speech_lengths=speech_lengths,
|
||||
speech_chunks=speech_chunks,
|
||||
speech_wav=speech_wavs,
|
||||
attention_mask=attention_masks,
|
||||
use_cache=True,
|
||||
stopping_criteria=[stopping_criteria],
|
||||
do_sample=True if self.gen_kwargs["temperature"] > 0 else False,
|
||||
temperature=self.gen_kwargs["temperature"],
|
||||
top_p=self.gen_kwargs["top_p"],
|
||||
num_beams=self.gen_kwargs["num_beams"],
|
||||
max_new_tokens=self.gen_kwargs["max_new_tokens"],
|
||||
)
|
||||
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
out=[]
|
||||
for output in outputs:
|
||||
output = output.strip()
|
||||
if output.endswith(stop_str):
|
||||
output = output[:-len(stop_str)]
|
||||
out.append(output)
|
||||
print(f"prompt---->",prompt)
|
||||
print(f"out---->",out)
|
||||
print(f"\n")
|
||||
return out
|
Loading…
Reference in New Issue
Block a user