import collections
import logging

import torch
from torch.nn import BCEWithLogitsLoss, Dropout, Linear
from transformers import AutoModel, XLNetModel, LongformerConfig
from transformers.models.longformer.modeling_longformer import LongformerEncoder
from huggingface_hub import PyTorchModelHubMixin
from LongLAT.models.utils import initial_code_title_vectors

logger = logging.getLogger("lwat")


class CodingModelConfig:
    def __init__(self,
                 transformer_model_name_or_path,
                 transformer_tokenizer_name,
                 transformer_layer_update_strategy,
                 num_chunks,
                 max_seq_length,
                 dropout,
                 dropout_att,
                 d_model,
                 label_dictionary,
                 num_labels,
                 use_code_representation,
                 code_max_seq_length,
                 code_batch_size,
                 multi_head_att,
                 chunk_att,
                 linear_init_mean,
                 linear_init_std,
                 document_pooling_strategy,
                 multi_head_chunk_attention,
                 num_hidden_layers):
        super(CodingModelConfig, self).__init__()
        self.transformer_model_name_or_path = transformer_model_name_or_path
        self.transformer_tokenizer_name = transformer_tokenizer_name
        self.transformer_layer_update_strategy = transformer_layer_update_strategy
        self.num_chunks = num_chunks
        self.max_seq_length = max_seq_length
        self.dropout = dropout
        self.dropout_att = dropout_att
        self.d_model = d_model
        # labels_dictionary is a dataframe with columns: icd9_code, long_title
        self.label_dictionary = label_dictionary
        self.num_labels = num_labels
        self.use_code_representation = use_code_representation
        self.code_max_seq_length = code_max_seq_length
        self.code_batch_size = code_batch_size
        self.multi_head_att = multi_head_att
        self.chunk_att = chunk_att
        self.linear_init_mean = linear_init_mean
        self.linear_init_std = linear_init_std
        self.document_pooling_strategy = document_pooling_strategy
        self.multi_head_chunk_attention = multi_head_chunk_attention
        self.num_hidden_layers = num_hidden_layers


class LableWiseAttentionLayer(torch.nn.Module):
    def __init__(self, coding_model_config, args):
        super(LableWiseAttentionLayer, self).__init__()

        self.config = coding_model_config
        self.args = args

        # layers
        self.l1_linear = torch.nn.Linear(self.config.d_model,
                                         self.config.d_model, bias=False)
        self.tanh = torch.nn.Tanh()
        self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
        self.softmax = torch.nn.Softmax(dim=1)

        # Mean pooling last hidden state of code title from transformer model as the initial code vectors
        self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)

    def _init_linear_weights(self, mean, std):
        # normalize the l1 weights
        torch.nn.init.normal_(self.l1_linear.weight, mean, std)
        if self.l1_linear.bias is not None:
            self.l1_linear.bias.data.fill_(0)
        # initialize the l2
        if self.config.use_code_representation:
            code_vectors = initial_code_title_vectors(self.config.label_dictionary,
                                                      self.config.transformer_model_name_or_path,
                                                      self.config.transformer_tokenizer_name
                                                      if self.config.transformer_tokenizer_name
                                                      else self.config.transformer_model_name_or_path,
                                                      self.config.code_max_seq_length,
                                                      self.config.code_batch_size,
                                                      self.config.d_model,
                                                      self.args.device)

            self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
        torch.nn.init.normal_(self.l2_linear.weight, mean, std)
        if self.l2_linear.bias is not None:
            self.l2_linear.bias.data.fill_(0)

    def forward(self, x):
        # input: (batch_size, max_seq_length, transformer_hidden_size)
        # output: (batch_size, max_seq_length, transformer_hidden_size)
        # Z = Tan(WH)
        l1_output = self.tanh(self.l1_linear(x))
        # softmax(UZ)
        # l2_linear output shape: (batch_size, max_seq_length, num_labels)
        # attention_weight shape: (batch_size, num_labels, max_seq_length)
        attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
        # attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
        attention_output = torch.matmul(attention_weight, x)

        return attention_output, attention_weight

