import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v1.keras.backend as K1
tf1.disable_eager_execution()

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Activation, Dense, Lambda, Layer, Concatenate

def get_TrR_weights(filename):
  weights = [np.squeeze(w) for w in np.load(filename, allow_pickle=True)]
  # remove weights for beta-beta pairing
  del weights[-4:-2]
  return weights

def get_TrR(blocks=12, trainable=False, weights=None, name="TrR"):
  ex = {"trainable":trainable}
  # custom layer(s)
  class PSSM(Layer):
    # modified from MRF to only output tiled 1D features
    def __init__(self, diag=0.4, use_entropy=False):
      super(PSSM, self).__init__()
      self.diag = diag
      self.use_entropy = use_entropy
    def call(self, inputs):
      x,y = inputs
      _,_,L,A = [tf.shape(y)[k] for k in range(4)]
      with tf.name_scope('1d_features'):
        # sequence
        x_i = x[0,0,:,:20]
        # pssm
        f_i = y[0,0]
        # entropy
        if self.use_entropy:
          h_i = K.sum(-f_i * K.log(f_i + 1e-8), axis=-1, keepdims=True)
        else:
          h_i = tf.zeros((L,1))
        # tile and combined 1D features
        feat_1D = tf.concat([x_i,f_i,h_i], axis=-1)
        feat_1D_tile_A = tf.tile(feat_1D[:,None,:], [1,L,1])
        feat_1D_tile_B = tf.tile(feat_1D[None,:,:], [L,1,1])

      with tf.name_scope('2d_features'):
        ic = self.diag * tf.eye(L*A)
        ic = tf.reshape(ic,(L,A,L,A))
        ic = tf.transpose(ic,(0,2,1,3))
        ic = tf.reshape(ic,(L,L,A*A))
        i0 = tf.zeros([L,L,1])
        feat_2D = tf.concat([ic,i0], axis=-1)

      feat = tf.concat([feat_1D_tile_A, feat_1D_tile_B, feat_2D],axis=-1)
      return tf.reshape(feat, [1,L,L,442+2*42])
      
  class instance_norm(Layer):
    def __init__(self, axes=(1,2),trainable=True):
      super(instance_norm, self).__init__()
      self.axes = axes
      self.trainable = trainable
    def build(self, input_shape):
      self.beta  = self.add_weight(name='beta',shape=(input_shape[-1],),
                                  initializer='zeros',trainable=self.trainable)
      self.gamma = self.add_weight(name='gamma',shape=(input_shape[-1],),
                                  initializer='ones',trainable=self.trainable)
    def call(self, inputs):
      mean, variance = tf.nn.moments(inputs, self.axes, keepdims=True)
      return tf.nn.batch_normalization(inputs, mean, variance, self.beta, self.gamma, 1e-6)

  ## INPUT ##
  inputs = Input((None,None,21),batch_size=1)
  A = PSSM()([inputs,inputs])
  A = Dense(64, **ex)(A)
  A = instance_norm(**ex)(A)
  A = Activation("elu")(A)

  ## RESNET ##
  def resnet(X, dilation=1, filters=64, win=3):
    Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(X)
    Y = instance_norm(**ex)(Y)
    Y = Activation("elu")(Y)
    Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(Y)
    Y = instance_norm(**ex)(Y)
    return Activation("elu")(X+Y)

  for _ in range(blocks):
    for dilation in [1,2,4,8,16]:
      A = resnet(A, dilation)
  A = resnet(A, dilation=1)
  
  ## OUTPUT ##
  A_input   = Input((None,None,64))
  p_theta   = Dense(25, activation="softmax", **ex)(A_input)
  p_phi     = Dense(13, activation="softmax", **ex)(A_input)
  A_sym     = Lambda(lambda x: (x + tf.transpose(x,[0,2,1,3]))/2)(A_input)
  p_dist    = Dense(37, activation="softmax", **ex)(A_sym)
  p_omega   = Dense(25, activation="softmax", **ex)(A_sym)
  A_model   = Model(A_input,Concatenate()([p_theta,p_phi,p_dist,p_omega]))

  ## MODEL ##
  model = Model(inputs, A_model(A),name=name)
  if weights is not None: model.set_weights(weights)
  return model

