|
from typing import Tuple |
|
|
|
import torch.nn as nn |
|
|
|
from .deepunet import DeepUnet |
|
|
|
|
|
class E2E(nn.Module): |
|
def __init__( |
|
self, |
|
n_blocks: int, |
|
n_gru: int, |
|
kernel_size: Tuple[int, int], |
|
en_de_layers=5, |
|
inter_layers=4, |
|
in_channels=1, |
|
en_out_channels=16, |
|
): |
|
super(E2E, self).__init__() |
|
|
|
self.unet = DeepUnet( |
|
kernel_size, |
|
n_blocks, |
|
en_de_layers, |
|
inter_layers, |
|
in_channels, |
|
en_out_channels, |
|
) |
|
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) |
|
if n_gru: |
|
self.fc = nn.Sequential( |
|
self.BiGRU(3 * 128, 256, n_gru), |
|
nn.Linear(512, 360), |
|
nn.Dropout(0.25), |
|
nn.Sigmoid(), |
|
) |
|
else: |
|
self.fc = nn.Sequential( |
|
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), |
|
nn.Dropout(0.25), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, mel): |
|
mel = mel.transpose(-1, -2).unsqueeze(1) |
|
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) |
|
x = self.fc(x) |
|
return x |
|
|
|
class BiGRU(nn.Module): |
|
def __init__( |
|
self, |
|
input_features: int, |
|
hidden_features: int, |
|
num_layers: int, |
|
): |
|
super().__init__() |
|
self.gru = nn.GRU( |
|
input_features, |
|
hidden_features, |
|
num_layers=num_layers, |
|
batch_first=True, |
|
bidirectional=True, |
|
) |
|
|
|
def forward(self, x): |
|
return self.gru(x)[0] |
|
|