|
import os |
|
import numpy as np |
|
|
|
import torchaudio |
|
import torch |
|
from torch import nn |
|
|
|
from speechbrain.lobes.models.Xvector import Xvector |
|
from speechbrain.lobes.features import Fbank |
|
from speechbrain.processing.features import InputNormalization |
|
|
|
|
|
class Extractor(nn.Module): |
|
model_dict = [ |
|
"mean_var_norm", |
|
"compute_features", |
|
"embedding_model", |
|
"mean_var_norm_emb", |
|
] |
|
def __init__(self, model_path, n_mels=24, device="cpu"): |
|
super().__init__() |
|
self.device = device |
|
self.compute_features = Fbank(n_mels=n_mels) |
|
self.mean_var_norm = InputNormalization(norm_type="sentence", std_norm=False) |
|
self.embedding_model = Xvector( |
|
in_channels = n_mels, |
|
activation = torch.nn.LeakyReLU, |
|
tdnn_blocks = 5, |
|
tdnn_channels = [512, 512, 512, 512, 1500], |
|
tdnn_kernel_sizes = [5, 3, 3, 1, 1], |
|
tdnn_dilations = [1, 2, 3, 1, 1], |
|
lin_neurons = 512, |
|
) |
|
self.mean_var_norm_emb = InputNormalization(norm_type="global", std_norm=False) |
|
for mod_name in self.model_dict: |
|
filename = os.path.join(model_path, f"{mod_name}.ckpt") |
|
module = getattr(self, mod_name) |
|
if os.path.exists(filename): |
|
if hasattr(module, "_load"): |
|
print(f"Load: {filename}") |
|
module._load(filename) |
|
else: |
|
print(f"Load State Dict: {filename}") |
|
module.load_state_dict(torch.load(filename)) |
|
module.to(self.device) |
|
|
|
@torch.no_grad() |
|
def forward(self, wavs, wav_lens = None, normalize=False): |
|
|
|
if len(wavs.shape) == 1: |
|
wavs = wavs.unsqueeze(0) |
|
|
|
|
|
if wav_lens is None: |
|
wav_lens = torch.ones(wavs.shape[0], device=self.device) |
|
|
|
|
|
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
|
wavs = wavs.float() |
|
|
|
|
|
feats = self.compute_features(wavs) |
|
feats = self.mean_var_norm(feats, wav_lens) |
|
embeddings = self.embedding_model(feats, wav_lens) |
|
if normalize: |
|
embeddings = self.mean_var_norm_emb( |
|
embeddings, torch.ones(embeddings.shape[0], device=self.device) |
|
) |
|
return embeddings |
|
|
|
|
|
MODEL_PATH = "pretrained_models/spkrec-xvect-voxceleb" |
|
signal, fs = torchaudio.load('audio.wav') |
|
|
|
device = "cuda" |
|
extractor = Extractor(MODEL_PATH, device=device) |
|
|
|
for k, p in extractor.named_parameters(): |
|
p.requires_grad = False |
|
|
|
extractor.eval() |
|
embeddings_x = extractor(signal).cpu().squeeze() |
|
|
|
|
|
traced_model = torch.jit.trace(extractor, signal) |
|
torch.jit.save(traced_model, f"model_{device}.pt") |
|
embeddings_t = traced_model(signal).squeeze() |
|
print(embeddings_t) |
|
|
|
model = torch.jit.load(f"model_{device}.pt") |
|
emb_m = model(signal).squeeze() |
|
print(emb_m) |
|
|