import types import torch import torch.nn as nn import torch.nn.functional as F from transformers import WhisperFeatureExtractor import whisper from opencompass.models.ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs class WhisperWrappedEncoder: @classmethod def load(cls, model_config): def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): old_params = child.state_dict() new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder replace_layer_norm(encoder) return encoder class DualWrappedEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.whisper_model = self.load_whisper(config) self.beats_model = self.load_beats(config) def load_whisper(cls, model_config): def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): old_params = child.state_dict() new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder replace_layer_norm(encoder) return encoder def load_beats(cls, model_config): beats_path = model_config.music_encoder print("Loading BEATs Model") beats_ckpt = torch.load(beats_path, map_location='cpu') beats_cfg = BEATsConfig(beats_ckpt['cfg']) beats = BEATs(beats_cfg) beats.load_state_dict(beats_ckpt['model']) return beats def forward(self, x, raw_wav=None, audio_padding_mask=None): with torch.no_grad(): self.beats_model = self.beats_model.float() speech_embeds = self.whisper_model(x) audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True) if audio_embeds.size(1) < speech_embeds.size(1): audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) elif audio_embeds.size(1) > speech_embeds.size(1): speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) speech_embeds = speech_embeds.to(torch.bfloat16) return speech_embeds