# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# flake8: noqa E722
from types import SimpleNamespace

import torch

try:
    from torch.nn import SyncBatchNorm
except ImportError:
    from torch.nn import BatchNorm2d as SyncBatchNorm
from torch import nn
from torch.nn import functional as F
from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock
from .misc import PartialSequential, ApplyNoise


class AdaptiveNorm(nn.Module):
    r"""Adaptive normalization layer. The layer first normalizes the input, then
    performs an affine transformation using parameters computed from the
    conditional inputs.

    Args:
        num_features (int): Number of channels in the input tensor.
        cond_dims (int): Number of channels in the conditional inputs.
        weight_norm_type (str): Type of weight normalization.
            ``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``.
        projection (bool): If ``True``, project the conditional input to gamma
            and beta using a fully connected layer, otherwise directly use
            the conditional input as gamma and beta.
        projection_bias (bool) If ``True``, use bias in the fully connected
            projection layer.
        separate_projection (bool): If ``True``, we will use two different
            layers for gamma and beta. Otherwise, we will use one layer. It
            matters only if you apply any weight norms to this layer.
        input_dim (int): Number of dimensions of the input tensor.
        activation_norm_type (str):
            Type of activation normalization.
            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
        activation_norm_params (obj, optional, default=None):
            Parameters of activation normalization.
            If not ``None``, ``activation_norm_params.__dict__`` will be used as
            keyword arguments when initializing activation normalization.
    """

    def __init__(self, num_features, cond_dims, weight_norm_type='',
                 projection=True,
                 projection_bias=True,
                 separate_projection=False,
                 input_dim=2,
                 activation_norm_type='instance',
                 activation_norm_params=None,
                 apply_noise=False,
                 add_bias=True,
                 input_scale=1.0,
                 init_gain=1.0):
        super().__init__()
        if activation_norm_params is None:
            activation_norm_params = SimpleNamespace(affine=False)
        self.norm = get_activation_norm_layer(num_features,
                                              activation_norm_type,
                                              input_dim,
                                              **vars(activation_norm_params))
        if apply_noise:
            self.noise_layer = ApplyNoise()
        else:
            self.noise_layer = None

        if projection:
            if separate_projection:
                self.fc_gamma = \
                    LinearBlock(cond_dims, num_features,
                                weight_norm_type=weight_norm_type,
                                bias=projection_bias)
                self.fc_beta = \
                    LinearBlock(cond_dims, num_features,
                                weight_norm_type=weight_norm_type,
                                bias=projection_bias)
            else:
                self.fc = LinearBlock(cond_dims, num_features * 2,
                                      weight_norm_type=weight_norm_type,
                                      bias=projection_bias)

        self.projection = projection
        self.separate_projection = separate_projection
        self.input_scale = input_scale
        self.add_bias = add_bias
        self.conditional = True
        self.init_gain = init_gain

    def forward(self, x, y, noise=None, **_kwargs):
        r"""Adaptive Normalization forward.

        Args:
            x (N x C1 x * tensor): Input tensor.
            y (N x C2 tensor): Conditional information.
        Returns:
            out (N x C1 x * tensor): Output tensor.
        """
        y = y * self.input_scale
        if self.projection:
            if self.separate_projection:
                gamma = self.fc_gamma(y)
                beta = self.fc_beta(y)
                for _ in range(x.dim() - gamma.dim()):
                    gamma = gamma.unsqueeze(-1)
                    beta = beta.unsqueeze(-1)
            else:
                y = self.fc(y)
                for _ in range(x.dim() - y.dim()):
                    y = y.unsqueeze(-1)
                gamma, beta = y.chunk(2, 1)
        else:
            for _ in range(x.dim() - y.dim()):
                y = y.unsqueeze(-1)
            gamma, beta = y.chunk(2, 1)
        if self.norm is not None:
            x = self.norm(x)
        if self.noise_layer is not None:
            x = self.noise_layer(x, noise=noise)
        if self.add_bias:
            x = torch.addcmul(beta, x, 1 + gamma)
            return x
        else:
            return x * (1 + gamma), beta.squeeze(3).squeeze(2)


