mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
14 lines
531 B
Python
14 lines
531 B
Python
from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder
|
|
import torch.nn as nn
|
|
|
|
def build_speech_encoder(config):
|
|
speech_encoder_type = getattr(config, 'speech_encoder_type', None)
|
|
if "whisper" in speech_encoder_type.lower():
|
|
return WhisperWrappedEncoder.load(config)
|
|
elif "dual" in speech_encoder_type.lower():
|
|
return DualWrappedEncoder(config)
|
|
elif "none" in speech_encoder_type.lower():
|
|
return None
|
|
|
|
raise ValueError(f'Unknown speech encoder: {speech_encoder_type}')
|