mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
74 lines
3.1 KiB
Python
74 lines
3.1 KiB
Python
import types
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import WhisperFeatureExtractor
|
|
import whisper
|
|
|
|
from opencompass.models.ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs
|
|
|
|
class WhisperWrappedEncoder:
|
|
|
|
@classmethod
|
|
def load(cls, model_config):
|
|
|
|
def replace_layer_norm(module):
|
|
from whisper.model import LayerNorm
|
|
for name, child in module.named_children():
|
|
if isinstance(child, LayerNorm):
|
|
old_params = child.state_dict()
|
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
|
|
new_layer_norm.load_state_dict(old_params)
|
|
setattr(module, name, new_layer_norm)
|
|
else:
|
|
replace_layer_norm(child)
|
|
|
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
|
|
replace_layer_norm(encoder)
|
|
return encoder
|
|
|
|
class DualWrappedEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.whisper_model = self.load_whisper(config)
|
|
self.beats_model = self.load_beats(config)
|
|
|
|
def load_whisper(cls, model_config):
|
|
|
|
def replace_layer_norm(module):
|
|
from whisper.model import LayerNorm
|
|
for name, child in module.named_children():
|
|
if isinstance(child, LayerNorm):
|
|
old_params = child.state_dict()
|
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
|
|
new_layer_norm.load_state_dict(old_params)
|
|
setattr(module, name, new_layer_norm)
|
|
else:
|
|
replace_layer_norm(child)
|
|
|
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
|
|
replace_layer_norm(encoder)
|
|
return encoder
|
|
|
|
def load_beats(cls, model_config):
|
|
beats_path = model_config.music_encoder
|
|
print("Loading BEATs Model")
|
|
beats_ckpt = torch.load(beats_path, map_location='cpu')
|
|
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
|
|
beats = BEATs(beats_cfg)
|
|
beats.load_state_dict(beats_ckpt['model'])
|
|
return beats
|
|
|
|
def forward(self, x, raw_wav=None, audio_padding_mask=None):
|
|
with torch.no_grad():
|
|
self.beats_model = self.beats_model.float()
|
|
speech_embeds = self.whisper_model(x)
|
|
audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True)
|
|
if audio_embeds.size(1) < speech_embeds.size(1):
|
|
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
|
|
elif audio_embeds.size(1) > speech_embeds.size(1):
|
|
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
|
|
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
|
|
speech_embeds = speech_embeds.to(torch.bfloat16)
|
|
return speech_embeds |