class SpatiallyAdaptiveNorm(nn.Module):
    r"""Spatially Adaptive Normalization (SPADE) initialization.

    Args:
        num_features (int) : Number of channels in the input tensor.
        cond_dims (int or list of int) : List of numbers of channels
            in the input.
        num_filters (int): Number of filters in SPADE.
        kernel_size (int): Kernel size of the convolutional filters in
            the SPADE layer.
        weight_norm_type (str): Type of weight normalization.
            ``'none'``, ``'spectral'``, or ``'weight'``.
        separate_projection (bool): If ``True``, we will use two different
            layers for gamma and beta. Otherwise, we will use one layer. It
            matters only if you apply any weight norms to this layer.
        activation_norm_type (str):
            Type of activation normalization.
            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
            ``'layer'``,  ``'layer_2d'``, ``'group'``.
        activation_norm_params (obj, optional, default=None):
            Parameters of activation normalization.
            If not ``None``, ``activation_norm_params.__dict__`` will be used as
            keyword arguments when initializing activation normalization.
    """

    def __init__(self,
                 num_features,
                 cond_dims,
                 num_filters=128,
                 kernel_size=3,
                 weight_norm_type='',
                 separate_projection=False,
                 activation_norm_type='sync_batch',
                 activation_norm_params=None,
                 bias_only=False,
                 partial=False,
                 interpolation='nearest'):
        super().__init__()
        if activation_norm_params is None:
            activation_norm_params = SimpleNamespace(affine=False)
        padding = kernel_size // 2
        self.separate_projection = separate_projection
        self.mlps = nn.ModuleList()
        self.gammas = nn.ModuleList()
        self.betas = nn.ModuleList()
        self.bias_only = bias_only
        self.interpolation = interpolation

        # Make cond_dims a list.
        if type(cond_dims) != list:
            cond_dims = [cond_dims]

        # Make num_filters a list.
        if not isinstance(num_filters, list):
            num_filters = [num_filters] * len(cond_dims)
        else:
            assert len(num_filters) >= len(cond_dims)

        # Make partial a list.
        if not isinstance(partial, list):
            partial = [partial] * len(cond_dims)
        else:
            assert len(partial) >= len(cond_dims)

        for i, cond_dim in enumerate(cond_dims):
            mlp = []
            conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock
            sequential = PartialSequential if partial[i] else nn.Sequential

            if num_filters[i] > 0:
                mlp += [conv_block(cond_dim,
                                   num_filters[i],
                                   kernel_size,
                                   padding=padding,
                                   weight_norm_type=weight_norm_type,
                                   nonlinearity='relu')]
            mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i]

            if self.separate_projection:
                if partial[i]:
                    raise NotImplementedError(
                        'Separate projection not yet implemented for ' +
                        'partial conv')
                self.mlps.append(nn.Sequential(*mlp))
                self.gammas.append(
                    conv_block(mlp_ch, num_features,
                               kernel_size,
                               padding=padding,
                               weight_norm_type=weight_norm_type))
                self.betas.append(
                    conv_block(mlp_ch, num_features,
                               kernel_size,
                               padding=padding,
                               weight_norm_type=weight_norm_type))
            else:
                mlp += [conv_block(mlp_ch, num_features * 2, kernel_size,
                                   padding=padding,
                                   weight_norm_type=weight_norm_type)]
                self.mlps.append(sequential(*mlp))

        self.norm = get_activation_norm_layer(num_features,
                                              activation_norm_type,
                                              2,
                                              **vars(activation_norm_params))
        self.conditional = True

    def forward(self, x, *cond_inputs, **_kwargs):
        r"""Spatially Adaptive Normalization (SPADE) forward.

        Args:
            x (N x C1 x H x W tensor) : Input tensor.
            cond_inputs (list of tensors) : Conditional maps for SPADE.
        Returns:
            output (4D tensor) : Output tensor.
        """
        output = self.norm(x) if self.norm is not None else x
        for i in range(len(cond_inputs)):
            if cond_inputs[i] is None:
                continue
            label_map = F.interpolate(cond_inputs[i], size=x.size()[2:], mode=self.interpolation)
            if self.separate_projection:
                hidden = self.mlps[i](label_map)
                gamma = self.gammas[i](hidden)
                beta = self.betas[i](hidden)
            else:
                affine_params = self.mlps[i](label_map)
                gamma, beta = affine_params.chunk(2, dim=1)
            if self.bias_only:
                output = output + beta
            else:
                output = output * (1 + gamma) + beta
        return output


