| import copy | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils import weight_norm, spectral_norm | |
| from einops import rearrange | |
| class HiFiGANPeriodDiscriminator(torch.nn.Module): | |
| """HiFiGAN period discriminator module.""" | |
| def __init__( | |
| self, | |
| in_channels=1, | |
| out_channels=1, | |
| period=3, | |
| kernel_sizes=[5, 3], | |
| channels=32, | |
| downsample_scales=[3, 3, 3, 3, 1], | |
| channel_increasing_factor=4, | |
| max_downsample_channels=1024, | |
| nonlinear_activation="LeakyReLU", | |
| nonlinear_activation_params={"negative_slope": 0.1}, | |
| use_weight_norm=True, | |
| ): | |
| """Initialize HiFiGANPeriodDiscriminator module. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| period (int): Period. | |
| kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer. | |
| channels (int): Number of initial channels. | |
| downsample_scales (list): List of downsampling scales. | |
| max_downsample_channels (int): Number of maximum downsampling channels. | |
| nonlinear_activation (str): Activation function module name. | |
| nonlinear_activation_params (dict): Hyperparameters for activation function. | |
| use_weight_norm (bool): Whether to use weight norm. | |
| If set to true, it will be applied to all of the conv layers. | |
| """ | |
| super().__init__() | |
| assert len(kernel_sizes) == 2 | |
| assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." | |
| assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." | |
| self.period = period | |
| self.convs = torch.nn.ModuleList() | |
| in_chs = in_channels | |
| out_chs = channels | |
| for downsample_scale in downsample_scales: | |
| self.convs += [ | |
| torch.nn.Sequential( | |
| torch.nn.Conv2d( | |
| in_chs, | |
| out_chs, | |
| (kernel_sizes[0], 1), | |
| (downsample_scale, 1), | |
| padding=((kernel_sizes[0] - 1) // 2, 0), | |
| ), | |
| getattr(torch.nn, nonlinear_activation)( | |
| **nonlinear_activation_params | |
| ), | |
| ) | |
| ] | |
| in_chs = out_chs | |
| out_chs = min(out_chs * channel_increasing_factor, max_downsample_channels) | |
| self.output_conv = torch.nn.Conv2d( | |
| in_chs, | |
| out_channels, | |
| (kernel_sizes[1] - 1, 1), | |
| 1, | |
| padding=((kernel_sizes[1] - 1) // 2, 0), | |
| ) | |
| if use_weight_norm: | |
| self.apply_weight_norm() | |
| def forward(self, x): | |
| """Calculate forward propagation. | |
| Args: | |
| c (Tensor): Input tensor (B, in_channels, T). | |
| Returns: | |
| list: List of each layer's tensors. | |
| """ | |
| b, c, t = x.shape | |
| if t % self.period != 0: | |
| n_pad = self.period - (t % self.period) | |
| x = F.pad(x, (0, n_pad), "reflect") | |
| t += n_pad | |
| x = x.view(b, c, t // self.period, self.period) | |
| outs = [] | |
| for layer in self.convs: | |
| x = layer(x) | |
| outs += [x] | |
| x = self.output_conv(x) | |
| x = torch.flatten(x, 1, -1) | |
| outs += [x] | |
| return outs | |
| def apply_weight_norm(self): | |
| def _apply_weight_norm(m): | |
| if isinstance(m, torch.nn.Conv2d): | |
| torch.nn.utils.weight_norm(m) | |
| self.apply(_apply_weight_norm) | |
| class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): | |
| def __init__( | |
| self, | |
| periods=[2, 3, 5, 7, 11], | |
| **kwargs, | |
| ): | |
| """Initialize HiFiGANMultiPeriodDiscriminator module. | |
| Args: | |
| periods (list): List of periods. | |
| discriminator_params (dict): Parameters for hifi-gan period discriminator module. | |
| The period parameter will be overwritten. | |
| """ | |
| super().__init__() | |
| self.discriminators = torch.nn.ModuleList() | |
| for period in periods: | |
| params = copy.deepcopy(kwargs) | |
| params["period"] = period | |
| self.discriminators += [HiFiGANPeriodDiscriminator(**params)] | |
| def forward(self, x): | |
| """Calculate forward propagation. | |
| Args: | |
| x (Tensor): Input noise signal (B, 1, T). | |
| Returns: | |
| List: List of list of each discriminator outputs, which consists of each layer output tensors. | |
| """ | |
| outs = [] | |
| for f in self.discriminators: | |
| outs += [f(x)] | |
| return outs | |