class ChunkAttentionLayer(torch.nn.Module):
    def __init__(self, coding_model_config, args):
        super(ChunkAttentionLayer, self).__init__()

        self.config = coding_model_config
        self.args = args

        # layers
        self.l1_linear = torch.nn.Linear(self.config.d_model,
                                         self.config.d_model, bias=False)
        self.tanh = torch.nn.Tanh()
        self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
        self.softmax = torch.nn.Softmax(dim=1)

        self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)

    def _init_linear_weights(self, mean, std):
        # initialize the l1
        torch.nn.init.normal_(self.l1_linear.weight, mean, std)
        if self.l1_linear.bias is not None:
            self.l1_linear.bias.data.fill_(0)
        # initialize the l2
        torch.nn.init.normal_(self.l2_linear.weight, mean, std)
        if self.l2_linear.bias is not None:
            self.l2_linear.bias.data.fill_(0)

    def forward(self, x):
        # input: (batch_size, num_chunks, transformer_hidden_size)
        # output: (batch_size, num_chunks, transformer_hidden_size)
        # Z = Tan(WH)
        l1_output = self.tanh(self.l1_linear(x))
        # softmax(UZ)
        # l2_linear output shape: (batch_size, num_chunks, 1)
        # attention_weight shape: (batch_size, 1, num_chunks)
        attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
        # attention_output shpae: (batch_size, 1, transformer_hidden_size)
        attention_output = torch.matmul(attention_weight, x)

        return attention_output, attention_weight