class DualAdaptiveNorm(nn.Module):
    def __init__(self,
                 num_features,
                 cond_dims,
                 projection_bias=True,
                 weight_norm_type='',
                 activation_norm_type='instance',
                 activation_norm_params=None,
                 apply_noise=False,
                 bias_only=False,
                 init_gain=1.0,
                 fc_scale=None,
                 is_spatial=None):
        super().__init__()
        if activation_norm_params is None:
            activation_norm_params = SimpleNamespace(affine=False)
        self.mlps = nn.ModuleList()
        self.gammas = nn.ModuleList()
        self.betas = nn.ModuleList()
        self.bias_only = bias_only

        # Make cond_dims a list.
        if type(cond_dims) != list:
            cond_dims = [cond_dims]

        if is_spatial is None:
            is_spatial = [False for _ in range(len(cond_dims))]
        self.is_spatial = is_spatial

        for cond_dim, this_is_spatial in zip(cond_dims, is_spatial):
            kwargs = dict(weight_norm_type=weight_norm_type,
                          bias=projection_bias,
                          init_gain=init_gain,
                          output_scale=fc_scale)
            if this_is_spatial:
                self.gammas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
                self.betas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
            else:
                self.gammas.append(LinearBlock(cond_dim, num_features, **kwargs))
                self.betas.append(LinearBlock(cond_dim, num_features, **kwargs))

        self.norm = get_activation_norm_layer(num_features,
                                              activation_norm_type,
                                              2,
                                              **vars(activation_norm_params))
        self.conditional = True

    def forward(self, x, *cond_inputs, **_kwargs):
        assert len(cond_inputs) == len(self.gammas)
        output = self.norm(x) if self.norm is not None else x
        for cond, gamma_layer, beta_layer in zip(cond_inputs, self.gammas, self.betas):
            if cond is None:
                continue
            gamma = gamma_layer(cond)
            beta = beta_layer(cond)
            if cond.dim() == 4 and gamma.shape != x.shape:
                gamma = F.interpolate(gamma, size=x.size()[2:], mode='bilinear')
                beta = F.interpolate(beta, size=x.size()[2:], mode='bilinear')
            elif cond.dim() == 2:
                gamma = gamma[:, :, None, None]
                beta = beta[:, :, None, None]
            if self.bias_only:
                output = output + beta
            else:
                output = output * (1 + gamma) + beta
        return output


