mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge 37b894d4a1
into d572761cef
This commit is contained in:
commit
5582aeb9f9
10
examples/eval_ola.py
Normal file
10
examples/eval_ola.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from opencompass.configs.datasets.demo.demo_gsm8k_chat_gen import gsm8k_datasets
|
||||||
|
from opencompass.configs.datasets.demo.demo_math_chat_gen import math_datasets
|
||||||
|
from opencompass.configs.models.ola.ola import models as ola_models
|
||||||
|
|
||||||
|
|
||||||
|
datasets = math_datasets
|
||||||
|
models = ola_models
|
12
opencompass/configs/models/ola/ola.py
Normal file
12
opencompass/configs/models/ola/ola.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from opencompass.models import OlaModel
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=OlaModel,
|
||||||
|
path="THUdyh/Ola-7b",
|
||||||
|
max_seq_len=2048,
|
||||||
|
abbr='ola',
|
||||||
|
max_out_len=1024,
|
||||||
|
batch_size=1,
|
||||||
|
run_cfg=dict(num_gpus=1),
|
||||||
|
)
|
||||||
|
]
|
@ -35,6 +35,7 @@ from .openai_api import OpenAI # noqa: F401
|
|||||||
from .openai_api import OpenAISDK # noqa: F401
|
from .openai_api import OpenAISDK # noqa: F401
|
||||||
from .pangu_api import PanGu # noqa: F401
|
from .pangu_api import PanGu # noqa: F401
|
||||||
from .qwen_api import Qwen # noqa: F401
|
from .qwen_api import Qwen # noqa: F401
|
||||||
|
from .ola_model import OlaModel # noqa: F401
|
||||||
from .rendu_api import Rendu # noqa: F401
|
from .rendu_api import Rendu # noqa: F401
|
||||||
from .sensetime_api import SenseTime # noqa: F401
|
from .sensetime_api import SenseTime # noqa: F401
|
||||||
from .stepfun_api import StepFun # noqa: F401
|
from .stepfun_api import StepFun # noqa: F401
|
||||||
|
65
opencompass/models/ola/arguments.py
Normal file
65
opencompass/models/ola/arguments.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import transformers
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
||||||
|
version: Optional[str] = field(default="v0")
|
||||||
|
freeze_backbone: bool = field(default=False)
|
||||||
|
tune_speech_projector: bool = field(default=False)
|
||||||
|
tune_speech_encoder: bool = field(default=False)
|
||||||
|
tune_speech_generator_only: bool = field(default=False)
|
||||||
|
speech_encoder_type: Optional[str] = field(default=None)
|
||||||
|
speech_encoder: Optional[str] = field(default=None)
|
||||||
|
pretrain_speech_projector: Optional[str] = field(default=None)
|
||||||
|
speech_projector_type: Optional[str] = field(default='linear')
|
||||||
|
speech_encoder_ds_rate: int = 5
|
||||||
|
speech_encoder_hidden_size: int = 1280
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataArguments:
|
||||||
|
data_path: str = field(default=None,
|
||||||
|
metadata={"help": "Path to the training data."})
|
||||||
|
is_multimodal: bool = False
|
||||||
|
input_type: str = field(default="mel")
|
||||||
|
speech_normalize: bool = False
|
||||||
|
mel_size: int = 128
|
||||||
|
has_tgt_units: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments(transformers.TrainingArguments):
|
||||||
|
cache_dir: Optional[str] = field(default=None)
|
||||||
|
optim: str = field(default="adamw_torch")
|
||||||
|
freeze_speech_projector: bool = field(default=False)
|
||||||
|
model_max_length: int = field(
|
||||||
|
default=512,
|
||||||
|
metadata={
|
||||||
|
"help":
|
||||||
|
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
double_quant: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Compress the quantization statistics through double quantization."}
|
||||||
|
)
|
||||||
|
quant_type: str = field(
|
||||||
|
default="nf4",
|
||||||
|
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
||||||
|
)
|
||||||
|
bits: int = field(
|
||||||
|
default=16,
|
||||||
|
metadata={"help": "How many bits to use."}
|
||||||
|
)
|
||||||
|
lora_enable: bool = False
|
||||||
|
lora_r: int = 64
|
||||||
|
lora_alpha: int = 16
|
||||||
|
lora_dropout: float = 0.05
|
||||||
|
lora_weight_path: str = ""
|
||||||
|
lora_bias: str = "none"
|
||||||
|
speech_projector_lr: Optional[float] = None
|
||||||
|
group_by_modality_length: bool = field(default=False)
|
14
opencompass/models/ola/constants.py
Normal file
14
opencompass/models/ola/constants.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
||||||
|
WORKER_HEART_BEAT_INTERVAL = 15
|
||||||
|
|
||||||
|
LOGDIR = "."
|
||||||
|
|
||||||
|
# Model Constants
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
SPEECH_TOKEN_INDEX = -200
|
||||||
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
|
IMAGE_TOKEN_INDEX= -300
|
||||||
|
DEFAULT_IMAGE_TOKEN = "<image>"
|
||||||
|
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||||
|
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||||
|
DEFAULT_IM_END_TOKEN = "<im_end>"
|
254
opencompass/models/ola/conversation.py
Normal file
254
opencompass/models/ola/conversation.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
import dataclasses
|
||||||
|
from enum import auto, Enum
|
||||||
|
from typing import List, Any, Union, Tuple
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class SeparatorStyle(Enum):
|
||||||
|
"""Different separator style."""
|
||||||
|
TWO = auto()
|
||||||
|
PLAIN = auto()
|
||||||
|
CHATML = auto()
|
||||||
|
LLAMA_2 = auto()
|
||||||
|
LLAMA_3 = auto()
|
||||||
|
QWEN2 = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Conversation:
|
||||||
|
"""A class that keeps all conversation history."""
|
||||||
|
system: str
|
||||||
|
roles: List[str]
|
||||||
|
messages: List[List[str]]
|
||||||
|
offset: int
|
||||||
|
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
|
||||||
|
sep: str = "###"
|
||||||
|
sep2: str = None
|
||||||
|
version: str = "Unknown"
|
||||||
|
|
||||||
|
tokenizer_id: str = ""
|
||||||
|
tokenizer: Any = None
|
||||||
|
# Stop criteria (the default one is EOS token)
|
||||||
|
stop_str: Union[str, List[str]] = None
|
||||||
|
# Stops generation if meeting any token in this list
|
||||||
|
stop_token_ids: List[int] = None
|
||||||
|
|
||||||
|
skip_next: bool = False
|
||||||
|
|
||||||
|
def get_prompt(self):
|
||||||
|
messages = self.messages
|
||||||
|
|
||||||
|
if self.sep_style == SeparatorStyle.TWO:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
ret = self.system + seps[0]
|
||||||
|
for i, (role, message) in enumerate(messages):
|
||||||
|
if message:
|
||||||
|
if type(message) is tuple:
|
||||||
|
message = message[0]
|
||||||
|
ret += role + ": " + message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
||||||
|
wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
|
||||||
|
ret = "<|begin_of_text|>" + wrap_sys(self.system)
|
||||||
|
for i, (role, message) in enumerate(messages):
|
||||||
|
if message:
|
||||||
|
if type(message) is tuple:
|
||||||
|
message = message[0]
|
||||||
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||||
|
ret += message.strip() + self.sep2
|
||||||
|
else:
|
||||||
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||||
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
||||||
|
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
||||||
|
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
||||||
|
ret = ""
|
||||||
|
|
||||||
|
for i, (role, message) in enumerate(messages):
|
||||||
|
if i == 0:
|
||||||
|
assert message, "first message should not be none"
|
||||||
|
assert role == self.roles[0], "first message should come from user"
|
||||||
|
if message:
|
||||||
|
if type(message) is tuple:
|
||||||
|
message, _, _ = message
|
||||||
|
if i == 0:
|
||||||
|
message = wrap_sys(self.system) + message
|
||||||
|
if i % 2 == 0:
|
||||||
|
message = wrap_inst(message)
|
||||||
|
ret += self.sep + message
|
||||||
|
else:
|
||||||
|
ret += " " + message + " " + self.sep2
|
||||||
|
else:
|
||||||
|
ret += ""
|
||||||
|
ret = ret.lstrip(self.sep)
|
||||||
|
elif self.sep_style == SeparatorStyle.PLAIN:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
ret = self.system
|
||||||
|
for i, (role, message) in enumerate(messages):
|
||||||
|
if message:
|
||||||
|
if type(message) is tuple:
|
||||||
|
message, _, _ = message
|
||||||
|
ret += message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += ""
|
||||||
|
|
||||||
|
elif self.sep_style == SeparatorStyle.CHATML:
|
||||||
|
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
||||||
|
for role, message in messages:
|
||||||
|
if message:
|
||||||
|
if type(message) is tuple:
|
||||||
|
raise ValueError("Tuple not supported in CHATML")
|
||||||
|
message, images = message
|
||||||
|
message = "<speech>" * len(images) + message
|
||||||
|
ret += role + "\n" + message + self.sep + "\n"
|
||||||
|
else:
|
||||||
|
ret += role + "\n"
|
||||||
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.QWEN2:
|
||||||
|
start = '<|im_start|>'
|
||||||
|
end = '<|im_end|>\n'
|
||||||
|
ret = start + 'system\n' + self.system + end
|
||||||
|
for i, (role, message) in enumerate(messages):
|
||||||
|
if message:
|
||||||
|
if type(message) is tuple:
|
||||||
|
message, _, _ = message
|
||||||
|
|
||||||
|
if message.endswith('<|endoftext|>'):
|
||||||
|
message = message.replace('<|endoftext|>', '')
|
||||||
|
ret += start + role + "\n" + message + end + '<|endoftext|>'
|
||||||
|
else:
|
||||||
|
assert not '<|endoftext|>' in message, f"Invalid message: {message}"
|
||||||
|
ret += start + role + "\n" + message + end
|
||||||
|
else:
|
||||||
|
ret += start + role + "\n"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def append_message(self, role, message):
|
||||||
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
def to_gradio_chatbot(self):
|
||||||
|
ret = []
|
||||||
|
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||||
|
if i % 2 == 0:
|
||||||
|
if type(msg) is tuple:
|
||||||
|
msg, speech = msg
|
||||||
|
ret.append([msg, None])
|
||||||
|
else:
|
||||||
|
ret.append([msg, None])
|
||||||
|
else:
|
||||||
|
ret[-1][-1] = msg
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Conversation(
|
||||||
|
system=self.system,
|
||||||
|
roles=self.roles,
|
||||||
|
messages=[[x, y] for x, y in self.messages],
|
||||||
|
offset=self.offset,
|
||||||
|
sep_style=self.sep_style,
|
||||||
|
sep=self.sep,
|
||||||
|
sep2=self.sep2,
|
||||||
|
version=self.version)
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
if len(self.get_images()) > 0:
|
||||||
|
return {
|
||||||
|
"system": self.system,
|
||||||
|
"roles": self.roles,
|
||||||
|
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
||||||
|
"offset": self.offset,
|
||||||
|
"sep": self.sep,
|
||||||
|
"sep2": self.sep2,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"system": self.system,
|
||||||
|
"roles": self.roles,
|
||||||
|
"messages": self.messages,
|
||||||
|
"offset": self.offset,
|
||||||
|
"sep": self.sep,
|
||||||
|
"sep2": self.sep2,
|
||||||
|
}
|
||||||
|
|
||||||
|
conv_vicuna_v1 = Conversation(
|
||||||
|
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||||
|
roles=("USER", "ASSISTANT"),
|
||||||
|
version="v1",
|
||||||
|
messages=[],
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.TWO,
|
||||||
|
sep=" ",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_llama_2 = Conversation(
|
||||||
|
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
||||||
|
roles=("USER", "ASSISTANT"),
|
||||||
|
version="llama_v2",
|
||||||
|
messages=[],
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.LLAMA_2,
|
||||||
|
sep="<s>",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_llama_3 = Conversation(
|
||||||
|
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
||||||
|
roles=("user", "assistant"),
|
||||||
|
version="llama_v3",
|
||||||
|
messages=[],
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.LLAMA_3,
|
||||||
|
sep="",
|
||||||
|
sep2="<|eot_id|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
conv_qwen_v1 = Conversation(
|
||||||
|
system="You are a helpful assistant.",
|
||||||
|
roles=("user", "assistant"),
|
||||||
|
version="v1",
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.QWEN2,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_plain = Conversation(
|
||||||
|
system="",
|
||||||
|
roles=("", ""),
|
||||||
|
messages=(
|
||||||
|
),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.PLAIN,
|
||||||
|
sep="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_qwen = Conversation(
|
||||||
|
system="""<|im_start|>system
|
||||||
|
You are a helpful assistant.""",
|
||||||
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||||
|
version="qwen",
|
||||||
|
messages=[],
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.CHATML,
|
||||||
|
sep="<|im_end|>",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_conversation = conv_llama_3
|
||||||
|
conv_templates = {
|
||||||
|
"v1": conv_vicuna_v1,
|
||||||
|
"plain": conv_plain,
|
||||||
|
"llama_2": conv_llama_2,
|
||||||
|
"llama_3": conv_llama_3,
|
||||||
|
'v1_qwen2': conv_qwen_v1,
|
||||||
|
"qwen_1_5": conv_qwen,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(default_conversation.get_prompt())
|
0
opencompass/models/ola/datasets/__init__.py
Normal file
0
opencompass/models/ola/datasets/__init__.py
Normal file
231
opencompass/models/ola/datasets/preprocess.py
Normal file
231
opencompass/models/ola/datasets/preprocess.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
import tokenizers
|
||||||
|
|
||||||
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX
|
||||||
|
from opencompass.models.ola import conversation as conversation_lib
|
||||||
|
from opencompass.models.ola.model import *
|
||||||
|
from opencompass.models.ola.arguments import DataArguments
|
||||||
|
from opencompass.models.ola.constants import SPEECH_TOKEN_INDEX
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
|
||||||
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
|
||||||
|
|
||||||
|
def insert_separator(X, sep):
|
||||||
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||||
|
|
||||||
|
input_ids = []
|
||||||
|
offset = 0
|
||||||
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||||
|
offset = 1
|
||||||
|
input_ids.append(prompt_chunks[0][0])
|
||||||
|
|
||||||
|
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
|
||||||
|
input_ids.extend(x[offset:])
|
||||||
|
|
||||||
|
if return_tensors is not None:
|
||||||
|
if return_tensors == 'pt':
|
||||||
|
return torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
||||||
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
||||||
|
|
||||||
|
def insert_separator(X, sep):
|
||||||
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||||
|
|
||||||
|
input_ids = []
|
||||||
|
offset = 0
|
||||||
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||||
|
offset = 1
|
||||||
|
input_ids.append(prompt_chunks[0][0])
|
||||||
|
|
||||||
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
||||||
|
input_ids.extend(x[offset:])
|
||||||
|
|
||||||
|
if return_tensors is not None:
|
||||||
|
if return_tensors == 'pt':
|
||||||
|
return torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
||||||
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech><image>')]
|
||||||
|
|
||||||
|
def insert_separator(X, sep):
|
||||||
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||||
|
|
||||||
|
input_ids = []
|
||||||
|
offset = 0
|
||||||
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||||
|
offset = 1
|
||||||
|
input_ids.append(prompt_chunks[0][0])
|
||||||
|
|
||||||
|
for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)):
|
||||||
|
input_ids.extend(x[offset:])
|
||||||
|
|
||||||
|
if return_tensors is not None:
|
||||||
|
if return_tensors == 'pt':
|
||||||
|
return torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
||||||
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>\nUser's question in speech: <speech>\n")]
|
||||||
|
|
||||||
|
def insert_separator(X, sep):
|
||||||
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||||
|
|
||||||
|
input_ids = []
|
||||||
|
offset = 0
|
||||||
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
||||||
|
offset = 1
|
||||||
|
input_ids.append(prompt_chunks[0][0])
|
||||||
|
|
||||||
|
nl_tokens = tokenizer("\n").input_ids
|
||||||
|
special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens]
|
||||||
|
|
||||||
|
for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)):
|
||||||
|
input_ids.extend(x[offset:])
|
||||||
|
|
||||||
|
if return_tensors is not None:
|
||||||
|
if return_tensors == 'pt':
|
||||||
|
return torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def preprocess_v1(
|
||||||
|
sources,
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
has_speech: bool = False
|
||||||
|
) -> Dict:
|
||||||
|
conv = conversation_lib.default_conversation.copy()
|
||||||
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||||
|
|
||||||
|
# Apply prompt templates
|
||||||
|
conversations = []
|
||||||
|
for i, source in enumerate(sources):
|
||||||
|
if roles[source[0]["from"]] != conv.roles[0]:
|
||||||
|
# Skip the first one if it is not from human
|
||||||
|
source = source[1:]
|
||||||
|
|
||||||
|
conv.messages = []
|
||||||
|
for j, sentence in enumerate(source):
|
||||||
|
role = roles[sentence["from"]]
|
||||||
|
assert role == conv.roles[j % 2], f"{i}"
|
||||||
|
conv.append_message(role, sentence["value"])
|
||||||
|
conversations.append(conv.get_prompt())
|
||||||
|
|
||||||
|
# Tokenize conversations
|
||||||
|
|
||||||
|
if has_speech:
|
||||||
|
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
||||||
|
else:
|
||||||
|
input_ids = tokenizer(
|
||||||
|
conversations,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
targets = input_ids.clone()
|
||||||
|
|
||||||
|
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
||||||
|
|
||||||
|
# Mask targets
|
||||||
|
sep = conv.sep + conv.roles[1] + ": "
|
||||||
|
for conversation, target in zip(conversations, targets):
|
||||||
|
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
||||||
|
|
||||||
|
rounds = conversation.split(conv.sep2)
|
||||||
|
cur_len = 1
|
||||||
|
target[:cur_len] = IGNORE_INDEX
|
||||||
|
for i, rou in enumerate(rounds):
|
||||||
|
if rou == "":
|
||||||
|
break
|
||||||
|
|
||||||
|
parts = rou.split(sep)
|
||||||
|
if len(parts) != 2:
|
||||||
|
break
|
||||||
|
parts[0] += sep
|
||||||
|
|
||||||
|
if has_speech:
|
||||||
|
round_len = len(tokenizer_speech_token(rou, tokenizer))
|
||||||
|
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
|
||||||
|
else:
|
||||||
|
round_len = len(tokenizer(rou).input_ids)
|
||||||
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
||||||
|
|
||||||
|
# FIXME: tokenizer bug
|
||||||
|
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
||||||
|
round_len -= 1
|
||||||
|
instruction_len -= 1
|
||||||
|
|
||||||
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
||||||
|
|
||||||
|
cur_len += round_len
|
||||||
|
target[cur_len:] = IGNORE_INDEX
|
||||||
|
|
||||||
|
if cur_len < tokenizer.model_max_length:
|
||||||
|
if cur_len != total_len:
|
||||||
|
target[:] = IGNORE_INDEX
|
||||||
|
print(
|
||||||
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
||||||
|
f" (ignored)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
labels=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_plain(
|
||||||
|
sources: Sequence[str],
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
) -> Dict:
|
||||||
|
# add end signal and concatenate together
|
||||||
|
conversations = []
|
||||||
|
for source in sources:
|
||||||
|
assert len(source) == 2
|
||||||
|
assert DEFAULT_SPEECH_TOKEN in source[0]['value']
|
||||||
|
source[0]['value'] = DEFAULT_SPEECH_TOKEN
|
||||||
|
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
||||||
|
conversations.append(conversation)
|
||||||
|
# tokenize conversations
|
||||||
|
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
||||||
|
targets = copy.deepcopy(input_ids)
|
||||||
|
for target, source in zip(targets, sources):
|
||||||
|
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer))
|
||||||
|
target[:tokenized_len] = IGNORE_INDEX
|
||||||
|
|
||||||
|
return dict(input_ids=input_ids, labels=targets)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
sources: Sequence[str],
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
has_speech: bool = False
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Given a list of sources, each is a conversation list. This transform:
|
||||||
|
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
||||||
|
2. Concatenate conversations together;
|
||||||
|
3. Tokenize the concatenated conversation;
|
||||||
|
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
||||||
|
"""
|
||||||
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
||||||
|
return preprocess_plain(sources, tokenizer)
|
||||||
|
if conversation_lib.default_conversation.version.startswith("v1"):
|
||||||
|
return preprocess_v1(sources, tokenizer, has_speech=has_speech)
|
||||||
|
raise NotImplementedError
|
271
opencompass/models/ola/mm_utils.py
Normal file
271
opencompass/models/ola/mm_utils.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import base64
|
||||||
|
import math
|
||||||
|
import ast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import StoppingCriteria
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
|
||||||
|
if 'VIDEO_RESIZE' in os.environ:
|
||||||
|
# highresxpatch
|
||||||
|
VIDEO_RESIZE = os.environ['VIDEO_RESIZE']
|
||||||
|
video_base, video_ps = VIDEO_RESIZE.split('x')
|
||||||
|
video_base = int(video_base)
|
||||||
|
video_ps = int(video_ps)
|
||||||
|
print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}")
|
||||||
|
else:
|
||||||
|
HIGHRES_BASE = None
|
||||||
|
|
||||||
|
if 'HIGHRES_BASE' in os.environ:
|
||||||
|
# highresxpatch
|
||||||
|
HIGHRES_BASE = os.environ['HIGHRES_BASE']
|
||||||
|
highres_base, highres_ps = HIGHRES_BASE.split('x')
|
||||||
|
highres_base = int(highres_base)
|
||||||
|
highres_ps = int(highres_ps)
|
||||||
|
print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}")
|
||||||
|
else:
|
||||||
|
HIGHRES_BASE = None
|
||||||
|
|
||||||
|
if 'MAXRES' in os.environ:
|
||||||
|
# highresxpatch
|
||||||
|
MAXRES = int(os.environ['MAXRES'])
|
||||||
|
print(f"MAXRES is set as {MAXRES}")
|
||||||
|
else:
|
||||||
|
MAXRES = 1536
|
||||||
|
|
||||||
|
if 'MINRES' in os.environ:
|
||||||
|
# highresxpatch
|
||||||
|
MINRES = int(os.environ['MINRES'])
|
||||||
|
print(f"MINRES is set as {MINRES}")
|
||||||
|
else:
|
||||||
|
MINRES = 0
|
||||||
|
|
||||||
|
if 'VIDEO_MAXRES' in os.environ:
|
||||||
|
# highresxpatch
|
||||||
|
VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES'])
|
||||||
|
print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}")
|
||||||
|
else:
|
||||||
|
VIDEO_MAXRES = 1536
|
||||||
|
|
||||||
|
if 'VIDEO_MINRES' in os.environ:
|
||||||
|
# highresxpatch
|
||||||
|
VIDEO_MINRES = int(os.environ['VIDEO_MINRES'])
|
||||||
|
print(f"VIDEO_MINRES is set as {VIDEO_MINRES}")
|
||||||
|
else:
|
||||||
|
MINRES = 0
|
||||||
|
|
||||||
|
if 'PAD2STRIDE' in os.environ:
|
||||||
|
# highresxpatch
|
||||||
|
PAD2STRIDE = True
|
||||||
|
print(f"PAD2STRIDE is set")
|
||||||
|
else:
|
||||||
|
PAD2STRIDE = False
|
||||||
|
|
||||||
|
if 'LOWRES_RESIZE' in os.environ:
|
||||||
|
LOWRES_RESIZE = os.environ['LOWRES_RESIZE']
|
||||||
|
print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}")
|
||||||
|
if 'x' in LOWRES_RESIZE:
|
||||||
|
size, ps = LOWRES_RESIZE.split('x')
|
||||||
|
size = int(size)
|
||||||
|
ps = int(ps)
|
||||||
|
LOWRES_RESIZE = (size, ps)
|
||||||
|
else:
|
||||||
|
LOWRES_RESIZE = int(LOWRES_RESIZE)
|
||||||
|
else:
|
||||||
|
LOWRES_RESIZE = None
|
||||||
|
|
||||||
|
|
||||||
|
def pad_image(image, target_resolution, value=0):
|
||||||
|
"""
|
||||||
|
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (PIL.Image.Image): The input image.
|
||||||
|
target_resolution (tuple): The target resolution (width, height) of the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL.Image.Image: The resized and padded image.
|
||||||
|
"""
|
||||||
|
original_width, original_height = image.size
|
||||||
|
target_width, target_height = target_resolution
|
||||||
|
# Create a new image with the target size and paste the resized image onto it
|
||||||
|
new_image = Image.new('RGB', (target_width, target_height), (value, value, value))
|
||||||
|
paste_x = (target_width - original_width) // 2
|
||||||
|
paste_y = (target_height - original_height) // 2
|
||||||
|
new_image.paste(image, (paste_x, paste_y))
|
||||||
|
return new_image
|
||||||
|
|
||||||
|
def resize_images(image, patch_size=14, base_size=896):
|
||||||
|
h, w = image.size
|
||||||
|
if base_size == 0:
|
||||||
|
if h * w > MAXRES * MAXRES:
|
||||||
|
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
||||||
|
scale = MAXRES * MAXRES / (h * w)
|
||||||
|
scale = math.sqrt(scale)
|
||||||
|
elif h * w < MINRES * MINRES:
|
||||||
|
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
||||||
|
scale = MINRES * MINRES / (h * w)
|
||||||
|
scale = math.sqrt(scale)
|
||||||
|
else:
|
||||||
|
scale = None
|
||||||
|
else:
|
||||||
|
scale = base_size * base_size / (h * w)
|
||||||
|
scale = math.sqrt(scale)
|
||||||
|
|
||||||
|
|
||||||
|
if scale is not None:
|
||||||
|
new_h = int(h * scale / patch_size) * patch_size
|
||||||
|
new_w = int(w * scale / patch_size) * patch_size
|
||||||
|
new_h = max(new_h, patch_size)
|
||||||
|
new_w = max(new_w, patch_size)
|
||||||
|
image = image.resize((new_h, new_w))
|
||||||
|
elif PAD2STRIDE:
|
||||||
|
if h % patch_size == 0:
|
||||||
|
new_h = h
|
||||||
|
else:
|
||||||
|
new_h = (h // patch_size + 1) * patch_size
|
||||||
|
|
||||||
|
if w % patch_size == 0:
|
||||||
|
new_w = w
|
||||||
|
else:
|
||||||
|
new_w = (w // patch_size + 1) * patch_size
|
||||||
|
image = pad_image(image, (new_h, new_w), value=127)
|
||||||
|
else:
|
||||||
|
scale = 1.0
|
||||||
|
new_h = int(h * scale / patch_size) * patch_size
|
||||||
|
new_w = int(w * scale / patch_size) * patch_size
|
||||||
|
new_h = max(new_h, patch_size)
|
||||||
|
new_w = max(new_w, patch_size)
|
||||||
|
image = image.resize((new_h, new_w))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def resize_video(image, patch_size=14, base_size=896):
|
||||||
|
h, w = image.size
|
||||||
|
if base_size == 0:
|
||||||
|
if h * w > VIDEO_MAXRES * VIDEO_MAXRES:
|
||||||
|
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
||||||
|
scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w)
|
||||||
|
scale = math.sqrt(scale)
|
||||||
|
elif h * w < VIDEO_MINRES * VIDEO_MINRES:
|
||||||
|
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
||||||
|
scale = VIDEO_MINRES * VIDEO_MINRES / (h * w)
|
||||||
|
scale = math.sqrt(scale)
|
||||||
|
else:
|
||||||
|
scale = None
|
||||||
|
else:
|
||||||
|
scale = base_size * base_size / (h * w)
|
||||||
|
scale = math.sqrt(scale)
|
||||||
|
|
||||||
|
if scale is not None:
|
||||||
|
new_h = int(h * scale / patch_size) * patch_size
|
||||||
|
new_w = int(w * scale / patch_size) * patch_size
|
||||||
|
image = image.resize((new_h, new_w))
|
||||||
|
elif PAD2STRIDE:
|
||||||
|
if h % patch_size == 0:
|
||||||
|
new_h = h
|
||||||
|
else:
|
||||||
|
new_h = (h // patch_size + 1) * patch_size
|
||||||
|
|
||||||
|
if w % patch_size == 0:
|
||||||
|
new_w = w
|
||||||
|
else:
|
||||||
|
new_w = (w // patch_size + 1) * patch_size
|
||||||
|
image = pad_image(image, (new_h, new_w), value=127)
|
||||||
|
else:
|
||||||
|
scale = 1.0
|
||||||
|
new_h = int(h * scale / patch_size) * patch_size
|
||||||
|
new_w = int(w * scale / patch_size) * patch_size
|
||||||
|
image = image.resize((new_h, new_w))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def process_anyres_video(image, processor):
|
||||||
|
if VIDEO_RESIZE is not None:
|
||||||
|
image = resize_video(image, patch_size=video_ps, base_size=video_base)
|
||||||
|
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||||
|
return image.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise ValueError("VIDEO_RESIZE is not set")
|
||||||
|
|
||||||
|
def process_anyres_highres_image(image, processor):
|
||||||
|
processor2 = None
|
||||||
|
if type(processor) is tuple:
|
||||||
|
processor, processor2 = processor[0], processor[1]
|
||||||
|
|
||||||
|
if HIGHRES_BASE is not None:
|
||||||
|
image = resize_images(image, patch_size=highres_ps, base_size=highres_base)
|
||||||
|
|
||||||
|
if processor2 is not None:
|
||||||
|
image_original_resize = image.resize((processor2.size['shortest_edge'], processor.size['shortest_edge']))
|
||||||
|
image_patches = [image_original_resize] + [image_original_resize]
|
||||||
|
image_patches = [processor2.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
||||||
|
for image_patch in image_patches]
|
||||||
|
else:
|
||||||
|
if LOWRES_RESIZE is not None:
|
||||||
|
if type(LOWRES_RESIZE) is int:
|
||||||
|
image_original_resize = resize_images(image, patch_size=14, base_size=LOWRES_RESIZE)
|
||||||
|
else:
|
||||||
|
image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0])
|
||||||
|
else:
|
||||||
|
image_original_resize = image.resize((336, 336))
|
||||||
|
image_patches = [image_original_resize]
|
||||||
|
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
||||||
|
for image_patch in image_patches]
|
||||||
|
image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||||
|
return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0)
|
||||||
|
|
||||||
|
def read_image_patch(patch_info):
|
||||||
|
if 'img_path' in patch_info.keys():
|
||||||
|
image = Image.open(patch_info['img_path']).convert('RGB')
|
||||||
|
else:
|
||||||
|
if 'image_encoing' in patch_info.keys():
|
||||||
|
patch_info['image_encoding'] = patch_info['image_encoing']
|
||||||
|
image_file_name = patch_info['patch']
|
||||||
|
start_bytes = int(patch_info['start_num'])
|
||||||
|
file_size = int(patch_info['size'])
|
||||||
|
|
||||||
|
with open(image_file_name, 'rb') as f:
|
||||||
|
f.seek(start_bytes)
|
||||||
|
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
|
||||||
|
image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
|
||||||
|
else:
|
||||||
|
image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_name_from_path(model_path):
|
||||||
|
model_path = model_path.strip("/")
|
||||||
|
model_paths = model_path.split("/")
|
||||||
|
if model_paths[-1].startswith('checkpoint-'):
|
||||||
|
return model_paths[-2] + "_" + model_paths[-1]
|
||||||
|
else:
|
||||||
|
return model_paths[-1]
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||||
|
def __init__(self, keywords, tokenizer, input_ids):
|
||||||
|
self.keywords = keywords
|
||||||
|
self.keyword_ids = []
|
||||||
|
for keyword in keywords:
|
||||||
|
cur_keyword_ids = tokenizer(keyword).input_ids
|
||||||
|
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
||||||
|
cur_keyword_ids = cur_keyword_ids[1:]
|
||||||
|
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.start_len = input_ids.shape[1]
|
||||||
|
|
||||||
|
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
||||||
|
offset = min(output_ids.shape[1] - self.start_len, 3)
|
||||||
|
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
||||||
|
for keyword_id in self.keyword_ids:
|
||||||
|
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
|
||||||
|
return True
|
||||||
|
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
||||||
|
for keyword in self.keywords:
|
||||||
|
if keyword in outputs:
|
||||||
|
return True
|
||||||
|
return False
|
1
opencompass/models/ola/model/__init__.py
Normal file
1
opencompass/models/ola/model/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen
|
91
opencompass/models/ola/model/builder.py
Normal file
91
opencompass/models/ola/model/builder.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
||||||
|
import torch
|
||||||
|
from opencompass.models.ola.model import *
|
||||||
|
from opencompass.models.ola.model.speech_encoder.builder import build_speech_encoder
|
||||||
|
|
||||||
|
def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs):
|
||||||
|
if load_8bit:
|
||||||
|
kwargs['load_in_8bit'] = True
|
||||||
|
elif load_4bit:
|
||||||
|
kwargs['load_in_4bit'] = True
|
||||||
|
kwargs['quantization_config'] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type='nf4'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kwargs['torch_dtype'] = torch.bfloat16
|
||||||
|
|
||||||
|
if use_flash_attn:
|
||||||
|
kwargs['attn_implementation'] = 'flash_attention_2'
|
||||||
|
|
||||||
|
model_cls = OlaQwenForCausalLM
|
||||||
|
|
||||||
|
# Load Ola model
|
||||||
|
if is_lora:
|
||||||
|
assert model_base is not None, "model_base is required for LoRA models."
|
||||||
|
from ola.model.language_model.ola_qwen import OlaConfigQwen
|
||||||
|
lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
||||||
|
print('Loading Ola from base model...')
|
||||||
|
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
|
||||||
|
print('Loading additional Ola weights...')
|
||||||
|
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
||||||
|
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
||||||
|
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
||||||
|
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
||||||
|
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
||||||
|
model.load_state_dict(non_lora_trainables, strict=False)
|
||||||
|
|
||||||
|
from peft import PeftModel
|
||||||
|
print('Loading LoRA weights...')
|
||||||
|
model = PeftModel.from_pretrained(model, model_path)
|
||||||
|
print('Merging LoRA weights...')
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
print('Model is loaded...')
|
||||||
|
elif model_base is not None:
|
||||||
|
print('Loading Ola from base model...')
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
||||||
|
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
||||||
|
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
|
||||||
|
|
||||||
|
speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu')
|
||||||
|
speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
|
||||||
|
model.load_state_dict(speech_projector_weights, strict=False)
|
||||||
|
model = model.to(device=device)
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||||
|
model = model_cls.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
low_cpu_mem_usage=False,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
model = model.to(device=device)
|
||||||
|
|
||||||
|
model.get_model().speech_encoder = build_speech_encoder(model.config)
|
||||||
|
model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
|
||||||
|
|
||||||
|
image_processor = None
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
vision_tower = model.get_vision_tower()
|
||||||
|
print("Loading vision tower...")
|
||||||
|
if not vision_tower.is_loaded:
|
||||||
|
vision_tower.load_model(device_map=device)
|
||||||
|
if device != "auto":
|
||||||
|
vision_tower.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
else:
|
||||||
|
vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
|
||||||
|
image_processor = vision_tower.image_processor
|
||||||
|
print("Loading vision tower succeeded.")
|
||||||
|
|
||||||
|
if hasattr(model.config, "max_sequence_length"):
|
||||||
|
context_len = model.config.max_sequence_length
|
||||||
|
else:
|
||||||
|
context_len = 16384
|
||||||
|
|
||||||
|
return tokenizer, model, image_processor, context_len
|
237
opencompass/models/ola/model/language_model/ola_qwen.py
Normal file
237
opencompass/models/ola/model/language_model/ola_qwen.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.generation.utils import GenerateOutput
|
||||||
|
|
||||||
|
from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM
|
||||||
|
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class OlaConfigQwen(Qwen2Config):
|
||||||
|
model_type = "ola_qwen"
|
||||||
|
|
||||||
|
|
||||||
|
class OlaQwenModel(OlaMetaModel, Qwen2Model):
|
||||||
|
config_class = OlaConfigQwen
|
||||||
|
|
||||||
|
def __init__(self, config: Qwen2Config):
|
||||||
|
super(OlaQwenModel, self).__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM):
|
||||||
|
config_class = OlaConfigQwen
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(Qwen2ForCausalLM, self).__init__(config)
|
||||||
|
|
||||||
|
config.rope_scaling = None
|
||||||
|
self.model = OlaQwenModel(config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
speech: Optional[torch.FloatTensor] = None,
|
||||||
|
speech_lengths: Optional[torch.LongTensor] = None,
|
||||||
|
speech_chunks: Optional[torch.LongTensor] = None,
|
||||||
|
speech_wav: Optional[torch.FloatTensor] = None,
|
||||||
|
images: Optional[torch.FloatTensor] = None,
|
||||||
|
images_highres: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
image_sizes: Optional[List[List[int]]] = None,
|
||||||
|
modalities: Optional[List[str]] = ["image"],
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
past_key_values,
|
||||||
|
inputs_embeds,
|
||||||
|
labels
|
||||||
|
) = self.prepare_inputs_labels_for_speech_vision_text(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
past_key_values,
|
||||||
|
labels,
|
||||||
|
speech,
|
||||||
|
speech_lengths,
|
||||||
|
speech_chunks,
|
||||||
|
speech_wav,
|
||||||
|
images,
|
||||||
|
modalities,
|
||||||
|
image_sizes,
|
||||||
|
images_highres
|
||||||
|
)
|
||||||
|
|
||||||
|
if labels is None:
|
||||||
|
return super().forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.forward_llm_efficient(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
hidden_dim = hidden_states.size(-1)
|
||||||
|
shift_labels = labels[..., 1:].contiguous().reshape(-1)
|
||||||
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
|
||||||
|
assert shift_labels.size(0) == shift_hidden_states.size(0)
|
||||||
|
mask = shift_labels > -1
|
||||||
|
assert mask.float().sum() > 0
|
||||||
|
shift_labels = shift_labels[mask]
|
||||||
|
shift_hidden_states = shift_hidden_states[mask, :]
|
||||||
|
logits = self.lm_head(shift_hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits, shift_labels)
|
||||||
|
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
speech: Optional[torch.Tensor] = None,
|
||||||
|
speech_lengths: Optional[torch.Tensor] = None,
|
||||||
|
speech_chunks: Optional[torch.Tensor] = None,
|
||||||
|
speech_wav: Optional[torch.FloatTensor] = None,
|
||||||
|
images: Optional[torch.Tensor] = None,
|
||||||
|
images_highres: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
modalities: Optional[List[str]] = ["image"],
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
|
position_ids = kwargs.pop("position_ids", None)
|
||||||
|
attention_mask = kwargs.pop("attention_mask", None)
|
||||||
|
if "inputs_embeds" in kwargs:
|
||||||
|
raise NotImplementedError("`inputs_embeds` is not supported")
|
||||||
|
|
||||||
|
(
|
||||||
|
inputs,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
_,
|
||||||
|
inputs_embeds,
|
||||||
|
_
|
||||||
|
) = self.prepare_inputs_labels_for_speech_vision_text(
|
||||||
|
inputs,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
speech,
|
||||||
|
speech_lengths,
|
||||||
|
speech_chunks,
|
||||||
|
speech_wav,
|
||||||
|
images,
|
||||||
|
modalities,
|
||||||
|
image_sizes,
|
||||||
|
images_highres
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().generate(
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
||||||
|
inputs_embeds=None, **kwargs):
|
||||||
|
speech = kwargs.pop("speech", None)
|
||||||
|
speech_lengths = kwargs.pop("speech_lengths", None)
|
||||||
|
speech_chunks = kwargs.pop("speech_chunks", None)
|
||||||
|
images = kwargs.pop("images", None)
|
||||||
|
image_sizes = kwargs.pop("image_sizes", None)
|
||||||
|
inputs = super().prepare_inputs_for_generation(
|
||||||
|
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
||||||
|
)
|
||||||
|
if speech is not None:
|
||||||
|
inputs['speech'] = speech
|
||||||
|
inputs['speech_lengths'] = speech_lengths
|
||||||
|
inputs['speech_chunks'] = speech_chunks
|
||||||
|
if images is not None:
|
||||||
|
inputs["images"] = images
|
||||||
|
if image_sizes is not None:
|
||||||
|
inputs["image_sizes"] = image_sizes
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
AutoConfig.register("ola_qwen", OlaConfigQwen)
|
||||||
|
AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM)
|
@ -0,0 +1,9 @@
|
|||||||
|
import os
|
||||||
|
from .oryx_vit import SigLIPViTAnysizeWrapper
|
||||||
|
|
||||||
|
def build_vision_tower(vision_tower_cfg, **kwargs):
|
||||||
|
vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None))
|
||||||
|
is_absolute_path_exists = os.path.exists(vision_tower)
|
||||||
|
print(f"Buiding OryxViTWrapper from {vision_tower}...")
|
||||||
|
# path = vision_tower.split(":")[1]
|
||||||
|
return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs)
|
1075
opencompass/models/ola/model/multimodal_encoder/oryx_vit.py
Normal file
1075
opencompass/models/ola/model/multimodal_encoder/oryx_vit.py
Normal file
File diff suppressed because it is too large
Load Diff
172
opencompass/models/ola/model/multimodal_projector/builder.py
Normal file
172
opencompass/models/ola/model/multimodal_projector/builder.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import re
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .pooler_projector import NormalizedDwPooler
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
|
||||||
|
class IdentityMap(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
return {"mm_projector_type": 'identity'}
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleResBlock(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = nn.LayerNorm(channels)
|
||||||
|
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Linear(channels, channels),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(channels, channels)
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
return x + self.proj(x)
|
||||||
|
|
||||||
|
class OlaMLP(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, twoview=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.proj1 = nn.Linear(in_channels, out_channels)
|
||||||
|
self.proj2 = nn.Linear(out_channels, out_channels)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.pooler = NormalizedDwPooler(out_channels)
|
||||||
|
|
||||||
|
embed_std = 1 / math.sqrt(out_channels)
|
||||||
|
self.image_newline = nn.Parameter(
|
||||||
|
torch.randn(out_channels) * embed_std
|
||||||
|
)
|
||||||
|
self.image_begin = nn.Parameter(
|
||||||
|
torch.randn(out_channels) * embed_std
|
||||||
|
)
|
||||||
|
self.image_end = nn.Parameter(
|
||||||
|
torch.randn(out_channels) * embed_std
|
||||||
|
)
|
||||||
|
|
||||||
|
if twoview:
|
||||||
|
self.image_sep = nn.Parameter(
|
||||||
|
torch.randn(out_channels) * embed_std
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'):
|
||||||
|
|
||||||
|
if modalities in ['image', 'text']:
|
||||||
|
h, w = size
|
||||||
|
dtype = x.dtype
|
||||||
|
x = x.reshape(x.shape[0], h, w, -1)
|
||||||
|
x = self.proj1(x)
|
||||||
|
x = self.pooler(x, forward_type='2x')
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.proj2(x)
|
||||||
|
|
||||||
|
|
||||||
|
b, h, w, c = x.shape
|
||||||
|
x = torch.cat([
|
||||||
|
x,
|
||||||
|
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
|
||||||
|
], dim=2)
|
||||||
|
x = x.reshape(b, -1, c)
|
||||||
|
|
||||||
|
if x2 is not None:
|
||||||
|
h2, w2 = size2
|
||||||
|
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
||||||
|
x2 = self.proj1(x2)
|
||||||
|
x2 = self.pooler(x2, forward_type='2x')
|
||||||
|
x2 = self.act(x2)
|
||||||
|
x2 = self.proj2(x2)
|
||||||
|
|
||||||
|
b2, h2, w2, c2 = x2.shape
|
||||||
|
x2 = torch.cat([
|
||||||
|
x2,
|
||||||
|
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
|
||||||
|
], dim=2)
|
||||||
|
x2 = x2.reshape(b, -1, c)
|
||||||
|
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
|
||||||
|
x = torch.cat([x, sep, x2], dim=1)
|
||||||
|
|
||||||
|
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
||||||
|
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
||||||
|
x = torch.cat([begin, x, end], dim=1)
|
||||||
|
return x
|
||||||
|
elif modalities in ['video']:
|
||||||
|
# x2 is the true feature, ignore x
|
||||||
|
h, w = size
|
||||||
|
dtype = x.dtype
|
||||||
|
x = x.reshape(x.shape[0], h, w, -1)
|
||||||
|
x1 = self.proj1(x)
|
||||||
|
x1 = self.pooler(x1, forward_type='2x')
|
||||||
|
x1 = self.proj2(x1).mean() * 0.0
|
||||||
|
|
||||||
|
h2, w2 = size2
|
||||||
|
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
||||||
|
x2 = self.proj1(x2)
|
||||||
|
x2 = self.pooler(x2, forward_type='2x')
|
||||||
|
x2 = self.act(x2)
|
||||||
|
x2 = self.proj2(x2)
|
||||||
|
|
||||||
|
b2, h2, w2, c = x2.shape
|
||||||
|
x2 = torch.cat([
|
||||||
|
x2,
|
||||||
|
self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype)
|
||||||
|
], dim=2)
|
||||||
|
|
||||||
|
x2 = x2.reshape(b2, -1, c)
|
||||||
|
|
||||||
|
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype)
|
||||||
|
x2 = torch.cat([x2, sep], dim=1)
|
||||||
|
|
||||||
|
x2 = x2.flatten(0, 1)
|
||||||
|
|
||||||
|
begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype)
|
||||||
|
end = self.image_end.reshape(1, -1).expand(1, c).to(dtype)
|
||||||
|
x2 = torch.cat([begin, x2, end], dim=0)
|
||||||
|
x2 = x2.unsqueeze(0)
|
||||||
|
return x2
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown modalities: {modalities}')
|
||||||
|
|
||||||
|
def build_vision_projector(config, delay_load=False, **kwargs):
|
||||||
|
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
||||||
|
|
||||||
|
if projector_type == 'linear':
|
||||||
|
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
||||||
|
|
||||||
|
elif projector_type == 'ola_mlp':
|
||||||
|
return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True)
|
||||||
|
|
||||||
|
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
||||||
|
if mlp_gelu_match:
|
||||||
|
mlp_depth = int(mlp_gelu_match.group(1))
|
||||||
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
|
||||||
|
mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type)
|
||||||
|
if mlp_gelu_resnet_match:
|
||||||
|
mlp_depth = int(mlp_gelu_resnet_match.group(1))
|
||||||
|
res_depth = int(mlp_gelu_resnet_match.group(2))
|
||||||
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
modules.append(nn.GELU())
|
||||||
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
||||||
|
for _ in range(res_depth):
|
||||||
|
modules.append(SimpleResBlock(config.hidden_size))
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
|
||||||
|
if projector_type == 'identity':
|
||||||
|
return IdentityMap()
|
||||||
|
|
||||||
|
raise ValueError(f'Unknown projector type: {projector_type}')
|
@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPVisionModel
|
||||||
|
import os
|
||||||
|
|
||||||
|
class PoolerProjector(nn.Module):
|
||||||
|
def __init__(self, config, vision_cfg):
|
||||||
|
super().__init__()
|
||||||
|
self._config = config
|
||||||
|
self.hw = vision_cfg.image_size // vision_cfg.patch_size
|
||||||
|
|
||||||
|
self.conv_pool = nn.Conv2d(
|
||||||
|
config.mm_hidden_size, config.hidden_size,
|
||||||
|
kernel_size=2, stride=2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(config.hidden_size, config.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
height = width = self.hw
|
||||||
|
assert height * width == x.shape[1]
|
||||||
|
x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
|
||||||
|
x = self.conv_pool(x)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
return {"mm_projector_type": 'pooler'}
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedDwPooler(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.predictor = nn.Sequential(
|
||||||
|
nn.Linear(dim*2, dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, forward_type='2x'):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
|
||||||
|
if forward_type == '2x':
|
||||||
|
new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
|
||||||
|
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
|
||||||
|
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
||||||
|
elif forward_type == '1x':
|
||||||
|
new_x = x.reshape(B, H, W, 1, C)
|
||||||
|
fused_x = torch.cat([new_x, new_x], dim=-1)
|
||||||
|
elif forward_type == '4x':
|
||||||
|
new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
|
||||||
|
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
|
||||||
|
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
||||||
|
|
||||||
|
score = self.predictor(fused_x)
|
||||||
|
normalized_score = F.softmax(score, dim=-2)
|
||||||
|
new_x = (new_x * normalized_score).sum(dim=-2)
|
||||||
|
return new_x
|
20
opencompass/models/ola/model/multimodal_resampler/builder.py
Normal file
20
opencompass/models/ola/model/multimodal_resampler/builder.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class IdentityMap(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
return {"mm_resampler_type": None}
|
||||||
|
|
||||||
|
def build_vision_resampler(model_args, delay_load=False, **kwargs):
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
resampler_type = getattr(model_args, 'mm_resampler_type', None)
|
||||||
|
if resampler_type is None:
|
||||||
|
return IdentityMap()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown resampler type: {resampler_type}')
|
397
opencompass/models/ola/model/ola_arch.py
Normal file
397
opencompass/models/ola/model/ola_arch.py
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .speech_encoder.builder import build_speech_encoder
|
||||||
|
from .speech_projector.builder import build_speech_projector
|
||||||
|
from opencompass.models.ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX
|
||||||
|
from opencompass.models.ola.utils import lengths_to_padding_mask
|
||||||
|
|
||||||
|
from .multimodal_encoder.builder import build_vision_tower
|
||||||
|
from .multimodal_resampler.builder import build_vision_resampler
|
||||||
|
from .multimodal_projector.builder import build_vision_projector
|
||||||
|
|
||||||
|
from opencompass.models.ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
||||||
|
|
||||||
|
class OlaMetaModel:
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(OlaMetaModel, self).__init__(config)
|
||||||
|
if hasattr(config, "speech_encoder"):
|
||||||
|
self.speech_encoder = build_speech_encoder(config)
|
||||||
|
self.speech_projector = build_speech_projector(config)
|
||||||
|
|
||||||
|
if hasattr(config, "mm_vision_tower"):
|
||||||
|
self.vision_tower = build_vision_tower(config, delay_load=True)
|
||||||
|
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
|
||||||
|
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
|
||||||
|
|
||||||
|
def get_speech_encoder(self):
|
||||||
|
speech_encoder = getattr(self, 'speech_encoder', None)
|
||||||
|
if type(speech_encoder) is list:
|
||||||
|
speech_encoder = speech_encoder[0]
|
||||||
|
return speech_encoder
|
||||||
|
|
||||||
|
def get_vision_tower(self):
|
||||||
|
vision_tower = getattr(self, 'vision_tower', None)
|
||||||
|
if type(vision_tower) is list:
|
||||||
|
vision_tower = vision_tower[0]
|
||||||
|
return vision_tower
|
||||||
|
|
||||||
|
def initialize_speech_modules(self, model_args, fsdp=None):
|
||||||
|
self.config.speech_encoder = getattr(model_args, "speech_encoder", None)
|
||||||
|
self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None)
|
||||||
|
self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear')
|
||||||
|
self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5)
|
||||||
|
self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280)
|
||||||
|
self.config.music_encoder = getattr(model_args, 'music_encoder', None)
|
||||||
|
|
||||||
|
if self.get_speech_encoder() is None:
|
||||||
|
speech_encoder = build_speech_encoder(self.config)
|
||||||
|
if fsdp is not None and len(fsdp) > 0:
|
||||||
|
self.speech_encoder = [speech_encoder]
|
||||||
|
else:
|
||||||
|
self.speech_encoder = speech_encoder
|
||||||
|
|
||||||
|
if getattr(self, 'speech_projector', None) is None:
|
||||||
|
self.speech_projector = build_speech_projector(self.config)
|
||||||
|
else:
|
||||||
|
# In case it is frozen by LoRA
|
||||||
|
for p in self.speech_projector.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
if model_args.pretrain_speech_projector is not None:
|
||||||
|
pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu')
|
||||||
|
def get_w(weights, keyword):
|
||||||
|
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
||||||
|
print('Loading pretrain speech projector weights')
|
||||||
|
|
||||||
|
msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False)
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
def initialize_vision_modules(self, model_args, fsdp=None):
|
||||||
|
vision_tower = model_args.vision_tower
|
||||||
|
mm_vision_select_layer = model_args.mm_vision_select_layer
|
||||||
|
mm_vision_select_feature = model_args.mm_vision_select_feature
|
||||||
|
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
||||||
|
|
||||||
|
self.config.mm_vision_tower = vision_tower
|
||||||
|
|
||||||
|
if self.get_vision_tower() is None:
|
||||||
|
vision_tower = build_vision_tower(model_args)
|
||||||
|
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
||||||
|
## Get the mm_spatial_pool_mode and mm_spatial_pool_stride
|
||||||
|
for k, v in vision_resampler.config.items():
|
||||||
|
setattr(self.config, k, v)
|
||||||
|
|
||||||
|
if fsdp is not None and len(fsdp) > 0:
|
||||||
|
self.vision_tower = [vision_tower]
|
||||||
|
self.vision_resampler = [vision_resampler]
|
||||||
|
else:
|
||||||
|
self.vision_tower = vision_tower
|
||||||
|
self.vision_resampler = vision_resampler
|
||||||
|
else:
|
||||||
|
if fsdp is not None and len(fsdp) > 0:
|
||||||
|
vision_resampler = self.vision_resampler[0]
|
||||||
|
vision_tower = self.vision_tower[0]
|
||||||
|
else:
|
||||||
|
vision_resampler = self.vision_resampler
|
||||||
|
vision_tower = self.vision_tower
|
||||||
|
vision_tower.load_model() # 已加载权重,故跳过
|
||||||
|
|
||||||
|
# In case it is frozen by LoRA
|
||||||
|
for p in self.vision_resampler.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
self.config.use_mm_proj = True
|
||||||
|
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
||||||
|
self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size)
|
||||||
|
|
||||||
|
self.config.mm_vision_select_layer = mm_vision_select_layer
|
||||||
|
self.config.mm_vision_select_feature = mm_vision_select_feature
|
||||||
|
|
||||||
|
if getattr(self, 'mm_projector', None) is None:
|
||||||
|
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
||||||
|
else:
|
||||||
|
for p in self.mm_projector.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
if pretrain_mm_mlp_adapter is not None:
|
||||||
|
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
||||||
|
def get_w(weights, keyword):
|
||||||
|
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
||||||
|
|
||||||
|
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
||||||
|
print('Loading pretrain mm projector weights')
|
||||||
|
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False)
|
||||||
|
print(incompatible_keys)
|
||||||
|
|
||||||
|
class OlaMetaForCausalLM(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_speech_encoder(self):
|
||||||
|
return self.get_model().get_speech_encoder()
|
||||||
|
|
||||||
|
def get_vision_tower(self):
|
||||||
|
return self.get_model().get_vision_tower()
|
||||||
|
|
||||||
|
def get_speech_projector(self):
|
||||||
|
return self.get_model().speech_projector
|
||||||
|
|
||||||
|
def encode_speech(self, speech, speech_lengths, speech_wav):
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
speech_encoder_type = self.config.speech_encoder_type
|
||||||
|
speech_encoder = self.get_speech_encoder()
|
||||||
|
if "whisper" in speech_encoder_type.lower():
|
||||||
|
encoder_outs = speech_encoder(speech.permute(0, 2, 1))
|
||||||
|
speech_lengths = (speech_lengths + 1) // 2
|
||||||
|
else:
|
||||||
|
encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav)
|
||||||
|
speech_lengths = (speech_lengths + 1) // 2
|
||||||
|
speech_projector_type = self.config.speech_projector_type
|
||||||
|
speech_projector = self.get_speech_projector()
|
||||||
|
if speech_projector_type == "linear":
|
||||||
|
encoder_outs = speech_projector(encoder_outs)
|
||||||
|
speech_lengths = speech_lengths // speech_projector.k
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown speech projector: {speech_projector_type}')
|
||||||
|
# speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))]
|
||||||
|
return encoder_outs
|
||||||
|
|
||||||
|
def prepare_inputs_labels_for_speech_vision_text(
|
||||||
|
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
||||||
|
speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None
|
||||||
|
):
|
||||||
|
speech_encoder = self.get_speech_encoder()
|
||||||
|
vision_tower = self.get_vision_tower()
|
||||||
|
|
||||||
|
if speech_encoder is None or input_ids.shape[1] == 1:
|
||||||
|
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
||||||
|
|
||||||
|
if vision_tower is None or input_ids.shape[1] == 1:
|
||||||
|
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
||||||
|
# encode speech
|
||||||
|
if not isinstance(speech, list):
|
||||||
|
speech = torch.split(speech, speech_chunks.tolist(), dim=0)
|
||||||
|
speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0)
|
||||||
|
speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0)
|
||||||
|
speech_features = []
|
||||||
|
for idx in range(len(speech)):
|
||||||
|
speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx]))
|
||||||
|
|
||||||
|
# encode vision
|
||||||
|
if isinstance(modalities, str):
|
||||||
|
modalities = [modalities]
|
||||||
|
|
||||||
|
video_idx_in_batch = []
|
||||||
|
for modal in range(len(modalities)):
|
||||||
|
if 'video' in modalities[modal]:
|
||||||
|
video_idx_in_batch.append(modal)
|
||||||
|
|
||||||
|
aimg = images[-1]
|
||||||
|
lowres_img = []
|
||||||
|
for idx, img_feat in enumerate(images):
|
||||||
|
if idx in video_idx_in_batch:
|
||||||
|
img_feat = aimg.new(1, 3, 128, 128).fill_(0)
|
||||||
|
lowres_img.append(img_feat)
|
||||||
|
|
||||||
|
lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img)
|
||||||
|
highres_img_features = []
|
||||||
|
highres_img_sizes = []
|
||||||
|
for idx, img_feat in enumerate(images_highres):
|
||||||
|
if img_feat.ndim == 5:
|
||||||
|
img_feat = img_feat.squeeze(1)
|
||||||
|
highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat)
|
||||||
|
highres_img_features.append(highres_img_feature)
|
||||||
|
highres_img_sizes.append(highres_img_size)
|
||||||
|
image_features = []
|
||||||
|
for idx in range(len(modalities)):
|
||||||
|
img_feat = self.get_model().mm_projector(lowres_img_features[idx],
|
||||||
|
lowres_img_sizes[idx],
|
||||||
|
highres_img_features[idx],
|
||||||
|
highres_img_sizes[idx],
|
||||||
|
modalities[idx])
|
||||||
|
image_features.append(img_feat.flatten(0, 1))
|
||||||
|
|
||||||
|
_labels = labels
|
||||||
|
_position_ids = position_ids
|
||||||
|
_attention_mask = attention_mask
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.bool()
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
||||||
|
if labels is None:
|
||||||
|
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
||||||
|
|
||||||
|
# remove the padding using attention_mask -- FIXME
|
||||||
|
_input_ids = input_ids
|
||||||
|
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
||||||
|
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
||||||
|
|
||||||
|
new_input_embeds = []
|
||||||
|
new_labels = []
|
||||||
|
cur_speech_idx = 0
|
||||||
|
cur_image_idx = 0
|
||||||
|
for batch_idx, cur_input_ids in enumerate(input_ids):
|
||||||
|
|
||||||
|
num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
|
||||||
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
||||||
|
|
||||||
|
num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
|
||||||
|
|
||||||
|
if num_speech_images == 0:
|
||||||
|
cur_speech_features = speech_features[cur_speech_idx]
|
||||||
|
cur_images_features = image_features[cur_image_idx]
|
||||||
|
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
||||||
|
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0)
|
||||||
|
new_input_embeds.append(cur_input_embeds)
|
||||||
|
new_labels.append(labels[batch_idx])
|
||||||
|
cur_speech_idx += 1
|
||||||
|
cur_image_idx += 1
|
||||||
|
continue
|
||||||
|
speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]]
|
||||||
|
|
||||||
|
cur_input_ids_nospeech_image = []
|
||||||
|
cur_labels = labels[batch_idx]
|
||||||
|
cur_labels_nospeech_image = []
|
||||||
|
for i in range(len(speech_image_token_indices) - 1):
|
||||||
|
cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
|
||||||
|
cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
|
||||||
|
split_sizes = [x.shape[0] for x in cur_labels_nospeech_image]
|
||||||
|
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image))
|
||||||
|
cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0)
|
||||||
|
cur_new_input_embeds = []
|
||||||
|
cur_new_labels = []
|
||||||
|
|
||||||
|
for i in range(num_speech_images + 1):
|
||||||
|
cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i])
|
||||||
|
cur_new_labels.append(cur_labels_nospeech_image[i])
|
||||||
|
if i < num_speech_images:
|
||||||
|
if i < num_images:
|
||||||
|
cur_images_features = image_features[cur_image_idx]
|
||||||
|
cur_image_idx += 1
|
||||||
|
cur_new_input_embeds.append(cur_images_features)
|
||||||
|
cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
||||||
|
else:
|
||||||
|
cur_speech_features = speech_features[cur_speech_idx]
|
||||||
|
cur_speech_idx += 1
|
||||||
|
cur_new_input_embeds.append(cur_speech_features)
|
||||||
|
cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
||||||
|
|
||||||
|
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
||||||
|
|
||||||
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
||||||
|
cur_new_labels = torch.cat(cur_new_labels)
|
||||||
|
|
||||||
|
if num_images == 0:
|
||||||
|
cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0)
|
||||||
|
cur_image_idx += 1
|
||||||
|
|
||||||
|
if num_speech == 0:
|
||||||
|
cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0)
|
||||||
|
cur_speech_idx += 1
|
||||||
|
|
||||||
|
new_input_embeds.append(cur_new_input_embeds)
|
||||||
|
new_labels.append(cur_new_labels)
|
||||||
|
|
||||||
|
# Truncate sequences to max length as speech features can make the sequence longer
|
||||||
|
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
||||||
|
if tokenizer_model_max_length is not None:
|
||||||
|
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
||||||
|
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
||||||
|
|
||||||
|
# Combine them
|
||||||
|
max_len = max(x.shape[0] for x in new_input_embeds)
|
||||||
|
batch_size = len(new_input_embeds)
|
||||||
|
|
||||||
|
new_input_embeds_padded = []
|
||||||
|
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
||||||
|
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
||||||
|
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
||||||
|
|
||||||
|
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
||||||
|
cur_len = cur_new_embed.shape[0]
|
||||||
|
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
||||||
|
new_input_embeds_padded.append(torch.cat((
|
||||||
|
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
||||||
|
cur_new_embed
|
||||||
|
), dim=0))
|
||||||
|
if cur_len > 0:
|
||||||
|
new_labels_padded[i, -cur_len:] = cur_new_labels
|
||||||
|
attention_mask[i, -cur_len:] = True
|
||||||
|
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
||||||
|
else:
|
||||||
|
new_input_embeds_padded.append(torch.cat((
|
||||||
|
cur_new_embed,
|
||||||
|
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
||||||
|
), dim=0))
|
||||||
|
if cur_len > 0:
|
||||||
|
new_labels_padded[i, :cur_len] = cur_new_labels
|
||||||
|
attention_mask[i, :cur_len] = True
|
||||||
|
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
||||||
|
|
||||||
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
||||||
|
|
||||||
|
if _labels is None:
|
||||||
|
new_labels = None
|
||||||
|
else:
|
||||||
|
new_labels = new_labels_padded
|
||||||
|
|
||||||
|
if _attention_mask is None:
|
||||||
|
attention_mask = None
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
||||||
|
|
||||||
|
if _position_ids is None:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
|
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
||||||
|
|
||||||
|
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
||||||
|
if model_args.mm_use_im_patch_token:
|
||||||
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||||
|
self.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
if model_args.mm_use_im_start_end:
|
||||||
|
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||||
|
self.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
if num_new_tokens > 0:
|
||||||
|
input_embeddings = self.get_input_embeddings().weight.data
|
||||||
|
output_embeddings = self.get_output_embeddings().weight.data
|
||||||
|
|
||||||
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||||
|
dim=0, keepdim=True)
|
||||||
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||||
|
dim=0, keepdim=True)
|
||||||
|
|
||||||
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||||
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||||
|
|
||||||
|
if model_args.tune_mm_mlp_adapter:
|
||||||
|
for p in self.get_input_embeddings().parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
for p in self.get_output_embeddings().parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
if model_args.pretrain_mm_mlp_adapter:
|
||||||
|
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
||||||
|
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
||||||
|
assert num_new_tokens == 2
|
||||||
|
if input_embeddings.shape == embed_tokens_weight.shape:
|
||||||
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
||||||
|
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
||||||
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
||||||
|
elif model_args.mm_use_im_patch_token:
|
||||||
|
if model_args.tune_mm_mlp_adapter:
|
||||||
|
for p in self.get_input_embeddings().parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
for p in self.get_output_embeddings().parameters():
|
||||||
|
p.requires_grad = False
|
182
opencompass/models/ola/model/speech_encoder/beats/BEATs.py
Normal file
182
opencompass/models/ola/model/speech_encoder/beats/BEATs.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
# import torchaudio.compliance.kaldi as ta_kaldi
|
||||||
|
|
||||||
|
from .kaldi import fbank as kaldi_fbank
|
||||||
|
|
||||||
|
from .backbone import (
|
||||||
|
TransformerEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BEATsConfig:
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
self.input_patch_size: int = -1 # path size of patch embedding
|
||||||
|
self.embed_dim: int = 512 # patch embedding dimension
|
||||||
|
self.conv_bias: bool = False # include bias in conv encoder
|
||||||
|
|
||||||
|
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||||
|
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||||
|
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||||
|
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||||
|
self.activation_fn: str = "gelu" # activation function to use
|
||||||
|
|
||||||
|
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
||||||
|
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||||
|
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||||
|
|
||||||
|
# dropouts
|
||||||
|
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||||
|
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||||
|
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||||
|
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||||
|
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||||
|
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||||
|
|
||||||
|
# relative position embedding
|
||||||
|
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||||
|
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||||
|
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||||
|
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||||
|
|
||||||
|
# label predictor
|
||||||
|
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
||||||
|
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
||||||
|
self.predictor_class: int = 527 # target class number for the predictor
|
||||||
|
|
||||||
|
if cfg is not None:
|
||||||
|
self.update(cfg)
|
||||||
|
|
||||||
|
def update(self, cfg: dict):
|
||||||
|
self.__dict__.update(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class BEATs(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: BEATsConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logger.info(f"BEATs Config: {cfg.__dict__}")
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.embed = cfg.embed_dim
|
||||||
|
self.post_extract_proj = (
|
||||||
|
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||||
|
if self.embed != cfg.encoder_embed_dim
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_patch_size = cfg.input_patch_size
|
||||||
|
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||||
|
bias=cfg.conv_bias)
|
||||||
|
|
||||||
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||||
|
|
||||||
|
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||||
|
self.encoder = TransformerEncoder(cfg)
|
||||||
|
self.layer_norm = LayerNorm(self.embed)
|
||||||
|
|
||||||
|
if cfg.finetuned_model:
|
||||||
|
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
||||||
|
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
||||||
|
else:
|
||||||
|
self.predictor = None
|
||||||
|
|
||||||
|
def forward_padding_mask(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
padding_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
extra = padding_mask.size(1) % features.size(1)
|
||||||
|
if extra > 0:
|
||||||
|
padding_mask = padding_mask[:, :-extra]
|
||||||
|
padding_mask = padding_mask.view(
|
||||||
|
padding_mask.size(0), features.size(1), -1
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask.all(-1)
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fbanks = []
|
||||||
|
for waveform in source:
|
||||||
|
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||||
|
fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||||
|
fbanks.append(fbank)
|
||||||
|
fbank = torch.stack(fbanks, dim=0)
|
||||||
|
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||||
|
return fbank
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
feature_only=False,
|
||||||
|
):
|
||||||
|
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||||
|
|
||||||
|
fbank = fbank.unsqueeze(1)
|
||||||
|
features = self.patch_embedding(fbank)
|
||||||
|
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
features = self.layer_norm(features)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||||
|
|
||||||
|
if self.post_extract_proj is not None:
|
||||||
|
features = self.post_extract_proj(features)
|
||||||
|
|
||||||
|
x = self.dropout_input(features)
|
||||||
|
|
||||||
|
x, layer_results = self.encoder(
|
||||||
|
x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not feature_only and self.predictor is not None:
|
||||||
|
x = self.predictor_dropout(x)
|
||||||
|
logits = self.predictor(x)
|
||||||
|
|
||||||
|
if padding_mask is not None and padding_mask.any():
|
||||||
|
logits[padding_mask] = 0
|
||||||
|
logits = logits.sum(dim=1)
|
||||||
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
||||||
|
else:
|
||||||
|
logits = logits.mean(dim=1)
|
||||||
|
|
||||||
|
lprobs = torch.sigmoid(logits)
|
||||||
|
|
||||||
|
return lprobs, padding_mask
|
||||||
|
else:
|
||||||
|
return x, padding_mask
|
174
opencompass/models/ola/model/speech_encoder/beats/Tokenizers.py
Normal file
174
opencompass/models/ola/model/speech_encoder/beats/Tokenizers.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
# import torchaudio.compliance.kaldi as ta_kaldi
|
||||||
|
|
||||||
|
from .kaldi import fbank as kaldi_fbank
|
||||||
|
|
||||||
|
from .backbone import (
|
||||||
|
TransformerEncoder,
|
||||||
|
)
|
||||||
|
from .quantizer import (
|
||||||
|
NormEMAVectorQuantizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizersConfig:
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
self.input_patch_size: int = -1 # path size of patch embedding
|
||||||
|
self.embed_dim: int = 512 # patch embedding dimension
|
||||||
|
self.conv_bias: bool = False # include bias in conv encoder
|
||||||
|
|
||||||
|
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||||
|
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||||
|
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||||
|
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||||
|
self.activation_fn: str = "gelu" # activation function to use
|
||||||
|
|
||||||
|
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||||
|
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
||||||
|
|
||||||
|
# dropouts
|
||||||
|
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||||
|
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||||
|
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||||
|
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||||
|
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||||
|
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||||
|
|
||||||
|
# relative position embedding
|
||||||
|
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||||
|
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||||
|
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||||
|
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||||
|
|
||||||
|
# quantizer
|
||||||
|
self.quant_n: int = 1024 # codebook number in quantizer
|
||||||
|
self.quant_dim: int = 256 # codebook dimension in quantizer
|
||||||
|
|
||||||
|
if cfg is not None:
|
||||||
|
self.update(cfg)
|
||||||
|
|
||||||
|
def update(self, cfg: dict):
|
||||||
|
self.__dict__.update(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizers(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: TokenizersConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.embed = cfg.embed_dim
|
||||||
|
self.post_extract_proj = (
|
||||||
|
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||||
|
if self.embed != cfg.encoder_embed_dim
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_patch_size = cfg.input_patch_size
|
||||||
|
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
||||||
|
bias=cfg.conv_bias)
|
||||||
|
|
||||||
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||||
|
|
||||||
|
assert not cfg.deep_norm or not cfg.layer_norm_first
|
||||||
|
self.encoder = TransformerEncoder(cfg)
|
||||||
|
self.layer_norm = LayerNorm(self.embed)
|
||||||
|
|
||||||
|
self.quantize = NormEMAVectorQuantizer(
|
||||||
|
n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
|
||||||
|
)
|
||||||
|
self.quant_n = cfg.quant_n
|
||||||
|
self.quantize_layer = nn.Sequential(
|
||||||
|
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_padding_mask(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
padding_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
extra = padding_mask.size(1) % features.size(1)
|
||||||
|
if extra > 0:
|
||||||
|
padding_mask = padding_mask[:, :-extra]
|
||||||
|
padding_mask = padding_mask.view(
|
||||||
|
padding_mask.size(0), features.size(1), -1
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask.all(-1)
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fbanks = []
|
||||||
|
for waveform in source:
|
||||||
|
waveform = waveform.unsqueeze(0) * 2 ** 15
|
||||||
|
fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
||||||
|
fbanks.append(fbank)
|
||||||
|
fbank = torch.stack(fbanks, dim=0)
|
||||||
|
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
||||||
|
return fbank
|
||||||
|
|
||||||
|
def extract_labels(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
fbank_mean: float = 15.41663,
|
||||||
|
fbank_std: float = 6.55582,
|
||||||
|
):
|
||||||
|
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
||||||
|
|
||||||
|
fbank = fbank.unsqueeze(1)
|
||||||
|
features = self.patch_embedding(fbank)
|
||||||
|
features = features.reshape(features.shape[0], features.shape[1], -1)
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
features = self.layer_norm(features)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||||
|
|
||||||
|
if self.post_extract_proj is not None:
|
||||||
|
features = self.post_extract_proj(features)
|
||||||
|
|
||||||
|
x = self.dropout_input(features)
|
||||||
|
|
||||||
|
x, layer_results = self.encoder(
|
||||||
|
x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
quantize_input = self.quantize_layer(x)
|
||||||
|
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
||||||
|
|
||||||
|
return embed_ind
|
782
opencompass/models/ola/model/speech_encoder/beats/backbone.py
Normal file
782
opencompass/models/ola/model/speech_encoder/beats/backbone.py
Normal file
@ -0,0 +1,782 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import LayerNorm, Parameter
|
||||||
|
from .modules import (
|
||||||
|
GradMultiply,
|
||||||
|
SamePad,
|
||||||
|
get_activation_fn,
|
||||||
|
GLU_Linear,
|
||||||
|
quant_noise,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.embedding_dim = args.encoder_embed_dim
|
||||||
|
|
||||||
|
self.pos_conv = nn.Conv1d(
|
||||||
|
self.embedding_dim,
|
||||||
|
self.embedding_dim,
|
||||||
|
kernel_size=args.conv_pos,
|
||||||
|
padding=args.conv_pos // 2,
|
||||||
|
groups=args.conv_pos_groups,
|
||||||
|
)
|
||||||
|
dropout = 0
|
||||||
|
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||||
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||||
|
nn.init.constant_(self.pos_conv.bias, 0)
|
||||||
|
|
||||||
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||||
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||||
|
|
||||||
|
if hasattr(args, "relative_position_embedding"):
|
||||||
|
self.relative_position_embedding = args.relative_position_embedding
|
||||||
|
self.num_buckets = args.num_buckets
|
||||||
|
self.max_distance = args.max_distance
|
||||||
|
else:
|
||||||
|
self.relative_position_embedding = False
|
||||||
|
self.num_buckets = 0
|
||||||
|
self.max_distance = 0
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerSentenceEncoderLayer(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||||
|
num_attention_heads=args.encoder_attention_heads,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=args.attention_dropout,
|
||||||
|
activation_dropout=args.activation_dropout,
|
||||||
|
activation_fn=args.activation_fn,
|
||||||
|
layer_norm_first=args.layer_norm_first,
|
||||||
|
deep_norm=args.deep_norm,
|
||||||
|
has_relative_attention_bias=self.relative_position_embedding,
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
gru_rel_pos=args.gru_rel_pos,
|
||||||
|
encoder_layers=args.encoder_layers,
|
||||||
|
)
|
||||||
|
for i in range(args.encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.relative_position_embedding:
|
||||||
|
for i in range(1, args.encoder_layers):
|
||||||
|
del self.layers[i].self_attn.relative_attention_bias
|
||||||
|
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
||||||
|
|
||||||
|
self.layer_norm_first = args.layer_norm_first
|
||||||
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
self.layerdrop = args.encoder_layerdrop
|
||||||
|
|
||||||
|
self.apply(init_bert_params)
|
||||||
|
|
||||||
|
if args.deep_norm:
|
||||||
|
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
||||||
|
for i in range(args.encoder_layers):
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
||||||
|
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
||||||
|
|
||||||
|
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask=None, layer=None):
|
||||||
|
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||||||
|
|
||||||
|
if self.layer_norm_first and layer is None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
x[padding_mask] = 0
|
||||||
|
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||||
|
x_conv = x_conv.transpose(1, 2)
|
||||||
|
x = x + x_conv
|
||||||
|
|
||||||
|
if not self.layer_norm_first:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# B x T x C -> T x B x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
layer_results = []
|
||||||
|
z = None
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
r = None
|
||||||
|
pos_bias = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
if self.layer_wise_gradient_decay_ratio != 1.0:
|
||||||
|
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
||||||
|
dropout_probability = np.random.random()
|
||||||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
if i == tgt_layer:
|
||||||
|
r = x
|
||||||
|
break
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
x = r
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSentenceEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: float = 768,
|
||||||
|
ffn_embedding_dim: float = 3072,
|
||||||
|
num_attention_heads: float = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
attention_dropout: float = 0.1,
|
||||||
|
activation_dropout: float = 0.1,
|
||||||
|
activation_fn: str = "relu",
|
||||||
|
layer_norm_first: bool = False,
|
||||||
|
deep_norm: bool = False,
|
||||||
|
has_relative_attention_bias: bool = False,
|
||||||
|
num_buckets: int = 0,
|
||||||
|
max_distance: int = 0,
|
||||||
|
rescale_init: bool = False,
|
||||||
|
gru_rel_pos: bool = False,
|
||||||
|
encoder_layers: int = 0,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
|
||||||
|
self.activation_name = activation_fn
|
||||||
|
self.activation_fn = get_activation_fn(activation_fn)
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
self.embedding_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
dropout=attention_dropout,
|
||||||
|
self_attention=True,
|
||||||
|
has_relative_attention_bias=has_relative_attention_bias,
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
max_distance=max_distance,
|
||||||
|
rescale_init=rescale_init,
|
||||||
|
gru_rel_pos=gru_rel_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.layer_norm_first = layer_norm_first
|
||||||
|
|
||||||
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||||
|
else:
|
||||||
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||||
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||||
|
|
||||||
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
self.deep_norm = deep_norm
|
||||||
|
if self.deep_norm:
|
||||||
|
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
||||||
|
else:
|
||||||
|
self.deep_norm_alpha = 1
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
self_attn_mask: torch.Tensor = None,
|
||||||
|
self_attn_padding_mask: torch.Tensor = None,
|
||||||
|
need_weights: bool = False,
|
||||||
|
pos_bias=None
|
||||||
|
):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
if self.layer_norm_first:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual + x
|
||||||
|
else:
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual * self.deep_norm_alpha + x
|
||||||
|
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual * self.deep_norm_alpha + x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
|
return x, attn, pos_bias
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
"""Multi-headed attention.
|
||||||
|
|
||||||
|
See "Attention Is All You Need" for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
self_attention=False,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
q_noise=0.0,
|
||||||
|
qn_block_size=8,
|
||||||
|
has_relative_attention_bias=False,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128,
|
||||||
|
gru_rel_pos=False,
|
||||||
|
rescale_init=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout_module = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.q_head_dim = self.head_dim
|
||||||
|
self.k_head_dim = self.head_dim
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.self_attention = self_attention
|
||||||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
|
||||||
|
assert not self.self_attention or self.qkv_same_dim, (
|
||||||
|
"Self-attention requires query, key and " "value to be of the same size"
|
||||||
|
)
|
||||||
|
|
||||||
|
k_bias = True
|
||||||
|
if rescale_init:
|
||||||
|
k_bias = False
|
||||||
|
|
||||||
|
k_embed_dim = embed_dim
|
||||||
|
q_embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.k_proj = quant_noise(
|
||||||
|
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.v_proj = quant_noise(
|
||||||
|
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.q_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self.gru_rel_pos = gru_rel_pos
|
||||||
|
if self.gru_rel_pos:
|
||||||
|
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||||
|
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if self.qkv_same_dim:
|
||||||
|
# Empirically observed the convergence to be much better with
|
||||||
|
# the scaled initialization
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
else:
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
if self.out_proj.bias is not None:
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
if self.bias_k is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_v)
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||||
|
|
||||||
|
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||||
|
num_buckets = self.num_buckets
|
||||||
|
max_distance = self.max_distance
|
||||||
|
relative_buckets = 0
|
||||||
|
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets = num_buckets // 2
|
||||||
|
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||||
|
relative_positions = torch.abs(relative_positions)
|
||||||
|
else:
|
||||||
|
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_positions < max_exact
|
||||||
|
|
||||||
|
relative_postion_if_large = max_exact + (
|
||||||
|
torch.log(relative_positions.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
relative_postion_if_large = torch.min(
|
||||||
|
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length):
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||||
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
relative_position_bucket = self._relative_positions_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=True
|
||||||
|
)
|
||||||
|
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket)
|
||||||
|
values = values.permute([2, 0, 1])
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key: Optional[Tensor],
|
||||||
|
value: Optional[Tensor],
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
static_kv: bool = False,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
before_softmax: bool = False,
|
||||||
|
need_head_weights: bool = False,
|
||||||
|
position_bias: Optional[Tensor] = None
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||||
|
"""Input shape: Time x Batch x Channel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||||
|
keys that are pads, of shape `(batch, src_len)`, where
|
||||||
|
padding elements are indicated by 1s.
|
||||||
|
need_weights (bool, optional): return the attention weights,
|
||||||
|
averaged over heads (default: False).
|
||||||
|
attn_mask (ByteTensor, optional): typically used to
|
||||||
|
implement causal attention, where the mask prevents the
|
||||||
|
attention from looking forward in time (default: None).
|
||||||
|
before_softmax (bool, optional): return the raw attention
|
||||||
|
weights and values before the attention softmax.
|
||||||
|
need_head_weights (bool, optional): return the attention
|
||||||
|
weights for each head. Implies *need_weights*. Default:
|
||||||
|
return the average attention weights over all heads.
|
||||||
|
"""
|
||||||
|
if need_head_weights:
|
||||||
|
need_weights = True
|
||||||
|
|
||||||
|
is_tpu = query.device.type == "xla"
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
|
src_len = tgt_len
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||||
|
if key is not None:
|
||||||
|
src_len, key_bsz, _ = key.size()
|
||||||
|
if not torch.jit.is_scripting():
|
||||||
|
assert key_bsz == bsz
|
||||||
|
assert value is not None
|
||||||
|
assert src_len, bsz == value.shape[:2]
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias and position_bias is None:
|
||||||
|
position_bias = self.compute_bias(tgt_len, src_len)
|
||||||
|
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
saved_state = self._get_input_buffer(incremental_state)
|
||||||
|
if saved_state is not None and "prev_key" in saved_state:
|
||||||
|
# previous time steps are cached - no need to recompute
|
||||||
|
# key and value if they are static
|
||||||
|
if static_kv:
|
||||||
|
assert self.encoder_decoder_attention and not self.self_attention
|
||||||
|
key = value = None
|
||||||
|
else:
|
||||||
|
saved_state = None
|
||||||
|
|
||||||
|
if self.self_attention:
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(query)
|
||||||
|
v = self.v_proj(query)
|
||||||
|
elif self.encoder_decoder_attention:
|
||||||
|
# encoder-decoder attention
|
||||||
|
q = self.q_proj(query)
|
||||||
|
if key is None:
|
||||||
|
assert value is None
|
||||||
|
k = v = None
|
||||||
|
else:
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(key)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert key is not None and value is not None
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
q *= self.scaling
|
||||||
|
alpha = 32
|
||||||
|
q *= 1 / alpha
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
assert self.bias_v is not None
|
||||||
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||||
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||||
|
)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = (
|
||||||
|
q.contiguous()
|
||||||
|
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if k is not None:
|
||||||
|
k = (
|
||||||
|
k.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if v is not None:
|
||||||
|
v = (
|
||||||
|
v.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if saved_state is not None:
|
||||||
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||||
|
if "prev_key" in saved_state:
|
||||||
|
_prev_key = saved_state["prev_key"]
|
||||||
|
assert _prev_key is not None
|
||||||
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
k = prev_key
|
||||||
|
else:
|
||||||
|
assert k is not None
|
||||||
|
k = torch.cat([prev_key, k], dim=1)
|
||||||
|
src_len = k.size(1)
|
||||||
|
if "prev_value" in saved_state:
|
||||||
|
_prev_value = saved_state["prev_value"]
|
||||||
|
assert _prev_value is not None
|
||||||
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
v = prev_value
|
||||||
|
else:
|
||||||
|
assert v is not None
|
||||||
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
|
prev_key_padding_mask: Optional[Tensor] = None
|
||||||
|
if "prev_key_padding_mask" in saved_state:
|
||||||
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||||
|
assert k is not None and v is not None
|
||||||
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
prev_key_padding_mask=prev_key_padding_mask,
|
||||||
|
batch_size=bsz,
|
||||||
|
src_len=k.size(1),
|
||||||
|
static_kv=static_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||||
|
# In this branch incremental_state is never None
|
||||||
|
assert incremental_state is not None
|
||||||
|
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||||
|
assert k is not None
|
||||||
|
assert k.size(1) == src_len
|
||||||
|
|
||||||
|
# This is part of a workaround to get around fork/join parallelism
|
||||||
|
# not supporting Optional types.
|
||||||
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||||
|
key_padding_mask = None
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.size(0) == bsz
|
||||||
|
assert key_padding_mask.size(1) == src_len
|
||||||
|
|
||||||
|
if self.add_zero_attn:
|
||||||
|
assert v is not None
|
||||||
|
src_len += 1
|
||||||
|
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||||
|
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||||
|
)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
||||||
|
key_padding_mask
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
||||||
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||||
|
|
||||||
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# don't attend to padding symbols
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
if not is_tpu:
|
||||||
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if before_softmax:
|
||||||
|
return attn_weights, v, position_bias
|
||||||
|
|
||||||
|
if position_bias is not None:
|
||||||
|
attn_mask_rel_pos = position_bias
|
||||||
|
if self.gru_rel_pos == 1:
|
||||||
|
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
||||||
|
_B, _H, _L, __ = query_layer.size()
|
||||||
|
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
||||||
|
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
||||||
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||||
|
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
||||||
|
|
||||||
|
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attn_mask_rel_pos
|
||||||
|
|
||||||
|
attn_weights_float = F.softmax(
|
||||||
|
attn_weights, dim=-1
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||||
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
|
assert v is not None
|
||||||
|
attn = torch.bmm(attn_probs, v)
|
||||||
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
|
attn = self.out_proj(attn)
|
||||||
|
attn_weights: Optional[Tensor] = None
|
||||||
|
if need_weights:
|
||||||
|
attn_weights = attn_weights_float.view(
|
||||||
|
bsz, self.num_heads, tgt_len, src_len
|
||||||
|
).transpose(1, 0)
|
||||||
|
if not need_head_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_weights = attn_weights.mean(dim=0)
|
||||||
|
|
||||||
|
return attn, attn_weights, position_bias
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_prev_key_padding_mask(
|
||||||
|
key_padding_mask: Optional[Tensor],
|
||||||
|
prev_key_padding_mask: Optional[Tensor],
|
||||||
|
batch_size: int,
|
||||||
|
src_len: int,
|
||||||
|
static_kv: bool,
|
||||||
|
) -> Optional[Tensor]:
|
||||||
|
# saved key padding masks have shape (bsz, seq_len)
|
||||||
|
if prev_key_padding_mask is not None and static_kv:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||||
|
)
|
||||||
|
# During incremental decoding, as the padding token enters and
|
||||||
|
# leaves the frame, there will be a time when prev or current
|
||||||
|
# is None
|
||||||
|
elif prev_key_padding_mask is not None:
|
||||||
|
if src_len > prev_key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||||
|
device=prev_key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask.float()
|
||||||
|
elif key_padding_mask is not None:
|
||||||
|
if src_len > key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - key_padding_mask.size(1)),
|
||||||
|
device=key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat(
|
||||||
|
[filler.float(), key_padding_mask.float()], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = key_padding_mask.float()
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
return new_key_padding_mask
|
||||||
|
|
||||||
|
def _get_input_buffer(
|
||||||
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||||
|
) -> Dict[str, Optional[Tensor]]:
|
||||||
|
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||||
|
return empty_result
|
||||||
|
|
||||||
|
def _set_input_buffer(
|
||||||
|
self,
|
||||||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||||
|
buffer: Dict[str, Optional[Tensor]],
|
||||||
|
):
|
||||||
|
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||||
|
|
||||||
|
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||||
|
return attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def init_bert_params(module):
|
||||||
|
"""
|
||||||
|
Initialize the weights specific to the BERT Model.
|
||||||
|
This overrides the default initializations depending on the specified arguments.
|
||||||
|
1. If normal_init_linear_weights is set then weights of linear
|
||||||
|
layer will be initialized using the normal distribution and
|
||||||
|
bais will be set to the specified value.
|
||||||
|
2. If normal_init_embed_weights is set then weights of embedding
|
||||||
|
layer will be initialized using the normal distribution.
|
||||||
|
3. If normal_init_proj_weights is set then weights of
|
||||||
|
in_project_weight for MultiHeadAttention initialized using
|
||||||
|
the normal distribution (to be validated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def normal_(data):
|
||||||
|
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||||
|
# so that the RNG is consistent with and without FSDP
|
||||||
|
data.copy_(
|
||||||
|
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
if isinstance(module, MultiheadAttention):
|
||||||
|
normal_(module.q_proj.weight.data)
|
||||||
|
normal_(module.k_proj.weight.data)
|
||||||
|
normal_(module.v_proj.weight.data)
|
813
opencompass/models/ola/model/speech_encoder/beats/kaldi.py
Normal file
813
opencompass/models/ola/model/speech_encoder/beats/kaldi.py
Normal file
@ -0,0 +1,813 @@
|
|||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
# import torchaudio
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_mel_banks",
|
||||||
|
"inverse_mel_scale",
|
||||||
|
"inverse_mel_scale_scalar",
|
||||||
|
"mel_scale",
|
||||||
|
"mel_scale_scalar",
|
||||||
|
"spectrogram",
|
||||||
|
"fbank",
|
||||||
|
"mfcc",
|
||||||
|
"vtln_warp_freq",
|
||||||
|
"vtln_warp_mel_freq",
|
||||||
|
]
|
||||||
|
|
||||||
|
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
|
||||||
|
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
|
||||||
|
# 1 milliseconds = 0.001 seconds
|
||||||
|
MILLISECONDS_TO_SECONDS = 0.001
|
||||||
|
|
||||||
|
# window types
|
||||||
|
HAMMING = "hamming"
|
||||||
|
HANNING = "hanning"
|
||||||
|
POVEY = "povey"
|
||||||
|
RECTANGULAR = "rectangular"
|
||||||
|
BLACKMAN = "blackman"
|
||||||
|
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_epsilon(device, dtype):
|
||||||
|
return EPSILON.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _next_power_of_2(x: int) -> int:
|
||||||
|
r"""Returns the smallest power of 2 that is greater than x"""
|
||||||
|
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
|
||||||
|
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
||||||
|
representing how the window is shifted along the waveform. Each row is a frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): Tensor of size ``num_samples``
|
||||||
|
window_size (int): Frame length
|
||||||
|
window_shift (int): Frame shift
|
||||||
|
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
||||||
|
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||||
|
depends only on the frame_shift, and we reflect the data at the ends.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
||||||
|
"""
|
||||||
|
assert waveform.dim() == 1
|
||||||
|
num_samples = waveform.size(0)
|
||||||
|
strides = (window_shift * waveform.stride(0), waveform.stride(0))
|
||||||
|
|
||||||
|
if snip_edges:
|
||||||
|
if num_samples < window_size:
|
||||||
|
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
||||||
|
else:
|
||||||
|
m = 1 + (num_samples - window_size) // window_shift
|
||||||
|
else:
|
||||||
|
reversed_waveform = torch.flip(waveform, [0])
|
||||||
|
m = (num_samples + (window_shift // 2)) // window_shift
|
||||||
|
pad = window_size // 2 - window_shift // 2
|
||||||
|
pad_right = reversed_waveform
|
||||||
|
if pad > 0:
|
||||||
|
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
|
||||||
|
# but we want [2, 1, 0, 0, 1, 2]
|
||||||
|
pad_left = reversed_waveform[-pad:]
|
||||||
|
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
|
||||||
|
else:
|
||||||
|
# pad is negative so we want to trim the waveform at the front
|
||||||
|
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
|
||||||
|
|
||||||
|
sizes = (m, window_size)
|
||||||
|
return waveform.as_strided(sizes, strides)
|
||||||
|
|
||||||
|
|
||||||
|
def _feature_window_function(
|
||||||
|
window_type: str,
|
||||||
|
window_size: int,
|
||||||
|
blackman_coeff: float,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: int,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Returns a window function with the given type and size"""
|
||||||
|
if window_type == HANNING:
|
||||||
|
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
|
||||||
|
elif window_type == HAMMING:
|
||||||
|
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
|
||||||
|
elif window_type == POVEY:
|
||||||
|
# like hanning but goes to zero at edges
|
||||||
|
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
|
||||||
|
elif window_type == RECTANGULAR:
|
||||||
|
return torch.ones(window_size, device=device, dtype=dtype)
|
||||||
|
elif window_type == BLACKMAN:
|
||||||
|
a = 2 * math.pi / (window_size - 1)
|
||||||
|
window_function = torch.arange(window_size, device=device, dtype=dtype)
|
||||||
|
# can't use torch.blackman_window as they use different coefficients
|
||||||
|
return (
|
||||||
|
blackman_coeff
|
||||||
|
- 0.5 * torch.cos(a * window_function)
|
||||||
|
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid window type " + window_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
|
||||||
|
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
||||||
|
device, dtype = strided_input.device, strided_input.dtype
|
||||||
|
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
||||||
|
if energy_floor == 0.0:
|
||||||
|
return log_energy
|
||||||
|
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_waveform_and_window_properties(
|
||||||
|
waveform: Tensor,
|
||||||
|
channel: int,
|
||||||
|
sample_frequency: float,
|
||||||
|
frame_shift: float,
|
||||||
|
frame_length: float,
|
||||||
|
round_to_power_of_two: bool,
|
||||||
|
preemphasis_coefficient: float,
|
||||||
|
) -> Tuple[Tensor, int, int, int]:
|
||||||
|
r"""Gets the waveform and window properties"""
|
||||||
|
channel = max(channel, 0)
|
||||||
|
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
|
||||||
|
waveform = waveform[channel, :] # size (n)
|
||||||
|
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
||||||
|
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
||||||
|
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
|
||||||
|
|
||||||
|
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
|
||||||
|
window_size, len(waveform)
|
||||||
|
)
|
||||||
|
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
||||||
|
assert padded_window_size % 2 == 0, (
|
||||||
|
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
|
||||||
|
)
|
||||||
|
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
||||||
|
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
||||||
|
return waveform, window_shift, window_size, padded_window_size
|
||||||
|
|
||||||
|
|
||||||
|
def _get_window(
|
||||||
|
waveform: Tensor,
|
||||||
|
padded_window_size: int,
|
||||||
|
window_size: int,
|
||||||
|
window_shift: int,
|
||||||
|
window_type: str,
|
||||||
|
blackman_coeff: float,
|
||||||
|
snip_edges: bool,
|
||||||
|
raw_energy: bool,
|
||||||
|
energy_floor: float,
|
||||||
|
dither: float,
|
||||||
|
remove_dc_offset: bool,
|
||||||
|
preemphasis_coefficient: float,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
r"""Gets a window and its log energy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
||||||
|
"""
|
||||||
|
device, dtype = waveform.device, waveform.dtype
|
||||||
|
epsilon = _get_epsilon(device, dtype)
|
||||||
|
|
||||||
|
# size (m, window_size)
|
||||||
|
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
||||||
|
|
||||||
|
if dither != 0.0:
|
||||||
|
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
|
||||||
|
strided_input = strided_input + rand_gauss * dither
|
||||||
|
|
||||||
|
if remove_dc_offset:
|
||||||
|
# Subtract each row/frame by its mean
|
||||||
|
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
||||||
|
strided_input = strided_input - row_means
|
||||||
|
|
||||||
|
if raw_energy:
|
||||||
|
# Compute the log energy of each row/frame before applying preemphasis and
|
||||||
|
# window function
|
||||||
|
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
||||||
|
|
||||||
|
if preemphasis_coefficient != 0.0:
|
||||||
|
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
||||||
|
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
|
||||||
|
0
|
||||||
|
) # size (m, window_size + 1)
|
||||||
|
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
||||||
|
|
||||||
|
# Apply window_function to each row/frame
|
||||||
|
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
|
||||||
|
0
|
||||||
|
) # size (1, window_size)
|
||||||
|
strided_input = strided_input * window_function # size (m, window_size)
|
||||||
|
|
||||||
|
# Pad columns with zero until we reach size (m, padded_window_size)
|
||||||
|
if padded_window_size != window_size:
|
||||||
|
padding_right = padded_window_size - window_size
|
||||||
|
strided_input = torch.nn.functional.pad(
|
||||||
|
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
# Compute energy after window function (not the raw one)
|
||||||
|
if not raw_energy:
|
||||||
|
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
||||||
|
|
||||||
|
return strided_input, signal_log_energy
|
||||||
|
|
||||||
|
|
||||||
|
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
||||||
|
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
||||||
|
# it returns size (m, n)
|
||||||
|
if subtract_mean:
|
||||||
|
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
|
||||||
|
tensor = tensor - col_means
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram(
|
||||||
|
waveform: Tensor,
|
||||||
|
blackman_coeff: float = 0.42,
|
||||||
|
channel: int = -1,
|
||||||
|
dither: float = 0.0,
|
||||||
|
energy_floor: float = 1.0,
|
||||||
|
frame_length: float = 25.0,
|
||||||
|
frame_shift: float = 10.0,
|
||||||
|
min_duration: float = 0.0,
|
||||||
|
preemphasis_coefficient: float = 0.97,
|
||||||
|
raw_energy: bool = True,
|
||||||
|
remove_dc_offset: bool = True,
|
||||||
|
round_to_power_of_two: bool = True,
|
||||||
|
sample_frequency: float = 16000.0,
|
||||||
|
snip_edges: bool = True,
|
||||||
|
subtract_mean: bool = False,
|
||||||
|
window_type: str = POVEY,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
||||||
|
compute-spectrogram-feats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||||
|
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||||
|
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||||
|
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||||
|
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||||
|
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||||
|
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||||
|
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||||
|
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||||
|
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||||
|
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||||
|
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||||
|
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||||
|
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||||
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||||
|
to FFT. (Default: ``True``)
|
||||||
|
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||||
|
specified there) (Default: ``16000.0``)
|
||||||
|
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||||
|
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||||
|
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||||
|
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||||
|
it this way. (Default: ``False``)
|
||||||
|
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||||
|
(Default: ``'povey'``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
||||||
|
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
||||||
|
"""
|
||||||
|
device, dtype = waveform.device, waveform.dtype
|
||||||
|
epsilon = _get_epsilon(device, dtype)
|
||||||
|
|
||||||
|
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||||
|
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(waveform) < min_duration * sample_frequency:
|
||||||
|
# signal is too short
|
||||||
|
return torch.empty(0)
|
||||||
|
|
||||||
|
strided_input, signal_log_energy = _get_window(
|
||||||
|
waveform,
|
||||||
|
padded_window_size,
|
||||||
|
window_size,
|
||||||
|
window_shift,
|
||||||
|
window_type,
|
||||||
|
blackman_coeff,
|
||||||
|
snip_edges,
|
||||||
|
raw_energy,
|
||||||
|
energy_floor,
|
||||||
|
dither,
|
||||||
|
remove_dc_offset,
|
||||||
|
preemphasis_coefficient,
|
||||||
|
)
|
||||||
|
|
||||||
|
# size (m, padded_window_size // 2 + 1, 2)
|
||||||
|
fft = torch.fft.rfft(strided_input)
|
||||||
|
|
||||||
|
# Convert the FFT into a power spectrum
|
||||||
|
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
|
||||||
|
power_spectrum[:, 0] = signal_log_energy
|
||||||
|
|
||||||
|
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
||||||
|
return power_spectrum
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
||||||
|
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
||||||
|
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def mel_scale_scalar(freq: float) -> float:
|
||||||
|
return 1127.0 * math.log(1.0 + freq / 700.0)
|
||||||
|
|
||||||
|
|
||||||
|
def mel_scale(freq: Tensor) -> Tensor:
|
||||||
|
return 1127.0 * (1.0 + freq / 700.0).log()
|
||||||
|
|
||||||
|
|
||||||
|
def vtln_warp_freq(
|
||||||
|
vtln_low_cutoff: float,
|
||||||
|
vtln_high_cutoff: float,
|
||||||
|
low_freq: float,
|
||||||
|
high_freq: float,
|
||||||
|
vtln_warp_factor: float,
|
||||||
|
freq: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""This computes a VTLN warping function that is not the same as HTK's one,
|
||||||
|
but has similar inputs (this function has the advantage of never producing
|
||||||
|
empty bins).
|
||||||
|
|
||||||
|
This function computes a warp function F(freq), defined between low_freq
|
||||||
|
and high_freq inclusive, with the following properties:
|
||||||
|
F(low_freq) == low_freq
|
||||||
|
F(high_freq) == high_freq
|
||||||
|
The function is continuous and piecewise linear with two inflection
|
||||||
|
points.
|
||||||
|
The lower inflection point (measured in terms of the unwarped
|
||||||
|
frequency) is at frequency l, determined as described below.
|
||||||
|
The higher inflection point is at a frequency h, determined as
|
||||||
|
described below.
|
||||||
|
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
||||||
|
If the higher inflection point (measured in terms of the unwarped
|
||||||
|
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
||||||
|
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
||||||
|
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
||||||
|
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
||||||
|
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
||||||
|
If the lower inflection point (measured in terms of the unwarped
|
||||||
|
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
||||||
|
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
||||||
|
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
||||||
|
Args:
|
||||||
|
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||||
|
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||||
|
low_freq (float): Lower frequency cutoffs in mel computation
|
||||||
|
high_freq (float): Upper frequency cutoffs in mel computation
|
||||||
|
vtln_warp_factor (float): Vtln warp factor
|
||||||
|
freq (Tensor): given frequency in Hz
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Freq after vtln warp
|
||||||
|
"""
|
||||||
|
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
|
||||||
|
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
|
||||||
|
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
||||||
|
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
||||||
|
scale = 1.0 / vtln_warp_factor
|
||||||
|
Fl = scale * l # F(l)
|
||||||
|
Fh = scale * h # F(h)
|
||||||
|
assert l > low_freq and h < high_freq
|
||||||
|
# slope of left part of the 3-piece linear function
|
||||||
|
scale_left = (Fl - low_freq) / (l - low_freq)
|
||||||
|
# [slope of center part is just "scale"]
|
||||||
|
|
||||||
|
# slope of right part of the 3-piece linear function
|
||||||
|
scale_right = (high_freq - Fh) / (high_freq - h)
|
||||||
|
|
||||||
|
res = torch.empty_like(freq)
|
||||||
|
|
||||||
|
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
|
||||||
|
before_l = torch.lt(freq, l) # freq < l
|
||||||
|
before_h = torch.lt(freq, h) # freq < h
|
||||||
|
after_h = torch.ge(freq, h) # freq >= h
|
||||||
|
|
||||||
|
# order of operations matter here (since there is overlapping frequency regions)
|
||||||
|
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
||||||
|
res[before_h] = scale * freq[before_h]
|
||||||
|
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
||||||
|
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def vtln_warp_mel_freq(
|
||||||
|
vtln_low_cutoff: float,
|
||||||
|
vtln_high_cutoff: float,
|
||||||
|
low_freq,
|
||||||
|
high_freq: float,
|
||||||
|
vtln_warp_factor: float,
|
||||||
|
mel_freq: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||||
|
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||||
|
low_freq (float): Lower frequency cutoffs in mel computation
|
||||||
|
high_freq (float): Upper frequency cutoffs in mel computation
|
||||||
|
vtln_warp_factor (float): Vtln warp factor
|
||||||
|
mel_freq (Tensor): Given frequency in Mel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: ``mel_freq`` after vtln warp
|
||||||
|
"""
|
||||||
|
return mel_scale(
|
||||||
|
vtln_warp_freq(
|
||||||
|
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mel_banks(
|
||||||
|
num_bins: int,
|
||||||
|
window_length_padded: int,
|
||||||
|
sample_freq: float,
|
||||||
|
low_freq: float,
|
||||||
|
high_freq: float,
|
||||||
|
vtln_low: float,
|
||||||
|
vtln_high: float,
|
||||||
|
vtln_warp_factor: float,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
||||||
|
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
||||||
|
center frequencies of bins of size (``num_bins``)).
|
||||||
|
"""
|
||||||
|
assert num_bins > 3, "Must have at least 3 mel bins"
|
||||||
|
assert window_length_padded % 2 == 0
|
||||||
|
num_fft_bins = window_length_padded / 2
|
||||||
|
nyquist = 0.5 * sample_freq
|
||||||
|
|
||||||
|
if high_freq <= 0.0:
|
||||||
|
high_freq += nyquist
|
||||||
|
|
||||||
|
assert (
|
||||||
|
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
|
||||||
|
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
||||||
|
|
||||||
|
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
||||||
|
fft_bin_width = sample_freq / window_length_padded
|
||||||
|
mel_low_freq = mel_scale_scalar(low_freq)
|
||||||
|
mel_high_freq = mel_scale_scalar(high_freq)
|
||||||
|
|
||||||
|
# divide by num_bins+1 in next line because of end-effects where the bins
|
||||||
|
# spread out to the sides.
|
||||||
|
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||||
|
|
||||||
|
if vtln_high < 0.0:
|
||||||
|
vtln_high += nyquist
|
||||||
|
|
||||||
|
assert vtln_warp_factor == 1.0 or (
|
||||||
|
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
||||||
|
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
|
||||||
|
vtln_low, vtln_high, low_freq, high_freq
|
||||||
|
)
|
||||||
|
|
||||||
|
bin = torch.arange(num_bins).unsqueeze(1)
|
||||||
|
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
||||||
|
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
|
||||||
|
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
||||||
|
|
||||||
|
if vtln_warp_factor != 1.0:
|
||||||
|
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
|
||||||
|
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
|
||||||
|
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
|
||||||
|
|
||||||
|
center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
|
||||||
|
# size(1, num_fft_bins)
|
||||||
|
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
|
||||||
|
|
||||||
|
# size (num_bins, num_fft_bins)
|
||||||
|
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||||
|
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||||
|
|
||||||
|
if vtln_warp_factor == 1.0:
|
||||||
|
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
||||||
|
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
|
||||||
|
else:
|
||||||
|
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
||||||
|
bins = torch.zeros_like(up_slope)
|
||||||
|
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
|
||||||
|
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
|
||||||
|
bins[up_idx] = up_slope[up_idx]
|
||||||
|
bins[down_idx] = down_slope[down_idx]
|
||||||
|
|
||||||
|
return bins, center_freqs
|
||||||
|
|
||||||
|
|
||||||
|
def fbank(
|
||||||
|
waveform: Tensor,
|
||||||
|
blackman_coeff: float = 0.42,
|
||||||
|
channel: int = -1,
|
||||||
|
dither: float = 0.0,
|
||||||
|
energy_floor: float = 1.0,
|
||||||
|
frame_length: float = 25.0,
|
||||||
|
frame_shift: float = 10.0,
|
||||||
|
high_freq: float = 0.0,
|
||||||
|
htk_compat: bool = False,
|
||||||
|
low_freq: float = 20.0,
|
||||||
|
min_duration: float = 0.0,
|
||||||
|
num_mel_bins: int = 23,
|
||||||
|
preemphasis_coefficient: float = 0.97,
|
||||||
|
raw_energy: bool = True,
|
||||||
|
remove_dc_offset: bool = True,
|
||||||
|
round_to_power_of_two: bool = True,
|
||||||
|
sample_frequency: float = 16000.0,
|
||||||
|
snip_edges: bool = True,
|
||||||
|
subtract_mean: bool = False,
|
||||||
|
use_energy: bool = False,
|
||||||
|
use_log_fbank: bool = True,
|
||||||
|
use_power: bool = True,
|
||||||
|
vtln_high: float = -500.0,
|
||||||
|
vtln_low: float = 100.0,
|
||||||
|
vtln_warp: float = 1.0,
|
||||||
|
window_type: str = POVEY,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
|
||||||
|
compute-fbank-feats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||||
|
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||||
|
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||||
|
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||||
|
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||||
|
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||||
|
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||||
|
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||||
|
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||||
|
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||||
|
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
||||||
|
(Default: ``0.0``)
|
||||||
|
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
|
||||||
|
(need to change other parameters). (Default: ``False``)
|
||||||
|
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
||||||
|
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||||
|
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
||||||
|
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||||
|
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||||
|
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||||
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||||
|
to FFT. (Default: ``True``)
|
||||||
|
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||||
|
specified there) (Default: ``16000.0``)
|
||||||
|
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||||
|
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||||
|
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||||
|
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||||
|
it this way. (Default: ``False``)
|
||||||
|
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
||||||
|
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
|
||||||
|
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
|
||||||
|
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
||||||
|
negative, offset from high-mel-freq (Default: ``-500.0``)
|
||||||
|
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
||||||
|
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
||||||
|
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||||
|
(Default: ``'povey'``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
|
||||||
|
where m is calculated in _get_strided
|
||||||
|
"""
|
||||||
|
device, dtype = waveform.device, waveform.dtype
|
||||||
|
|
||||||
|
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||||
|
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(waveform) < min_duration * sample_frequency:
|
||||||
|
# signal is too short
|
||||||
|
return torch.empty(0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
||||||
|
strided_input, signal_log_energy = _get_window(
|
||||||
|
waveform,
|
||||||
|
padded_window_size,
|
||||||
|
window_size,
|
||||||
|
window_shift,
|
||||||
|
window_type,
|
||||||
|
blackman_coeff,
|
||||||
|
snip_edges,
|
||||||
|
raw_energy,
|
||||||
|
energy_floor,
|
||||||
|
dither,
|
||||||
|
remove_dc_offset,
|
||||||
|
preemphasis_coefficient,
|
||||||
|
)
|
||||||
|
|
||||||
|
# size (m, padded_window_size // 2 + 1)
|
||||||
|
spectrum = torch.fft.rfft(strided_input).abs()
|
||||||
|
if use_power:
|
||||||
|
spectrum = spectrum.pow(2.0)
|
||||||
|
|
||||||
|
# size (num_mel_bins, padded_window_size // 2)
|
||||||
|
mel_energies, _ = get_mel_banks(
|
||||||
|
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
|
||||||
|
)
|
||||||
|
mel_energies = mel_energies.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
||||||
|
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
||||||
|
|
||||||
|
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
||||||
|
mel_energies = torch.mm(spectrum, mel_energies.T)
|
||||||
|
if use_log_fbank:
|
||||||
|
# avoid log of zero (which should be prevented anyway by dithering)
|
||||||
|
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
||||||
|
|
||||||
|
# if use_energy then add it as the last column for htk_compat == true else first column
|
||||||
|
if use_energy:
|
||||||
|
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
||||||
|
# returns size (m, num_mel_bins + 1)
|
||||||
|
if htk_compat:
|
||||||
|
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
|
||||||
|
else:
|
||||||
|
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
|
||||||
|
|
||||||
|
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
||||||
|
return mel_energies
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
||||||
|
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
||||||
|
# size (num_mel_bins, num_mel_bins)
|
||||||
|
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
|
||||||
|
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
||||||
|
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
||||||
|
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
||||||
|
# expects a left multiply e.g. dct_matrix * vector).
|
||||||
|
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
||||||
|
dct_matrix = dct_matrix[:, :num_ceps]
|
||||||
|
return dct_matrix
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
||||||
|
# returns size (num_ceps)
|
||||||
|
# Compute liftering coefficients (scaling on cepstral coeffs)
|
||||||
|
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
||||||
|
i = torch.arange(num_ceps)
|
||||||
|
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
|
||||||
|
|
||||||
|
|
||||||
|
def mfcc(
|
||||||
|
waveform: Tensor,
|
||||||
|
blackman_coeff: float = 0.42,
|
||||||
|
cepstral_lifter: float = 22.0,
|
||||||
|
channel: int = -1,
|
||||||
|
dither: float = 0.0,
|
||||||
|
energy_floor: float = 1.0,
|
||||||
|
frame_length: float = 25.0,
|
||||||
|
frame_shift: float = 10.0,
|
||||||
|
high_freq: float = 0.0,
|
||||||
|
htk_compat: bool = False,
|
||||||
|
low_freq: float = 20.0,
|
||||||
|
num_ceps: int = 13,
|
||||||
|
min_duration: float = 0.0,
|
||||||
|
num_mel_bins: int = 23,
|
||||||
|
preemphasis_coefficient: float = 0.97,
|
||||||
|
raw_energy: bool = True,
|
||||||
|
remove_dc_offset: bool = True,
|
||||||
|
round_to_power_of_two: bool = True,
|
||||||
|
sample_frequency: float = 16000.0,
|
||||||
|
snip_edges: bool = True,
|
||||||
|
subtract_mean: bool = False,
|
||||||
|
use_energy: bool = False,
|
||||||
|
vtln_high: float = -500.0,
|
||||||
|
vtln_low: float = 100.0,
|
||||||
|
vtln_warp: float = 1.0,
|
||||||
|
window_type: str = POVEY,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
||||||
|
compute-mfcc-feats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||||
|
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||||
|
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
||||||
|
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||||
|
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||||
|
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||||
|
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||||
|
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||||
|
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||||
|
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||||
|
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||||
|
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
||||||
|
(Default: ``0.0``)
|
||||||
|
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
||||||
|
features (need to change other parameters). (Default: ``False``)
|
||||||
|
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
||||||
|
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
||||||
|
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||||
|
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
||||||
|
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||||
|
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||||
|
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||||
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||||
|
to FFT. (Default: ``True``)
|
||||||
|
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||||
|
specified there) (Default: ``16000.0``)
|
||||||
|
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||||
|
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||||
|
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||||
|
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||||
|
it this way. (Default: ``False``)
|
||||||
|
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
||||||
|
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
||||||
|
negative, offset from high-mel-freq (Default: ``-500.0``)
|
||||||
|
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
||||||
|
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
||||||
|
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||||
|
(Default: ``"povey"``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
||||||
|
where m is calculated in _get_strided
|
||||||
|
"""
|
||||||
|
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
|
||||||
|
|
||||||
|
device, dtype = waveform.device, waveform.dtype
|
||||||
|
|
||||||
|
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
||||||
|
# (subtract_mean=False), and use log (use_log_fbank=True).
|
||||||
|
# size (m, num_mel_bins + use_energy)
|
||||||
|
feature = fbank(
|
||||||
|
waveform=waveform,
|
||||||
|
blackman_coeff=blackman_coeff,
|
||||||
|
channel=channel,
|
||||||
|
dither=dither,
|
||||||
|
energy_floor=energy_floor,
|
||||||
|
frame_length=frame_length,
|
||||||
|
frame_shift=frame_shift,
|
||||||
|
high_freq=high_freq,
|
||||||
|
htk_compat=htk_compat,
|
||||||
|
low_freq=low_freq,
|
||||||
|
min_duration=min_duration,
|
||||||
|
num_mel_bins=num_mel_bins,
|
||||||
|
preemphasis_coefficient=preemphasis_coefficient,
|
||||||
|
raw_energy=raw_energy,
|
||||||
|
remove_dc_offset=remove_dc_offset,
|
||||||
|
round_to_power_of_two=round_to_power_of_two,
|
||||||
|
sample_frequency=sample_frequency,
|
||||||
|
snip_edges=snip_edges,
|
||||||
|
subtract_mean=False,
|
||||||
|
use_energy=use_energy,
|
||||||
|
use_log_fbank=True,
|
||||||
|
use_power=True,
|
||||||
|
vtln_high=vtln_high,
|
||||||
|
vtln_low=vtln_low,
|
||||||
|
vtln_warp=vtln_warp,
|
||||||
|
window_type=window_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_energy:
|
||||||
|
# size (m)
|
||||||
|
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
||||||
|
# offset is 0 if htk_compat==True else 1
|
||||||
|
mel_offset = int(not htk_compat)
|
||||||
|
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
|
||||||
|
|
||||||
|
# size (num_mel_bins, num_ceps)
|
||||||
|
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# size (m, num_ceps)
|
||||||
|
feature = feature.matmul(dct_matrix)
|
||||||
|
|
||||||
|
if cepstral_lifter != 0.0:
|
||||||
|
# size (1, num_ceps)
|
||||||
|
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
||||||
|
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# if use_energy then replace the last column for htk_compat == true else first column
|
||||||
|
if use_energy:
|
||||||
|
feature[:, 0] = signal_log_energy
|
||||||
|
|
||||||
|
if htk_compat:
|
||||||
|
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
||||||
|
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
||||||
|
if not use_energy:
|
||||||
|
# scale on C0 (actually removing a scale we previously added that's
|
||||||
|
# part of one common definition of the cosine transform.)
|
||||||
|
energy *= math.sqrt(2)
|
||||||
|
|
||||||
|
feature = torch.cat((feature, energy), dim=1)
|
||||||
|
|
||||||
|
feature = _subtract_column_mean(feature, subtract_mean)
|
||||||
|
return feature
|
218
opencompass/models/ola/model/speech_encoder/beats/modules.py
Normal file
218
opencompass/models/ola/model/speech_encoder/beats/modules.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class GradMultiply(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, scale):
|
||||||
|
ctx.scale = scale
|
||||||
|
res = x.new(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
return grad * ctx.scale, None
|
||||||
|
|
||||||
|
|
||||||
|
class SamePad(nn.Module):
|
||||||
|
def __init__(self, kernel_size, causal=False):
|
||||||
|
super().__init__()
|
||||||
|
if causal:
|
||||||
|
self.remove = kernel_size - 1
|
||||||
|
else:
|
||||||
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.remove > 0:
|
||||||
|
x = x[:, :, : -self.remove]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Swish, self).__init__()
|
||||||
|
self.act = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.act(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU_Linear(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||||
|
super(GLU_Linear, self).__init__()
|
||||||
|
|
||||||
|
self.glu_type = glu_type
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
if glu_type == "sigmoid":
|
||||||
|
self.glu_act = torch.nn.Sigmoid()
|
||||||
|
elif glu_type == "swish":
|
||||||
|
self.glu_act = Swish()
|
||||||
|
elif glu_type == "relu":
|
||||||
|
self.glu_act = torch.nn.ReLU()
|
||||||
|
elif glu_type == "gelu":
|
||||||
|
self.glu_act = torch.nn.GELU()
|
||||||
|
|
||||||
|
if bias_in_glu:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||||
|
else:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||||
|
x = self.linear(x)
|
||||||
|
|
||||||
|
if self.glu_type == "bilinear":
|
||||||
|
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
||||||
|
else:
|
||||||
|
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_accurate(x):
|
||||||
|
if not hasattr(gelu_accurate, "_a"):
|
||||||
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||||
|
return (
|
||||||
|
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_fn(activation: str):
|
||||||
|
"""Returns the activation function corresponding to `activation`"""
|
||||||
|
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
return gelu
|
||||||
|
elif activation == "gelu_fast":
|
||||||
|
warnings.warn(
|
||||||
|
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||||
|
)
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "gelu_accurate":
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "tanh":
|
||||||
|
return torch.tanh
|
||||||
|
elif activation == "linear":
|
||||||
|
return lambda x: x
|
||||||
|
elif activation == "glu":
|
||||||
|
return lambda x: x
|
||||||
|
else:
|
||||||
|
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||||
|
|
||||||
|
|
||||||
|
def quant_noise(module, p, block_size):
|
||||||
|
"""
|
||||||
|
Wraps modules and applies quantization noise to the weights for
|
||||||
|
subsequent quantization with Iterative Product Quantization as
|
||||||
|
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- module: nn.Module
|
||||||
|
- p: amount of Quantization Noise
|
||||||
|
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
- Module weights must have the right sizes wrt the block size
|
||||||
|
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||||
|
- For more detail on how to quantize by blocks with convolutional weights,
|
||||||
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||||
|
- We implement the simplest form of noise here as stated in the paper
|
||||||
|
which consists in randomly dropping blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if no quantization noise, don't register hook
|
||||||
|
if p <= 0:
|
||||||
|
return module
|
||||||
|
|
||||||
|
# supported modules
|
||||||
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||||
|
|
||||||
|
# test whether module.weight has the right sizes wrt block_size
|
||||||
|
is_conv = module.weight.ndim == 4
|
||||||
|
|
||||||
|
# 2D matrix
|
||||||
|
if not is_conv:
|
||||||
|
assert (
|
||||||
|
module.weight.size(1) % block_size == 0
|
||||||
|
), "Input features must be a multiple of block sizes"
|
||||||
|
|
||||||
|
# 4D matrix
|
||||||
|
else:
|
||||||
|
# 1x1 convolutions
|
||||||
|
if module.kernel_size == (1, 1):
|
||||||
|
assert (
|
||||||
|
module.in_channels % block_size == 0
|
||||||
|
), "Input channels must be a multiple of block sizes"
|
||||||
|
# regular convolutions
|
||||||
|
else:
|
||||||
|
k = module.kernel_size[0] * module.kernel_size[1]
|
||||||
|
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||||
|
|
||||||
|
def _forward_pre_hook(mod, input):
|
||||||
|
# no noise for evaluation
|
||||||
|
if mod.training:
|
||||||
|
if not is_conv:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_features = weight.size(1)
|
||||||
|
out_features = weight.size(0)
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
mask = torch.zeros(
|
||||||
|
in_features // block_size * out_features, device=weight.device
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_channels = mod.in_channels
|
||||||
|
out_channels = mod.out_channels
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
if mod.kernel_size == (1, 1):
|
||||||
|
mask = torch.zeros(
|
||||||
|
int(in_channels // block_size * out_channels),
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros(
|
||||||
|
weight.size(0), weight.size(1), device=weight.device
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = (
|
||||||
|
mask.unsqueeze(2)
|
||||||
|
.unsqueeze(3)
|
||||||
|
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale weights and apply mask
|
||||||
|
mask = mask.to(
|
||||||
|
torch.bool
|
||||||
|
) # x.bool() is not currently supported in TorchScript
|
||||||
|
s = 1 / (1 - p)
|
||||||
|
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(_forward_pre_hook)
|
||||||
|
return module
|
215
opencompass/models/ola/model/speech_encoder/beats/quantizer.py
Normal file
215
opencompass/models/ola/model/speech_encoder/beats/quantizer.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
||||||
|
# Copyright (c) 2022 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on VQGAN code bases
|
||||||
|
# https://github.com/CompVis/taming-transformers
|
||||||
|
# --------------------------------------------------------'
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as distributed
|
||||||
|
|
||||||
|
try:
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def l2norm(t):
|
||||||
|
return F.normalize(t, p=2, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def ema_inplace(moving_avg, new, decay):
|
||||||
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||||
|
|
||||||
|
|
||||||
|
def sample_vectors(samples, num):
|
||||||
|
num_samples, device = samples.shape[0], samples.device
|
||||||
|
|
||||||
|
if num_samples >= num:
|
||||||
|
indices = torch.randperm(num_samples, device=device)[:num]
|
||||||
|
else:
|
||||||
|
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||||
|
|
||||||
|
return samples[indices]
|
||||||
|
|
||||||
|
|
||||||
|
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
||||||
|
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
||||||
|
|
||||||
|
means = sample_vectors(samples, num_clusters)
|
||||||
|
|
||||||
|
for _ in range(num_iters):
|
||||||
|
if use_cosine_sim:
|
||||||
|
dists = samples @ means.t()
|
||||||
|
else:
|
||||||
|
diffs = rearrange(samples, 'n d -> n () d') \
|
||||||
|
- rearrange(means, 'c d -> () c d')
|
||||||
|
dists = -(diffs ** 2).sum(dim=-1)
|
||||||
|
|
||||||
|
buckets = dists.max(dim=-1).indices
|
||||||
|
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||||
|
zero_mask = bins == 0
|
||||||
|
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||||
|
|
||||||
|
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||||
|
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
||||||
|
new_means = new_means / bins_min_clamped[..., None]
|
||||||
|
|
||||||
|
if use_cosine_sim:
|
||||||
|
new_means = l2norm(new_means)
|
||||||
|
|
||||||
|
means = torch.where(zero_mask[..., None], means, new_means)
|
||||||
|
|
||||||
|
return means, bins
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingEMA(nn.Module):
|
||||||
|
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
|
||||||
|
super().__init__()
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
self.codebook_dim = codebook_dim
|
||||||
|
self.decay = decay
|
||||||
|
self.eps = eps
|
||||||
|
if codebook_init_path == '':
|
||||||
|
if not kmeans_init:
|
||||||
|
weight = torch.randn(num_tokens, codebook_dim)
|
||||||
|
weight = l2norm(weight)
|
||||||
|
else:
|
||||||
|
weight = torch.zeros(num_tokens, codebook_dim)
|
||||||
|
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
||||||
|
else:
|
||||||
|
print(f"load init codebook weight from {codebook_init_path}")
|
||||||
|
codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
|
||||||
|
weight = codebook_ckpt_weight.clone()
|
||||||
|
self.register_buffer('initted', torch.Tensor([True]))
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||||
|
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
||||||
|
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
||||||
|
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
||||||
|
self.update = True
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def init_embed_(self, data):
|
||||||
|
if self.initted:
|
||||||
|
return
|
||||||
|
print("Performing Kemans init for codebook")
|
||||||
|
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
||||||
|
self.weight.data.copy_(embed)
|
||||||
|
self.cluster_size.data.copy_(cluster_size)
|
||||||
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
|
def forward(self, embed_id):
|
||||||
|
return F.embedding(embed_id, self.weight)
|
||||||
|
|
||||||
|
def cluster_size_ema_update(self, new_cluster_size):
|
||||||
|
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
||||||
|
|
||||||
|
def embed_avg_ema_update(self, new_embed_avg):
|
||||||
|
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
||||||
|
|
||||||
|
def weight_update(self, num_tokens):
|
||||||
|
n = self.cluster_size.sum()
|
||||||
|
smoothed_cluster_size = (
|
||||||
|
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
||||||
|
)
|
||||||
|
# normalize embedding average with smoothed cluster size
|
||||||
|
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
||||||
|
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
||||||
|
self.weight.data.copy_(embed_normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def norm_ema_inplace(moving_avg, new, decay):
|
||||||
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||||
|
moving_avg.data.copy_(l2norm(moving_avg.data))
|
||||||
|
|
||||||
|
|
||||||
|
class NormEMAVectorQuantizer(nn.Module):
|
||||||
|
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
||||||
|
statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
|
||||||
|
super().__init__()
|
||||||
|
self.codebook_dim = embedding_dim
|
||||||
|
self.num_tokens = n_embed
|
||||||
|
self.beta = beta
|
||||||
|
self.decay = decay
|
||||||
|
|
||||||
|
# learnable = True if orthogonal_reg_weight > 0 else False
|
||||||
|
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
|
||||||
|
|
||||||
|
self.statistic_code_usage = statistic_code_usage
|
||||||
|
if statistic_code_usage:
|
||||||
|
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
||||||
|
if distributed.is_available() and distributed.is_initialized():
|
||||||
|
print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
|
||||||
|
self.all_reduce_fn = distributed.all_reduce
|
||||||
|
else:
|
||||||
|
self.all_reduce_fn = nn.Identity()
|
||||||
|
|
||||||
|
def reset_cluster_size(self, device):
|
||||||
|
if self.statistic_code_usage:
|
||||||
|
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
||||||
|
self.cluster_size = self.cluster_size.to(device)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
# reshape z -> (batch, height, width, channel) and flatten
|
||||||
|
# z, 'b c h w -> b h w c'
|
||||||
|
# z = rearrange(z, 'b c h w -> b h w c')
|
||||||
|
# z = z.transpose(1, 2)
|
||||||
|
z = l2norm(z)
|
||||||
|
z_flattened = z.reshape(-1, self.codebook_dim)
|
||||||
|
|
||||||
|
self.embedding.init_embed_(z_flattened)
|
||||||
|
|
||||||
|
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
||||||
|
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
||||||
|
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
||||||
|
|
||||||
|
encoding_indices = torch.argmin(d, dim=1)
|
||||||
|
|
||||||
|
z_q = self.embedding(encoding_indices).view(z.shape)
|
||||||
|
|
||||||
|
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
||||||
|
|
||||||
|
if not self.training:
|
||||||
|
with torch.no_grad():
|
||||||
|
cluster_size = encodings.sum(0)
|
||||||
|
self.all_reduce_fn(cluster_size)
|
||||||
|
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
||||||
|
|
||||||
|
if self.training and self.embedding.update:
|
||||||
|
# EMA cluster size
|
||||||
|
|
||||||
|
bins = encodings.sum(0)
|
||||||
|
self.all_reduce_fn(bins)
|
||||||
|
|
||||||
|
# self.embedding.cluster_size_ema_update(bins)
|
||||||
|
ema_inplace(self.cluster_size, bins, self.decay)
|
||||||
|
|
||||||
|
zero_mask = (bins == 0)
|
||||||
|
bins = bins.masked_fill(zero_mask, 1.)
|
||||||
|
|
||||||
|
embed_sum = z_flattened.t() @ encodings
|
||||||
|
self.all_reduce_fn(embed_sum)
|
||||||
|
|
||||||
|
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
||||||
|
embed_normalized = l2norm(embed_normalized)
|
||||||
|
|
||||||
|
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
|
||||||
|
embed_normalized)
|
||||||
|
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
|
||||||
|
|
||||||
|
# compute loss for embedding
|
||||||
|
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
||||||
|
|
||||||
|
# preserve gradients
|
||||||
|
z_q = z + (z_q - z).detach()
|
||||||
|
|
||||||
|
# reshape back to match original input shape
|
||||||
|
# z_q, 'b h w c -> b c h w'
|
||||||
|
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
||||||
|
# z_q = z_q.transpose(1, 2)
|
||||||
|
return z_q, loss, encoding_indices
|
13
opencompass/models/ola/model/speech_encoder/builder.py
Normal file
13
opencompass/models/ola/model/speech_encoder/builder.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def build_speech_encoder(config):
|
||||||
|
speech_encoder_type = getattr(config, 'speech_encoder_type', None)
|
||||||
|
if "whisper" in speech_encoder_type.lower():
|
||||||
|
return WhisperWrappedEncoder.load(config)
|
||||||
|
elif "dual" in speech_encoder_type.lower():
|
||||||
|
return DualWrappedEncoder(config)
|
||||||
|
elif "none" in speech_encoder_type.lower():
|
||||||
|
return None
|
||||||
|
|
||||||
|
raise ValueError(f'Unknown speech encoder: {speech_encoder_type}')
|
@ -0,0 +1,74 @@
|
|||||||
|
import types
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import WhisperFeatureExtractor
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
from opencompass.models.ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs
|
||||||
|
|
||||||
|
class WhisperWrappedEncoder:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, model_config):
|
||||||
|
|
||||||
|
def replace_layer_norm(module):
|
||||||
|
from whisper.model import LayerNorm
|
||||||
|
for name, child in module.named_children():
|
||||||
|
if isinstance(child, LayerNorm):
|
||||||
|
old_params = child.state_dict()
|
||||||
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
|
||||||
|
new_layer_norm.load_state_dict(old_params)
|
||||||
|
setattr(module, name, new_layer_norm)
|
||||||
|
else:
|
||||||
|
replace_layer_norm(child)
|
||||||
|
|
||||||
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
|
||||||
|
replace_layer_norm(encoder)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
class DualWrappedEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.whisper_model = self.load_whisper(config)
|
||||||
|
self.beats_model = self.load_beats(config)
|
||||||
|
|
||||||
|
def load_whisper(cls, model_config):
|
||||||
|
|
||||||
|
def replace_layer_norm(module):
|
||||||
|
from whisper.model import LayerNorm
|
||||||
|
for name, child in module.named_children():
|
||||||
|
if isinstance(child, LayerNorm):
|
||||||
|
old_params = child.state_dict()
|
||||||
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
|
||||||
|
new_layer_norm.load_state_dict(old_params)
|
||||||
|
setattr(module, name, new_layer_norm)
|
||||||
|
else:
|
||||||
|
replace_layer_norm(child)
|
||||||
|
|
||||||
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
|
||||||
|
replace_layer_norm(encoder)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
def load_beats(cls, model_config):
|
||||||
|
beats_path = model_config.music_encoder
|
||||||
|
print("Loading BEATs Model")
|
||||||
|
beats_ckpt = torch.load(beats_path, map_location='cpu')
|
||||||
|
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
|
||||||
|
beats = BEATs(beats_cfg)
|
||||||
|
beats.load_state_dict(beats_ckpt['model'])
|
||||||
|
return beats
|
||||||
|
|
||||||
|
def forward(self, x, raw_wav=None, audio_padding_mask=None):
|
||||||
|
with torch.no_grad():
|
||||||
|
self.beats_model = self.beats_model.float()
|
||||||
|
speech_embeds = self.whisper_model(x)
|
||||||
|
audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True)
|
||||||
|
if audio_embeds.size(1) < speech_embeds.size(1):
|
||||||
|
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
|
||||||
|
elif audio_embeds.size(1) > speech_embeds.size(1):
|
||||||
|
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
|
||||||
|
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
|
||||||
|
speech_embeds = speech_embeds.to(torch.bfloat16)
|
||||||
|
return speech_embeds
|
11
opencompass/models/ola/model/speech_projector/builder.py
Normal file
11
opencompass/models/ola/model/speech_projector/builder.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from .speech_projector import EncoderProjectorConcat
|
||||||
|
|
||||||
|
|
||||||
|
def build_speech_projector(config):
|
||||||
|
projector_type = getattr(config, 'speech_projector_type', 'linear')
|
||||||
|
if projector_type == 'linear':
|
||||||
|
return EncoderProjectorConcat(config)
|
||||||
|
elif projector_type == 'none':
|
||||||
|
return None
|
||||||
|
|
||||||
|
raise ValueError(f'Unknown projector type: {projector_type}')
|
@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
class EncoderProjectorConcat(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.k = config.speech_encoder_ds_rate
|
||||||
|
self.encoder_dim = config.speech_encoder_hidden_size
|
||||||
|
self.llm_dim = config.hidden_size
|
||||||
|
self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.linear2 = nn.Linear(2048, config.hidden_size)
|
||||||
|
|
||||||
|
embed_std = 1 / math.sqrt(config.hidden_size)
|
||||||
|
self.speech_newline = nn.Parameter(
|
||||||
|
torch.randn(config.hidden_size) * embed_std
|
||||||
|
)
|
||||||
|
self.speech_begin = nn.Parameter(
|
||||||
|
torch.randn(config.hidden_size) * embed_std
|
||||||
|
)
|
||||||
|
self.speech_end = nn.Parameter(
|
||||||
|
torch.randn(config.hidden_size) * embed_std
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size, seq_len, dim = x.size()
|
||||||
|
num_frames_to_discard = seq_len % self.k
|
||||||
|
if num_frames_to_discard > 0:
|
||||||
|
x = x[:, :-num_frames_to_discard, :]
|
||||||
|
seq_len = x.size(1)
|
||||||
|
|
||||||
|
x = x.contiguous()
|
||||||
|
x = x.view(batch_size, seq_len // self.k, dim * self.k)
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
x = torch.cat([
|
||||||
|
x,
|
||||||
|
self.speech_newline.reshape(1, 1, -1).expand(batch_size, 1, -1).to(x.dtype)
|
||||||
|
], dim=1)
|
||||||
|
begin = self.speech_begin.reshape(1, -1).to(x.dtype)
|
||||||
|
end = self.speech_end.reshape(1, -1).to(x.dtype)
|
||||||
|
x = x.flatten(0, 1)
|
||||||
|
x = torch.cat([begin, x, end], dim=0)
|
||||||
|
# x = x.flatten(0, 1)
|
||||||
|
return x
|
213
opencompass/models/ola/utils.py
Normal file
213
opencompass/models/ola/utils.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
||||||
|
# Copyright 2023 Haotian Liu
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from opencompass.models.ola.constants import LOGDIR
|
||||||
|
|
||||||
|
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||||
|
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
||||||
|
|
||||||
|
handler = None
|
||||||
|
|
||||||
|
|
||||||
|
def build_logger(logger_name, logger_filename):
|
||||||
|
global handler
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the format of root handlers
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
# Redirect stdout and stderr to loggers
|
||||||
|
stdout_logger = logging.getLogger("stdout")
|
||||||
|
stdout_logger.setLevel(logging.INFO)
|
||||||
|
sl = StreamToLogger(stdout_logger, logging.INFO)
|
||||||
|
sys.stdout = sl
|
||||||
|
|
||||||
|
stderr_logger = logging.getLogger("stderr")
|
||||||
|
stderr_logger.setLevel(logging.ERROR)
|
||||||
|
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||||
|
sys.stderr = sl
|
||||||
|
|
||||||
|
# Get logger
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Add a file handler for all loggers
|
||||||
|
if handler is None:
|
||||||
|
os.makedirs(LOGDIR, exist_ok=True)
|
||||||
|
filename = os.path.join(LOGDIR, logger_filename)
|
||||||
|
handler = logging.handlers.TimedRotatingFileHandler(
|
||||||
|
filename, when='D', utc=True, encoding='UTF-8')
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
for name, item in logging.root.manager.loggerDict.items():
|
||||||
|
if isinstance(item, logging.Logger):
|
||||||
|
item.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class StreamToLogger(object):
|
||||||
|
"""
|
||||||
|
Fake file-like stream object that redirects writes to a logger instance.
|
||||||
|
"""
|
||||||
|
def __init__(self, logger, log_level=logging.INFO):
|
||||||
|
self.terminal = sys.stdout
|
||||||
|
self.logger = logger
|
||||||
|
self.log_level = log_level
|
||||||
|
self.linebuf = ''
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.terminal, attr)
|
||||||
|
|
||||||
|
def write(self, buf):
|
||||||
|
temp_linebuf = self.linebuf + buf
|
||||||
|
self.linebuf = ''
|
||||||
|
for line in temp_linebuf.splitlines(True):
|
||||||
|
# From the io.TextIOWrapper docs:
|
||||||
|
# On output, if newline is None, any '\n' characters written
|
||||||
|
# are translated to the system default line separator.
|
||||||
|
# By default sys.stdout.write() expects '\n' newlines and then
|
||||||
|
# translates them so this is still cross platform.
|
||||||
|
if line[-1] == '\n':
|
||||||
|
self.logger.log(self.log_level, line.rstrip())
|
||||||
|
else:
|
||||||
|
self.linebuf += line
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
if self.linebuf != '':
|
||||||
|
self.logger.log(self.log_level, self.linebuf.rstrip())
|
||||||
|
self.linebuf = ''
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_zero_3(param, ignore_status=False, name=None):
|
||||||
|
from deepspeed import zero
|
||||||
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||||
|
if hasattr(param, "ds_id"):
|
||||||
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
||||||
|
if not ignore_status:
|
||||||
|
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
||||||
|
with zero.GatheredParameters([param]):
|
||||||
|
param = param.data.detach().cpu().clone()
|
||||||
|
else:
|
||||||
|
param = param.detach().cpu().clone()
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
# Borrowed from peft.utils.get_peft_model_state_dict
|
||||||
|
def get_peft_state_maybe_zero_3(named_params, bias):
|
||||||
|
if bias == "none":
|
||||||
|
to_return = {k: t for k, t in named_params if "lora_" in k}
|
||||||
|
elif bias == "all":
|
||||||
|
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
||||||
|
elif bias == "lora_only":
|
||||||
|
to_return = {}
|
||||||
|
maybe_lora_bias = {}
|
||||||
|
lora_bias_names = set()
|
||||||
|
for k, t in named_params:
|
||||||
|
if "lora_" in k:
|
||||||
|
to_return[k] = t
|
||||||
|
bias_name = k.split("lora_")[0] + "bias"
|
||||||
|
lora_bias_names.add(bias_name)
|
||||||
|
elif "bias" in k:
|
||||||
|
maybe_lora_bias[k] = t
|
||||||
|
for k, t in maybe_lora_bias:
|
||||||
|
if bias_name in lora_bias_names:
|
||||||
|
to_return[bias_name] = t
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
|
||||||
|
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
||||||
|
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
||||||
|
if require_grad_only:
|
||||||
|
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
||||||
|
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
|
||||||
|
def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match):
|
||||||
|
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
||||||
|
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
def lengths_to_padding_mask(lens):
|
||||||
|
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
||||||
|
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
||||||
|
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def lengths_to_mask(lens):
|
||||||
|
return ~lengths_to_padding_mask(lens)
|
||||||
|
|
||||||
|
|
||||||
|
def disable_torch_init():
|
||||||
|
"""
|
||||||
|
Disable the redundant torch default initialization to accelerate model creation.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||||
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_name_from_path(model_path):
|
||||||
|
model_path = model_path.strip("/")
|
||||||
|
model_paths = model_path.split("/")
|
||||||
|
if model_paths[-1].startswith('checkpoint-'):
|
||||||
|
return model_paths[-2] + "_" + model_paths[-1]
|
||||||
|
else:
|
||||||
|
return model_paths[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def violates_moderation(text):
|
||||||
|
"""
|
||||||
|
Check whether the text violates OpenAI moderation API.
|
||||||
|
"""
|
||||||
|
url = "https://api.openai.com/v1/moderations"
|
||||||
|
headers = {"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
||||||
|
text = text.replace("\n", "")
|
||||||
|
data = "{" + '"input": ' + f'"{text}"' + "}"
|
||||||
|
data = data.encode("utf-8")
|
||||||
|
try:
|
||||||
|
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
||||||
|
flagged = ret.json()["results"][0]["flagged"]
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
flagged = False
|
||||||
|
except KeyError as e:
|
||||||
|
flagged = False
|
||||||
|
|
||||||
|
return flagged
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_semaphore(semaphore):
|
||||||
|
if semaphore is None:
|
||||||
|
return "None"
|
||||||
|
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
141
opencompass/models/ola_model.py
Normal file
141
opencompass/models/ola_model.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import os
|
||||||
|
os.environ['LOWRES_RESIZE'] = '384x32'
|
||||||
|
os.environ['HIGHRES_BASE'] = '0x32'
|
||||||
|
os.environ['VIDEO_RESIZE'] = "0x64"
|
||||||
|
os.environ['VIDEO_MAXRES'] = "480"
|
||||||
|
os.environ['VIDEO_MINRES'] = "288"
|
||||||
|
os.environ['MAXRES'] = '1536'
|
||||||
|
os.environ['MINRES'] = '0'
|
||||||
|
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
||||||
|
os.environ['LOAD_VISION_EARLY'] = '1'
|
||||||
|
os.environ['PAD2STRIDE'] = '1'
|
||||||
|
|
||||||
|
from opencompass.models.base import BaseModel
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from opencompass.models.base import BaseModel, LMTemplateParser
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import re
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import transformers
|
||||||
|
from typing import Dict, Optional, Sequence, List
|
||||||
|
from opencompass.models.ola.conversation import conv_templates, SeparatorStyle
|
||||||
|
from opencompass.models.ola.model.builder import load_pretrained_model
|
||||||
|
from opencompass.models.ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
||||||
|
from opencompass.models.ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
||||||
|
from opencompass.models.ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
|
||||||
|
class OlaModel(BaseModel):
|
||||||
|
def __init__(self,
|
||||||
|
path: str,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
model_config: Optional[str] = None,
|
||||||
|
meta_template: Optional[Dict] = None):
|
||||||
|
|
||||||
|
|
||||||
|
self.template_parser = LMTemplateParser(meta_template)
|
||||||
|
self.eos_token_id = None
|
||||||
|
if meta_template and 'eos_token_id' in meta_template:
|
||||||
|
self.eos_token_id = meta_template['eos_token_id']
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer, model, _, _ = load_pretrained_model(path, None)
|
||||||
|
model = model.to('cuda').eval()
|
||||||
|
model = model.bfloat16()
|
||||||
|
self.tokenizer=tokenizer
|
||||||
|
self.model=model
|
||||||
|
self.gen_kwargs = {
|
||||||
|
"max_new_tokens":1024,
|
||||||
|
"temperature":0.2,
|
||||||
|
"top_p":None,
|
||||||
|
"num_beams":1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||||
|
assert len(inputs)==1 # batch=1
|
||||||
|
image_path = None
|
||||||
|
audio_path = None
|
||||||
|
video_path = None
|
||||||
|
text = inputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
||||||
|
images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
||||||
|
image_sizes = [(224, 224)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
USE_SPEECH=False
|
||||||
|
speechs = []
|
||||||
|
speech_lengths = []
|
||||||
|
speech_wavs = []
|
||||||
|
speech_chunks = []
|
||||||
|
speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
|
||||||
|
speech_lengths = [torch.LongTensor([3000]).to('cuda')]
|
||||||
|
speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
|
||||||
|
speech_chunks = [torch.LongTensor([1]).to('cuda')]
|
||||||
|
|
||||||
|
|
||||||
|
conv_mode = "qwen_1_5"
|
||||||
|
if text:
|
||||||
|
qs = text
|
||||||
|
else:
|
||||||
|
qs = ''
|
||||||
|
conv = conv_templates[conv_mode].copy()
|
||||||
|
conv.append_message(conv.roles[0], qs)
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
prompt = conv.get_prompt()
|
||||||
|
|
||||||
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
||||||
|
|
||||||
|
pad_token_ids = 151643
|
||||||
|
|
||||||
|
attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
|
||||||
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
||||||
|
keywords = [stop_str]
|
||||||
|
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
output_ids = self.model.generate(
|
||||||
|
input_ids,
|
||||||
|
images=images,
|
||||||
|
images_highres=images_highres,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
modalities=['text'],
|
||||||
|
speech=speechs,
|
||||||
|
speech_lengths=speech_lengths,
|
||||||
|
speech_chunks=speech_chunks,
|
||||||
|
speech_wav=speech_wavs,
|
||||||
|
attention_mask=attention_masks,
|
||||||
|
use_cache=True,
|
||||||
|
stopping_criteria=[stopping_criteria],
|
||||||
|
do_sample=True if self.gen_kwargs["temperature"] > 0 else False,
|
||||||
|
temperature=self.gen_kwargs["temperature"],
|
||||||
|
top_p=self.gen_kwargs["top_p"],
|
||||||
|
num_beams=self.gen_kwargs["num_beams"],
|
||||||
|
max_new_tokens=self.gen_kwargs["max_new_tokens"],
|
||||||
|
)
|
||||||
|
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
out=[]
|
||||||
|
for output in outputs:
|
||||||
|
output = output.strip()
|
||||||
|
if output.endswith(stop_str):
|
||||||
|
output = output[:-len(stop_str)]
|
||||||
|
out.append(output)
|
||||||
|
print(f"prompt---->",prompt)
|
||||||
|
print(f"out---->",out)
|
||||||
|
print(f"\n")
|
||||||
|
return out
|
Loading…
Reference in New Issue
Block a user