# define the model class
class CodingModel(torch.nn.Module, PyTorchModelHubMixin):
    def __init__(self, coding_model_config, args, **kwargs):
        super(CodingModel, self).__init__()
        self.coding_model_config = coding_model_config
        self.args = args
        # layers
        self.transformer_layer = AutoModel.from_pretrained(self.coding_model_config.transformer_model_name_or_path)
        if isinstance(self.transformer_layer, XLNetModel):
            self.transformer_layer.config.use_mems_eval = False
        self.dropout = Dropout(p=self.coding_model_config.dropout)

        if self.coding_model_config.multi_head_att:
            # initial multi head attention according to the num_chunks
            self.label_wise_attention_layer = torch.nn.ModuleList(
                [LableWiseAttentionLayer(coding_model_config, args)
                 for _ in range(self.coding_model_config.num_chunks)])
        else:
            self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
        self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)

        # initial chunk attention
        if self.coding_model_config.chunk_att:
            if self.coding_model_config.multi_head_chunk_attention:
                self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
                                                                  for _ in range(self.coding_model_config.num_labels)])
            else:
                self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)

            self.classifier_layer = Linear(self.coding_model_config.d_model,
                                           self.coding_model_config.num_labels)
        else:
            if self.coding_model_config.document_pooling_strategy == "flat":
                self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
                                       self.coding_model_config.num_labels)
            else: # max or mean pooling
                self.classifier_layer = Linear(self.coding_model_config.d_model,
                                               self.coding_model_config.num_labels)
        self.sigmoid = torch.nn.Sigmoid()

        if self.coding_model_config.transformer_layer_update_strategy == "no":
            self.freeze_all_transformer_layers()
        elif self.coding_model_config.transformer_layer_update_strategy == "last":
            self.freeze_all_transformer_layers()
            self.unfreeze_transformer_last_layers()

        # initialize the weights of classifier
        self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)

    def _init_linear_weights(self, mean, std):
        torch.nn.init.normal_(self.classifier_layer.weight, mean, std)

    def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
        # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
        # (global_attention_mask + 1) => 1 for local attention, 2 for global attention
        # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
        if attention_mask is not None:
            attention_mask = attention_mask * (global_attention_mask + 1)
        else:
            # simply use `global_attention_mask` as `attention_mask`
            # if no `attention_mask` is given
            attention_mask = global_attention_mask + 1
        return attention_mask
    
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
        # input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
        # labels shape: (batch_size, num_labels)
        transformer_output = []

        # pass chunk by chunk into transformer layer in the batches.
        # input (batch_size, sequence_length)
        for i in range(self.coding_model_config.num_chunks):
            l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
                                               attention_mask=attention_mask[:, i, :],
                                               token_type_ids=token_type_ids[:, i, :])
            # output hidden state shape: (batch_size, sequence_length, hidden_size)
            transformer_output.append(l1_output[0])

        # transpose back chunk and batch size dimensions
        transformer_output = torch.stack(transformer_output)
        transformer_output = transformer_output.transpose(0, 1)
        # dropout transformer output
        l2_dropout = self.dropout(transformer_output)

        config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
        config.num_labels =5
        config.num_hidden_layers = 2
        # self.coding_model_config.num_hidden_layers
        config.hidden_size = self.coding_model_config.d_model
        config.attention_window  = [512, 512]
        longformer_layer = LongformerEncoder(config)
        # longformer_layer = longformer_layer(config)
        l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
        attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
        # is_index_masked = attention_mask < 0

        global_attention_mask = torch.zeros_like(attention_mask)
        # global attention on cls token
        global_attention_positions = [0, 512, 1024, 1536, 2048, 2560, 3072, 3584, 4095]
        global_attention_mask[:, global_attention_positions] = 1

        if global_attention_mask is not None:
            attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
        output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True)
        l2_dropout = self.dropout_att(output[0])
        l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)

        # Label-wise attention layers
        # output: (batch_size, num_chunks, num_labels, hidden_size)
        attention_output = []
        attention_weights = []

        for i in range(self.coding_model_config.num_chunks):
            # input: (batch_size, max_seq_length, transformer_hidden_size)
            if self.coding_model_config.multi_head_att:
                attention_layer = self.label_wise_attention_layer[i]
            else:
                attention_layer = self.label_wise_attention_layer
            l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
            # l3_attention shape: (batch_size, num_labels, hidden_size)
            # attention_weight: (batch_size, num_labels, max_seq_length)
            attention_output.append(l3_attention)
            attention_weights.append(attention_weight)

        attention_output = torch.stack(attention_output)
        attention_output = attention_output.transpose(0, 1)
        attention_weights = torch.stack(attention_weights)
        attention_weights = attention_weights.transpose(0, 1)

        l3_dropout = self.dropout_att(attention_output)

        if self.coding_model_config.chunk_att:
            # Chunk attention layers
            # output: (batch_size, num_labels, hidden_size)
            chunk_attention_output = []
            chunk_attention_weights = []

            for i in range(self.coding_model_config.num_labels):
                if self.coding_model_config.multi_head_chunk_attention:
                    chunk_attention = self.chunk_attention_layer[i]
                else:
                    chunk_attention = self.chunk_attention_layer
                l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
                chunk_attention_output.append(l4_chunk_attention.squeeze(dim=1))
                chunk_attention_weights.append(l4_chunk_attention_weights.squeeze(dim=1))

            chunk_attention_output = torch.stack(chunk_attention_output)
            chunk_attention_output = chunk_attention_output.transpose(0, 1)
            chunk_attention_weights = torch.stack(chunk_attention_weights)
            chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
            # output shape: (batch_size, num_labels, hidden_size)
            l4_dropout = self.dropout_att(chunk_attention_output)
        else:
            # output shape: (batch_size, num_labels, hidden_size*num_chunks)
            l4_dropout = l3_dropout.transpose(1, 2)
            if self.coding_model_config.document_pooling_strategy == "flat":
                # Flatten layer. concatenate representation by labels
                l4_dropout = torch.flatten(l4_dropout, start_dim=2)
            elif self.coding_model_config.document_pooling_strategy == "max":
                l4_dropout = torch.amax(l4_dropout, 2)
            elif self.coding_model_config.document_pooling_strategy == "mean":
                l4_dropout = torch.mean(l4_dropout, 2)
            else:
                raise ValueError("Not supported pooling strategy")

        # classifier layer
        # each code has a binary linear formula
        logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)

        loss_fct = BCEWithLogitsLoss()
        loss = loss_fct(logits, targets)

        return {
            "loss": loss,
            "logits": logits,
            "label_attention_weights": attention_weights,
            "chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
        }

    def freeze_all_transformer_layers(self):
        """
        Freeze all layer weight parameters. They will not be updated during training.
        """
        for param in self.transformer_layer.parameters():
            param.requires_grad = False

    def unfreeze_all_transformer_layers(self):
        """
        Unfreeze all layers weight parameters. They will be updated during training.
        """
        for param in self.transformer_layer.parameters():
            param.requires_grad = True

    def unfreeze_transformer_last_layers(self):
        for name, param in self.transformer_layer.named_parameters():
            if "layer.11" in name or "pooler" in name:
                param.requires_grad = True