|  |  | 
					
						
						|  |  | 
					
						
						|  | from __future__ import print_function | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import numpy as np | 
					
						
						|  | from torch.utils.checkpoint import checkpoint_sequential | 
					
						
						|  | try: | 
					
						
						|  | from mamba_ssm.modules.mamba2 import Mamba2 | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print('Exception during load Mamba2 modules: {}'.format(str(e))) | 
					
						
						|  | print('Load local torch implementation!') | 
					
						
						|  | from .ex_bi_mamba2 import Mamba2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MambaBlock(nn.Module): | 
					
						
						|  | def __init__(self, in_channels): | 
					
						
						|  | super(MambaBlock, self).__init__() | 
					
						
						|  | self.forward_mamba2 = Mamba2( | 
					
						
						|  | d_model=in_channels, | 
					
						
						|  | d_state=128, | 
					
						
						|  | d_conv=4, | 
					
						
						|  | expand=4, | 
					
						
						|  | headdim=64, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.backward_mamba2 = Mamba2( | 
					
						
						|  | d_model=in_channels, | 
					
						
						|  | d_state=128, | 
					
						
						|  | d_conv=4, | 
					
						
						|  | expand=4, | 
					
						
						|  | headdim=64, | 
					
						
						|  | ) | 
					
						
						|  | def forward(self, input): | 
					
						
						|  | forward_f = input | 
					
						
						|  | forward_f_output = self.forward_mamba2(forward_f) | 
					
						
						|  | backward_f = torch.flip(input, [1]) | 
					
						
						|  | backward_f_output = self.backward_mamba2(backward_f) | 
					
						
						|  | backward_f_output2 = torch.flip(backward_f_output, [1]) | 
					
						
						|  | output = torch.cat([forward_f_output + input, backward_f_output2+input], -1) | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  | class TAC(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | A transform-average-concatenate (TAC) module. | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self, input_size, hidden_size): | 
					
						
						|  | super(TAC, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.input_size = input_size | 
					
						
						|  | self.eps = torch.finfo(torch.float32).eps | 
					
						
						|  |  | 
					
						
						|  | self.input_norm = nn.GroupNorm(1, input_size, self.eps) | 
					
						
						|  | self.TAC_input = nn.Sequential(nn.Linear(input_size, hidden_size), | 
					
						
						|  | nn.Tanh() | 
					
						
						|  | ) | 
					
						
						|  | self.TAC_mean = nn.Sequential(nn.Linear(hidden_size, hidden_size), | 
					
						
						|  | nn.Tanh() | 
					
						
						|  | ) | 
					
						
						|  | self.TAC_output = nn.Sequential(nn.Linear(hidden_size*2, input_size), | 
					
						
						|  | nn.Tanh() | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_size, G, N = input.shape[:3] | 
					
						
						|  | output = self.input_norm(input.view(batch_size*G, N, -1)).view(batch_size, G, N, -1) | 
					
						
						|  | T = output.shape[-1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | group_input = output | 
					
						
						|  | group_input = group_input.permute(0,3,1,2).contiguous().view(-1, N) | 
					
						
						|  | group_output = self.TAC_input(group_input).view(batch_size, T, G, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | group_mean = group_output.mean(2).view(batch_size*T, -1) | 
					
						
						|  | group_mean = self.TAC_mean(group_mean).unsqueeze(1).expand(batch_size*T, G, group_mean.shape[-1]).contiguous() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | group_output = group_output.view(batch_size*T, G, -1) | 
					
						
						|  | group_output = torch.cat([group_output, group_mean], 2) | 
					
						
						|  | group_output = self.TAC_output(group_output.view(-1, group_output.shape[-1])) | 
					
						
						|  | group_output = group_output.view(batch_size, T, G, -1).permute(0,2,3,1).contiguous() | 
					
						
						|  | output = input + group_output.view(input.shape) | 
					
						
						|  |  | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  | class ResMamba(nn.Module): | 
					
						
						|  | def __init__(self, input_size, hidden_size, dropout=0., bidirectional=True): | 
					
						
						|  | super(ResMamba, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.input_size = input_size | 
					
						
						|  | self.hidden_size = hidden_size | 
					
						
						|  | self.eps = torch.finfo(torch.float32).eps | 
					
						
						|  |  | 
					
						
						|  | self.norm = nn.GroupNorm(1, input_size, self.eps) | 
					
						
						|  | self.dropout = nn.Dropout(p=dropout) | 
					
						
						|  | self.rnn = MambaBlock(input_size) | 
					
						
						|  | self.proj = nn.Linear(input_size*2 ,input_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  |  | 
					
						
						|  | rnn_output =  self.rnn(self.dropout(self.norm(input)).transpose(1, 2).contiguous()) | 
					
						
						|  | rnn_output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])).view(input.shape[0], | 
					
						
						|  | input.shape[2], | 
					
						
						|  | input.shape[1]) | 
					
						
						|  |  | 
					
						
						|  | return input + rnn_output.transpose(1, 2).contiguous() | 
					
						
						|  |  | 
					
						
						|  | class BSNet(nn.Module): | 
					
						
						|  | def __init__(self, in_channel, nband=7): | 
					
						
						|  | super(BSNet, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.nband = nband | 
					
						
						|  | self.feature_dim = in_channel // nband | 
					
						
						|  |  | 
					
						
						|  | self.band_rnn = ResMamba(self.feature_dim, self.feature_dim*2) | 
					
						
						|  | self.band_comm = ResMamba(self.feature_dim, self.feature_dim*2) | 
					
						
						|  | self.channel_comm = TAC(self.feature_dim, self.feature_dim*3) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  |  | 
					
						
						|  | B, nch, N, T = input.shape | 
					
						
						|  |  | 
					
						
						|  | band_output = self.band_rnn(input.view(B*nch*self.nband, self.feature_dim, -1)).view(B*nch, self.nband, -1, T) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | band_output = band_output.permute(0,3,2,1).contiguous().view(B*nch*T, -1, self.nband) | 
					
						
						|  | output = self.band_comm(band_output).view(B*nch, T, -1, self.nband).permute(0,3,2,1).contiguous() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | output = output.view(B, nch, self.nband, -1, T).transpose(1,2).contiguous().view(B*self.nband, nch, -1, T) | 
					
						
						|  | output = self.channel_comm(output).view(B, self.nband, nch, -1, T).transpose(1,2).contiguous() | 
					
						
						|  |  | 
					
						
						|  | return output.view(B, nch, N, T) | 
					
						
						|  |  | 
					
						
						|  | class Separator(nn.Module): | 
					
						
						|  | def __init__(self, sr=44100, win=2048, stride=512, feature_dim=128, num_repeat_mask=8, num_repeat_map=4, num_output=4): | 
					
						
						|  | super(Separator, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.sr = sr | 
					
						
						|  | self.win = win | 
					
						
						|  | self.stride = stride | 
					
						
						|  | self.group = self.win // 2 | 
					
						
						|  | self.enc_dim = self.win // 2 + 1 | 
					
						
						|  | self.feature_dim = feature_dim | 
					
						
						|  | self.num_output = num_output | 
					
						
						|  | self.eps = torch.finfo(torch.float32).eps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) | 
					
						
						|  | bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) | 
					
						
						|  | bandwidth_250 = int(np.floor(250 / (sr / 2.) * self.enc_dim)) | 
					
						
						|  | bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) | 
					
						
						|  | bandwidth_1k = int(np.floor(1000 / (sr / 2.) * self.enc_dim)) | 
					
						
						|  | bandwidth_2k = int(np.floor(2000 / (sr / 2.) * self.enc_dim)) | 
					
						
						|  | self.band_width = [bandwidth_50]*20 | 
					
						
						|  | self.band_width += [bandwidth_100]*10 | 
					
						
						|  | self.band_width += [bandwidth_250]*8 | 
					
						
						|  | self.band_width += [bandwidth_500]*8 | 
					
						
						|  | self.band_width += [bandwidth_1k]*8 | 
					
						
						|  | self.band_width += [bandwidth_2k]*2 | 
					
						
						|  | self.band_width.append(self.enc_dim - np.sum(self.band_width)) | 
					
						
						|  | self.nband = len(self.band_width) | 
					
						
						|  | print(self.band_width) | 
					
						
						|  |  | 
					
						
						|  | self.BN_mask = nn.ModuleList([]) | 
					
						
						|  | for i in range(self.nband): | 
					
						
						|  | self.BN_mask.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i]*2, self.eps), | 
					
						
						|  | nn.Conv1d(self.band_width[i]*2, self.feature_dim, 1) | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.BN_map = nn.ModuleList([]) | 
					
						
						|  | for i in range(self.nband): | 
					
						
						|  | self.BN_map.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i] * 2, self.eps), | 
					
						
						|  | nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1) | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.separator_mask = [] | 
					
						
						|  | for i in range(num_repeat_mask): | 
					
						
						|  | self.separator_mask.append(BSNet(self.nband*self.feature_dim, self.nband)) | 
					
						
						|  | self.separator_mask = nn.Sequential(*self.separator_mask) | 
					
						
						|  |  | 
					
						
						|  | self.separator_map = [] | 
					
						
						|  | for i in range(num_repeat_map): | 
					
						
						|  | self.separator_map.append(BSNet(self.nband * self.feature_dim, self.nband)) | 
					
						
						|  | self.separator_map = nn.Sequential(*self.separator_map) | 
					
						
						|  |  | 
					
						
						|  | self.in_conv = nn.Conv1d(self.feature_dim*2, self.feature_dim, 1) | 
					
						
						|  | self.Tanh = nn.Tanh() | 
					
						
						|  | self.mask = nn.ModuleList([]) | 
					
						
						|  | self.map = nn.ModuleList([]) | 
					
						
						|  | for i in range(self.nband): | 
					
						
						|  | self.mask.append(nn.Sequential(nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps), | 
					
						
						|  | nn.Conv1d(self.feature_dim, self.feature_dim*1*self.num_output, 1), | 
					
						
						|  | nn.Tanh(), | 
					
						
						|  | nn.Conv1d(self.feature_dim*1*self.num_output, self.feature_dim*1*self.num_output, 1, groups=self.num_output), | 
					
						
						|  | nn.Tanh(), | 
					
						
						|  | nn.Conv1d(self.feature_dim*1*self.num_output, self.band_width[i]*4*self.num_output, 1, groups=self.num_output) | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | self.map.append(nn.Sequential(nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps), | 
					
						
						|  | nn.Conv1d(self.feature_dim, self.feature_dim*1*self.num_output, 1), | 
					
						
						|  | nn.Tanh(), | 
					
						
						|  | nn.Conv1d(self.feature_dim*1*self.num_output, self.feature_dim*1*self.num_output, 1, groups=self.num_output), | 
					
						
						|  | nn.Tanh(), | 
					
						
						|  | nn.Conv1d(self.feature_dim*1*self.num_output, self.band_width[i]*4*self.num_output, 1, groups=self.num_output) | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def pad_input(self, input, window, stride): | 
					
						
						|  | """ | 
					
						
						|  | Zero-padding input according to window/stride size. | 
					
						
						|  | """ | 
					
						
						|  | batch_size, nsample = input.shape | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | rest = window - (stride + nsample % window) % window | 
					
						
						|  | if rest > 0: | 
					
						
						|  | pad = torch.zeros(batch_size, rest).type(input.type()) | 
					
						
						|  | input = torch.cat([input, pad], 1) | 
					
						
						|  | pad_aux = torch.zeros(batch_size, stride).type(input.type()) | 
					
						
						|  | input = torch.cat([pad_aux, input, pad_aux], 1) | 
					
						
						|  |  | 
					
						
						|  | return input, rest | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_size, nch, nsample = input.shape | 
					
						
						|  | input = input.view(batch_size*nch, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | spec = torch.stft(input, n_fft=self.win, hop_length=self.stride, | 
					
						
						|  | window=torch.hann_window(self.win).to(input.device).type(input.type()), | 
					
						
						|  | return_complex=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | spec_RI = torch.stack([spec.real, spec.imag], 1) | 
					
						
						|  | subband_spec_RI = [] | 
					
						
						|  | subband_spec = [] | 
					
						
						|  | band_idx = 0 | 
					
						
						|  | for i in range(len(self.band_width)): | 
					
						
						|  | subband_spec_RI.append(spec_RI[:,:,band_idx:band_idx+self.band_width[i]].contiguous()) | 
					
						
						|  | subband_spec.append(spec[:,band_idx:band_idx+self.band_width[i]]) | 
					
						
						|  | band_idx += self.band_width[i] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | subband_feature_mask = [] | 
					
						
						|  | for i in range(len(self.band_width)): | 
					
						
						|  | subband_feature_mask.append(self.BN_mask[i](subband_spec_RI[i].view(batch_size*nch, self.band_width[i]*2, -1))) | 
					
						
						|  | subband_feature_mask = torch.stack(subband_feature_mask, 1) | 
					
						
						|  |  | 
					
						
						|  | subband_feature_map = [] | 
					
						
						|  | for i in range(len(self.band_width)): | 
					
						
						|  | subband_feature_map.append(self.BN_map[i](subband_spec_RI[i].view(batch_size * nch, self.band_width[i] * 2, -1))) | 
					
						
						|  | subband_feature_map = torch.stack(subband_feature_map, 1) | 
					
						
						|  |  | 
					
						
						|  | sep_output = checkpoint_sequential(self.separator_mask, 2, subband_feature_mask.view(batch_size, nch, self.nband*self.feature_dim, -1)) | 
					
						
						|  | sep_output = sep_output.view(batch_size*nch, self.nband, self.feature_dim, -1) | 
					
						
						|  | combined = torch.cat((subband_feature_map,sep_output), dim=2) | 
					
						
						|  | combined1 = combined.reshape(batch_size * nch * self.nband,self.feature_dim*2,-1) | 
					
						
						|  | combined2 = self.Tanh(self.in_conv(combined1)) | 
					
						
						|  | combined3 = combined2.reshape(batch_size * nch, self.nband,self.feature_dim,-1) | 
					
						
						|  | sep_output2 = checkpoint_sequential(self.separator_map, 2, combined3.view(batch_size, nch, self.nband*self.feature_dim, -1)) | 
					
						
						|  | sep_output2 = sep_output2.view(batch_size * nch, self.nband, self.feature_dim, -1) | 
					
						
						|  |  | 
					
						
						|  | sep_subband_spec = [] | 
					
						
						|  | sep_subband_spec_mask = [] | 
					
						
						|  | for i in range(self.nband): | 
					
						
						|  | this_output = self.mask[i](sep_output[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1) | 
					
						
						|  | this_mask = this_output[:,0] * torch.sigmoid(this_output[:,1]) | 
					
						
						|  | this_mask_real = this_mask[:,0] | 
					
						
						|  | this_mask_imag = this_mask[:,1] | 
					
						
						|  |  | 
					
						
						|  | this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) | 
					
						
						|  | this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) | 
					
						
						|  | this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output | 
					
						
						|  | this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output | 
					
						
						|  | est_spec_real = subband_spec[i].real.unsqueeze(1) * this_mask_real - subband_spec[i].imag.unsqueeze(1) * this_mask_imag | 
					
						
						|  | est_spec_imag = subband_spec[i].real.unsqueeze(1) * this_mask_imag + subband_spec[i].imag.unsqueeze(1) * this_mask_real | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | this_output2 = self.map[i](sep_output2[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1) | 
					
						
						|  | this_map = this_output2[:,0] * torch.sigmoid(this_output2[:,1]) | 
					
						
						|  | this_map_real = this_map[:,0] | 
					
						
						|  | this_map_imag = this_map[:,1] | 
					
						
						|  | est_spec_real2 = est_spec_real+this_map_real | 
					
						
						|  | est_spec_imag2 = est_spec_imag+this_map_imag | 
					
						
						|  |  | 
					
						
						|  | sep_subband_spec.append(torch.complex(est_spec_real2, est_spec_imag2)) | 
					
						
						|  | sep_subband_spec_mask.append(torch.complex(est_spec_real, est_spec_imag)) | 
					
						
						|  |  | 
					
						
						|  | sep_subband_spec = torch.cat(sep_subband_spec, 2) | 
					
						
						|  | est_spec_mask = torch.cat(sep_subband_spec_mask, 2) | 
					
						
						|  |  | 
					
						
						|  | output = torch.istft(sep_subband_spec.view(batch_size*nch*self.num_output, self.enc_dim, -1), | 
					
						
						|  | n_fft=self.win, hop_length=self.stride, | 
					
						
						|  | window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample) | 
					
						
						|  | output_mask = torch.istft(est_spec_mask.view(batch_size*nch*self.num_output, self.enc_dim, -1), | 
					
						
						|  | n_fft=self.win, hop_length=self.stride, | 
					
						
						|  | window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample) | 
					
						
						|  |  | 
					
						
						|  | output = output.view(batch_size, nch, self.num_output, -1).transpose(1,2).contiguous() | 
					
						
						|  | output_mask = output_mask.view(batch_size, nch, self.num_output, -1).transpose(1,2).contiguous() | 
					
						
						|  |  | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | model = Separator().cuda() | 
					
						
						|  | arr = np.zeros((1, 2, 3*44100), dtype=np.float32) | 
					
						
						|  | x = torch.from_numpy(arr).cuda() | 
					
						
						|  | res = model(x) | 
					
						
						|  |  |