import torch.nn as nn

__all__ = ['build_decoder']


def build_decoder(config):
    # rec decoder
    from .abinet_decoder import ABINetDecoder
    from .aster_decoder import ASTERDecoder
    from .cdistnet_decoder import CDistNetDecoder
    from .cppd_decoder import CPPDDecoder
    from .rctc_decoder import RCTCDecoder
    from .ctc_decoder import CTCDecoder
    from .dan_decoder import DANDecoder
    from .igtr_decoder import IGTRDecoder
    from .lister_decoder import LISTERDecoder
    from .lpv_decoder import LPVDecoder
    from .mgp_decoder import MGPDecoder
    from .nrtr_decoder import NRTRDecoder
    from .parseq_decoder import PARSeqDecoder
    from .robustscanner_decoder import RobustScannerDecoder
    from .sar_decoder import SARDecoder
    from .smtr_decoder import SMTRDecoder
    from .smtr_decoder_nattn import SMTRDecoderNumAttn
    from .srn_decoder import SRNDecoder
    from .visionlan_decoder import VisionLANDecoder
    from .matrn_decoder import MATRNDecoder
    from .cam_decoder import CAMDecoder
    from .ote_decoder import OTEDecoder
    from .bus_decoder import BUSDecoder

    support_dict = [
        'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder',
        'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder',
        'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder',
        'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder',
        'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder',
        'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder'
    ]

    module_name = config.pop('name')
    assert module_name in support_dict, Exception(
        'decoder only support {}'.format(support_dict))
    module_class = eval(module_name)(**config)
    return module_class


class GTCDecoder(nn.Module):

    def __init__(self,
                 in_channels,
                 gtc_decoder,
                 ctc_decoder,
                 detach=True,
                 infer_gtc=False,
                 out_channels=0,
                 **kwargs):
        super(GTCDecoder, self).__init__()
        self.detach = detach
        self.infer_gtc = infer_gtc
        if infer_gtc:
            gtc_decoder['out_channels'] = out_channels[0]
            ctc_decoder['out_channels'] = out_channels[1]
            gtc_decoder['in_channels'] = in_channels
            ctc_decoder['in_channels'] = in_channels
            self.gtc_decoder = build_decoder(gtc_decoder)
        else:
            ctc_decoder['in_channels'] = in_channels
            ctc_decoder['out_channels'] = out_channels
        self.ctc_decoder = build_decoder(ctc_decoder)

    def forward(self, x, data=None):
        ctc_pred = self.ctc_decoder(x.detach() if self.detach else x,
                                    data=data)
        if self.training or self.infer_gtc:
            gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2),
                                        data=data)
            return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
        else:
            return ctc_pred


class GTCDecoderTwo(nn.Module):

    def __init__(self,
                 in_channels,
                 gtc_decoder,
                 ctc_decoder,
                 infer_gtc=False,
                 out_channels=0,
                 **kwargs):
        super(GTCDecoderTwo, self).__init__()
        self.infer_gtc = infer_gtc
        gtc_decoder['out_channels'] = out_channels[0]
        ctc_decoder['out_channels'] = out_channels[1]
        gtc_decoder['in_channels'] = in_channels
        ctc_decoder['in_channels'] = in_channels
        self.gtc_decoder = build_decoder(gtc_decoder)
        self.ctc_decoder = build_decoder(ctc_decoder)

    def forward(self, x, data=None):
        x_ctc, x_gtc = x
        ctc_pred = self.ctc_decoder(x_ctc, data=data)
        if self.training or self.infer_gtc:
            gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2),
                                        data=data)
            return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
        else:
            return ctc_pred