class HyperSpatiallyAdaptiveNorm(nn.Module):
    r"""Spatially Adaptive Normalization (SPADE) initialization.

    Args:
        num_features (int) : Number of channels in the input tensor.
        cond_dims (int or list of int) : List of numbers of channels
            in the conditional input.
        num_filters (int): Number of filters in SPADE.
        kernel_size (int): Kernel size of the convolutional filters in
            the SPADE layer.
        weight_norm_type (str): Type of weight normalization.
            ``'none'``, ``'spectral'``, or ``'weight'``.
        activation_norm_type (str):
            Type of activation normalization.
            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
            ``'layer'``,  ``'layer_2d'``, ``'group'``.
        is_hyper (bool): Whether to use hyper SPADE.
    """

    def __init__(self, num_features, cond_dims,
                 num_filters=0, kernel_size=3,
                 weight_norm_type='',
                 activation_norm_type='sync_batch', is_hyper=True):
        super().__init__()
        padding = kernel_size // 2
        self.mlps = nn.ModuleList()
        if type(cond_dims) != list:
            cond_dims = [cond_dims]

        for i, cond_dim in enumerate(cond_dims):
            mlp = []
            if not is_hyper or (i != 0):
                if num_filters > 0:
                    mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size,
                                        padding=padding,
                                        weight_norm_type=weight_norm_type,
                                        nonlinearity='relu')]
                mlp_ch = cond_dim if num_filters == 0 else num_filters
                mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size,
                                    padding=padding,
                                    weight_norm_type=weight_norm_type)]
                mlp = nn.Sequential(*mlp)
            else:
                if num_filters > 0:
                    raise ValueError('Multi hyper layer not supported yet.')
                mlp = HyperConv2d(padding=padding)
            self.mlps.append(mlp)

        self.norm = get_activation_norm_layer(num_features,
                                              activation_norm_type,
                                              2,
                                              affine=False)

        self.conditional = True

    def forward(self, x, *cond_inputs,
                norm_weights=(None, None), **_kwargs):
        r"""Spatially Adaptive Normalization (SPADE) forward.

        Args:
            x (4D tensor) : Input tensor.
            cond_inputs (list of tensors) : Conditional maps for SPADE.
            norm_weights (5D tensor or list of tensors): conv weights or
            [weights, biases].
        Returns:
            output (4D tensor) : Output tensor.
        """
        output = self.norm(x)
        for i in range(len(cond_inputs)):
            if cond_inputs[i] is None:
                continue
            if type(cond_inputs[i]) == list:
                cond_input, mask = cond_inputs[i]
                mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=False)
            else:
                cond_input = cond_inputs[i]
                mask = None
            label_map = F.interpolate(cond_input, size=x.size()[2:])
            if norm_weights is None or norm_weights[0] is None or i != 0:
                affine_params = self.mlps[i](label_map)
            else:
                affine_params = self.mlps[i](label_map,
                                             conv_weights=norm_weights)
            gamma, beta = affine_params.chunk(2, dim=1)
            if mask is not None:
                gamma = gamma * (1 - mask)
                beta = beta * (1 - mask)
            output = output * (1 + gamma) + beta
        return output


class LayerNorm2d(nn.Module):
    r"""Layer Normalization as introduced in
    https://arxiv.org/abs/1607.06450.
    This is the usual way to apply layer normalization in CNNs.
    Note that unlike the pytorch implementation which applies per-element
    scale and bias, here it applies per-channel scale and bias, similar to
    batch/instance normalization.

    Args:
        num_features (int): Number of channels in the input tensor.
        eps (float, optional, default=1e-5): a value added to the
            denominator for numerical stability.
        affine (bool, optional, default=False): If ``True``, performs
            affine transformation after normalization.
    """

    def __init__(self, num_features, eps=1e-5, channel_only=False, affine=True):
        super(LayerNorm2d, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        self.channel_only = channel_only

        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).fill_(1.0))
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        r"""

        Args:
            x (tensor): Input tensor.
        """
        shape = [-1] + [1] * (x.dim() - 1)
        if self.channel_only:
            mean = x.mean(1, keepdim=True)
            std = x.std(1, keepdim=True)
        else:
            mean = x.view(x.size(0), -1).mean(1).view(*shape)
            std = x.view(x.size(0), -1).std(1).view(*shape)

        x = (x - mean) / (std + self.eps)

        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)
            x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x


class ScaleNorm(nn.Module):
    r"""Scale normalization:
    "Transformers without Tears: Improving the Normalization of Self-Attention"
    Modified from:
    https://github.com/tnq177/transformers_without_tears
    """

    def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
        super().__init__()
        # scale = num_features ** 0.5
        if learned_scale:
            self.scale = nn.Parameter(torch.tensor(1.))
        else:
            self.scale = 1.
        # self.num_features = num_features
        self.dim = dim
        self.eps = eps
        self.learned_scale = learned_scale

    def forward(self, x):
        # noinspection PyArgumentList
        scale = self.scale * torch.rsqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
        return x * scale

    def extra_repr(self):
        s = 'learned_scale={learned_scale}'
        return s.format(**self.__dict__)


class PixelNorm(ScaleNorm):
    def __init__(self, learned_scale=False, eps=1e-5, **_kwargs):
        super().__init__(1, learned_scale, eps)


