from collections import OrderedDict
import torch
import torch.nn as nn


class IdentityLayer(nn.Module):

    def __init__(self):
        super(IdentityLayer, self).__init__()

    def forward(self, x):
        return x

    @staticmethod
    def is_zero_layer():
        return False


class ZeroLayer(nn.Module):

    def __init__(self, stride):
        super(ZeroLayer, self).__init__()
        self.stride = stride

    def forward(self, x):
        n, c, h, w = x.shape
        h //= self.stride[0]
        w //= self.stride[1]
        device = x.device
        padding = torch.zeros(n, c, h, w, device=device, requires_grad=False)
        return padding

    @staticmethod
    def is_zero_layer():
        return True

    def get_flops(self, x):
        return 0, self.forward(x)


def get_same_padding(kernel_size):
    if isinstance(kernel_size, tuple):
        assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
        p1 = get_same_padding(kernel_size[0])
        p2 = get_same_padding(kernel_size[1])
        return p1, p2
    assert isinstance(kernel_size,
                      int), 'kernel size should be either `int` or `tuple`'
    assert kernel_size % 2 > 0, 'kernel size should be odd number'
    return kernel_size // 2


class MBInvertedConvLayer(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=(1, 1),
                 expand_ratio=6,
                 mid_channels=None):
        super(MBInvertedConvLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.mid_channels = mid_channels

        feature_dim = round(
            self.in_channels *
            self.expand_ratio) if mid_channels is None else mid_channels
        if self.expand_ratio == 1:
            self.inverted_bottleneck = None
        else:
            self.inverted_bottleneck = nn.Sequential(
                OrderedDict([
                    ('conv',
                     nn.Conv2d(self.in_channels,
                               feature_dim,
                               1,
                               1,
                               0,
                               bias=False)),
                    ('bn', nn.BatchNorm2d(feature_dim)),
                    ('act', nn.ReLU6(inplace=True)),
                ]))
        pad = get_same_padding(self.kernel_size)
        self.depth_conv = nn.Sequential(
            OrderedDict([
                ('conv',
                 nn.Conv2d(feature_dim,
                           feature_dim,
                           kernel_size,
                           stride,
                           pad,
                           groups=feature_dim,
                           bias=False)),
                ('bn', nn.BatchNorm2d(feature_dim)),
                ('act', nn.ReLU6(inplace=True)),
            ]))
        self.point_conv = nn.Sequential(
            OrderedDict([
                ('conv',
                 nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
                ('bn', nn.BatchNorm2d(out_channels)),
            ]))

    def forward(self, x):
        if self.inverted_bottleneck:
            x = self.inverted_bottleneck(x)
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x

    @staticmethod
    def is_zero_layer():
        return False


def conv_func_by_name(name):
    name2ops = {
        'Identity': lambda in_C, out_C, S: IdentityLayer(),
        'Zero': lambda in_C, out_C, S: ZeroLayer(stride=S),
    }
    name2ops.update({
        '3x3_MBConv1':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
        '3x3_MBConv2':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
        '3x3_MBConv3':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
        '3x3_MBConv4':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
        '3x3_MBConv5':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
        '3x3_MBConv6':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
        #######################################################################################
        '5x5_MBConv1':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
        '5x5_MBConv2':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
        '5x5_MBConv3':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
        '5x5_MBConv4':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
        '5x5_MBConv5':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
        '5x5_MBConv6':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
        #######################################################################################
        '7x7_MBConv1':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
        '7x7_MBConv2':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
        '7x7_MBConv3':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
        '7x7_MBConv4':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
        '7x7_MBConv5':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
        '7x7_MBConv6':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
    })
    return name2ops[name]


def build_candidate_ops(candidate_ops, in_channels, out_channels, stride,
                        ops_order):
    if candidate_ops is None:
        raise ValueError('please specify a candidate set')

    name2ops = {
        'Identity':
        lambda in_C, out_C, S: IdentityLayer(in_C, out_C, ops_order=ops_order),
        'Zero':
        lambda in_C, out_C, S: ZeroLayer(stride=S),
    }
    # add MBConv layers
    name2ops.update({
        '3x3_MBConv1':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
        '3x3_MBConv2':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
        '3x3_MBConv3':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
        '3x3_MBConv4':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
        '3x3_MBConv5':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
        '3x3_MBConv6':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
        #######################################################################################
        '5x5_MBConv1':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
        '5x5_MBConv2':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
        '5x5_MBConv3':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
        '5x5_MBConv4':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
        '5x5_MBConv5':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
        '5x5_MBConv6':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
        #######################################################################################
        '7x7_MBConv1':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
        '7x7_MBConv2':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
        '7x7_MBConv3':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
        '7x7_MBConv4':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
        '7x7_MBConv5':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
        '7x7_MBConv6':
        lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
    })

    return [
        name2ops[name](in_channels, out_channels, stride)
        for name in candidate_ops
    ]


class MobileInvertedResidualBlock(nn.Module):

    def __init__(self, mobile_inverted_conv, shortcut):
        super(MobileInvertedResidualBlock, self).__init__()
        self.mobile_inverted_conv = mobile_inverted_conv
        self.shortcut = shortcut

    def forward(self, x):
        if self.mobile_inverted_conv.is_zero_layer():
            res = x
        elif self.shortcut is None or self.shortcut.is_zero_layer():
            res = self.mobile_inverted_conv(x)
        else:
            conv_x = self.mobile_inverted_conv(x)
            skip_x = self.shortcut(x)
            res = skip_x + conv_x
        return res


class AutoSTREncoder(nn.Module):

    def __init__(self,
                 in_channels,
                 out_dim=256,
                 with_lstm=True,
                 stride_stages='[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]',
                 n_cell_stages=[3, 3, 3, 3, 3],
                 conv_op_ids=[5, 5, 5, 5, 5, 5, 5, 6, 6, 5, 4, 3, 4, 6, 6],
                 **kwargs):
        super().__init__()
        self.first_conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      32,
                      kernel_size=(3, 3),
                      stride=1,
                      padding=1,
                      bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True))
        stride_stages = eval(stride_stages)
        width_stages = [32, 64, 128, 256, 512]
        conv_candidates = [
            '5x5_MBConv1', '5x5_MBConv3', '5x5_MBConv6', '3x3_MBConv1',
            '3x3_MBConv3', '3x3_MBConv6', 'Zero'
        ]

        assert len(conv_op_ids) == sum(n_cell_stages)
        blocks = []
        input_channel = 32
        for width, n_cell, s in zip(width_stages, n_cell_stages,
                                    stride_stages):
            for i in range(n_cell):
                if i == 0:
                    stride = s
                else:
                    stride = (1, 1)
                block_i = len(blocks)
                conv_op = conv_func_by_name(
                    conv_candidates[conv_op_ids[block_i]])(input_channel,
                                                           width, stride)
                if stride == (1, 1) and input_channel == width:
                    shortcut = IdentityLayer()
                else:
                    shortcut = None
                inverted_residual_block = MobileInvertedResidualBlock(
                    conv_op, shortcut)
                blocks.append(inverted_residual_block)
                input_channel = width
        self.out_channels = input_channel

        self.blocks = nn.ModuleList(blocks)

        # with_lstm = False
        self.with_lstm = with_lstm
        if with_lstm:
            self.rnn = nn.LSTM(input_channel,
                               out_dim // 2,
                               bidirectional=True,
                               num_layers=2,
                               batch_first=True)
            self.out_channels = out_dim

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.first_conv(x)
        for block in self.blocks:
            x = block(x)
        cnn_feat = x.squeeze(dim=2)
        cnn_feat = cnn_feat.transpose(2, 1)
        if self.with_lstm:
            rnn_feat, _ = self.rnn(cnn_feat)
            return rnn_feat
        else:
            return cnn_feat