def get_TrR_model(L=None, exclude_theta=False, use_idx=False, use_bkg=False, models_path="models"):
  def gather_idx(x):
    idx = x[1][0]
    return tf.gather(tf.gather(x[0],idx,axis=-2),idx,axis=-3)

  def get_cce_loss(x, eps=1e-8, only_dist=False):
    if only_dist:
      true_x = split_feat(x[0])["dist"]
      pred_x = split_feat(x[1])["dist"]
      loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2])
      return loss * 4
      
    elif exclude_theta:
      true_x = split_feat(x[0])
      pred_x = split_feat(x[1])
      true_x = tf.concat([true_x[k] for k in ["phi","dist","omega"]],-1)
      pred_x = tf.concat([pred_x[k] for k in ["phi","dist","omega"]],-1)
      loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2])
      return loss * 4/3

    else:
      return -tf.reduce_mean(tf.reduce_sum(x[0]*tf.math.log(x[1] + eps),-1),[-1,-2])
  
  def get_bkg_loss(x, eps=1e-8):
    return -tf.reduce_mean(tf.reduce_sum(x[1]*(tf.math.log(x[1]+eps)-tf.math.log(x[0]+eps)),-1),[-1,-2])      

  def prep_seq(x_logits):
    x_soft = tf.nn.softmax(x_logits,-1)
    x_hard = tf.one_hot(tf.argmax(x_logits,-1),20)
    x = tf.stop_gradient(x_hard - x_soft) + x_soft
    x = tf.pad(x,[[0,0],[0,0],[0,1]])
    return x[None]

  I_seq_logits = Input((L,20),name="seq_logits")
  seq = Lambda(prep_seq,name="seq")(I_seq_logits)
  I_true = Input((L,L,100),name="true")
  if use_bkg:
    I_bkg = Input((L,L,100),name="bkg")
  if use_idx:
    I_idx = Input((None,),dtype=tf.int32,name="idx")
    I_idx_true = Input((None,),dtype=tf.int32,name="idx_true")
  
  pred = []
  for nam in ["xaa","xab","xac","xad","xae"]:
    print(nam)
    TrR = get_TrR(weights=get_TrR_weights(f"{models_path}/model_{nam}.npy"),name=nam)
    pred.append(TrR(seq))
  pred = sum(pred)/len(pred)

  if use_idx:
    pred_sub = Lambda(gather_idx, name="pred_sub")([pred,I_idx])
    true_sub = Lambda(gather_idx, name="true_sub")([I_true,I_idx_true])
  else:
    pred_sub = pred
    true_sub = I_true
  
  cce_loss = Lambda(get_cce_loss,name="cce_loss")([true_sub, pred_sub])
  if use_bkg:
    bkg_loss = Lambda(get_bkg_loss,name="bkg_loss")([I_bkg, pred])
    loss = Lambda(lambda x: x[0]+0.1*x[1])([cce_loss,bkg_loss])
  else:
    loss = cce_loss
  grad = Lambda(lambda x: tf.gradients(x[0],x[1]), name="grad")([loss,I_seq_logits])

  # setup model
  inputs = [I_seq_logits, I_true]
  outputs = [cce_loss]
  if use_bkg:
    inputs += [I_bkg]
    outputs += [bkg_loss]
  if use_idx: inputs += [I_idx, I_idx_true]  
  model = Model(intputs, outputs + [grad, pred], name="TrR_model")
  
  TrR_model(seq, true, **kwargs):
    i = [seq[None],true[None]]
    if use_bkg:
      i += [kwargs["bkg"][None]]
    if use_idx:
      pos_idx = kwargs["pos_idx"]
      if "pos_idx_ref" not in kwargs or kwargs["pos_idx_ref"] is None:
        pos_idx_ref = pos_idx
      else:
        pos_idx_ref = kwargs["pos_idx_ref"]
      i += [pos_idx[None],pos_idx_ref[None]]
    
    *o = model.predict(i)
    r = {"cce_loss":o[0][0],"grad":o[-1][0],"pred":o[-2][0]}
    if use_bkg: r["bkg_loss"] = o[1][0]
    return r
  
  return TrR_model