class SplitMeanStd(nn.Module):
    def __init__(self, num_features, eps=1e-5, **kwargs):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.multiple_outputs = True

    def forward(self, x):
        b, c, h, w = x.size()
        mean = x.view(b, c, -1).mean(-1)[:, :, None, None]
        var = x.view(b, c, -1).var(-1)[:, :, None, None]
        std = torch.sqrt(var + self.eps)

        # x = (x - mean) / std
        return x, torch.cat((mean, std), dim=1)


class ScaleNorm(nn.Module):
    r"""Scale normalization:
    "Transformers without Tears: Improving the Normalization of Self-Attention"
    Modified from:
    https://github.com/tnq177/transformers_without_tears
    """

    def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
        super().__init__()
        # scale = num_features ** 0.5
        if learned_scale:
            self.scale = nn.Parameter(torch.tensor(1.))
        else:
            self.scale = 1.
        # self.num_features = num_features
        self.dim = dim
        self.eps = eps
        self.learned_scale = learned_scale

    def forward(self, x):
        # noinspection PyArgumentList
        scale = self.scale * torch.rsqrt(
            torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
        return x * scale

    def extra_repr(self):
        s = 'learned_scale={learned_scale}'
        return s.format(**self.__dict__)


class PixelLayerNorm(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.norm = nn.LayerNorm(*args, **kwargs)

    def forward(self, x):
        if x.dim() == 4:
            b, c, h, w = x.shape
            return self.norm(x.permute(0, 2, 3, 1).view(-1, c)).view(b, h, w, c).permute(0, 3, 1, 2)
        else:
            return self.norm(x)


def get_activation_norm_layer(num_features, norm_type, input_dim, **norm_params):
    r"""Return an activation normalization layer.

    Args:
        num_features (int): Number of feature channels.
        norm_type (str):
            Type of activation normalization.
            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
        input_dim (int): Number of input dimensions.
        norm_params: Arbitrary keyword arguments that will be used to
            initialize the activation normalization.
    """
    input_dim = max(input_dim, 1)  # Norm1d works with both 0d and 1d inputs

    if norm_type == 'none' or norm_type == '':
        norm_layer = None
    elif norm_type == 'batch':
        norm = getattr(nn, 'BatchNorm%dd' % input_dim)
        norm_layer = norm(num_features, **norm_params)
    elif norm_type == 'instance':
        affine = norm_params.pop('affine', True)  # Use affine=True by default
        norm = getattr(nn, 'InstanceNorm%dd' % input_dim)
        norm_layer = norm(num_features, affine=affine, **norm_params)
    elif norm_type == 'sync_batch':
        norm_layer = SyncBatchNorm(num_features, **norm_params)
    elif norm_type == 'layer':
        norm_layer = nn.LayerNorm(num_features, **norm_params)
    elif norm_type == 'layer_2d':
        norm_layer = LayerNorm2d(num_features, **norm_params)
    elif norm_type == 'pixel_layer':
        elementwise_affine = norm_params.pop('affine', True)  # Use affine=True by default
        norm_layer = PixelLayerNorm(num_features, elementwise_affine=elementwise_affine, **norm_params)
    elif norm_type == 'scale':
        norm_layer = ScaleNorm(**norm_params)
    elif norm_type == 'pixel':
        norm_layer = PixelNorm(**norm_params)
        import imaginaire.config
        if imaginaire.config.USE_JIT:
            norm_layer = torch.jit.script(norm_layer)
    elif norm_type == 'group':
        num_groups = norm_params.pop('num_groups', 4)
        norm_layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups, **norm_params)
    elif norm_type == 'adaptive':
        norm_layer = AdaptiveNorm(num_features, **norm_params)
    elif norm_type == 'dual_adaptive':
        norm_layer = DualAdaptiveNorm(num_features, **norm_params)
    elif norm_type == 'spatially_adaptive':
        if input_dim != 2:
            raise ValueError('Spatially adaptive normalization layers '
                             'only supports 2D input')
        norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params)
    elif norm_type == 'hyper_spatially_adaptive':
        if input_dim != 2:
            raise ValueError('Spatially adaptive normalization layers '
                             'only supports 2D input')
        norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params)
    else:
        raise ValueError('Activation norm layer %s '
                         'is not recognized' % norm_type)
    return norm_layer