import time
import numpy as np
import torch
from PIL import Image
import glob
import sys
import argparse
import datetime
import json
from pathlib import Path

class MaskHookLogger(object):
    def __init__(self, model, device):
        self.current_layer = 0
        self.device = device
        self.attentions = []
        self.mlps = []
        self.post_ln_std = None
        self.post_ln_mean = None
        self.model = model
        
    @torch.no_grad()
    def compute_attentions(self, ret):
        if self.current_layer == self.layer_index:
            bias_term = self.model.visual.transformer.resblocks[self.current_layer].attn.out_proj.bias
            return_value = ret[:, 0]
            return_value = return_value +  bias_term[np.newaxis, np.newaxis] / (return_value.shape[1])# [b, n, d]
            self.attentions.append(return_value.detach()) 
        self.current_layer += 1
        return ret
    
    @torch.no_grad()
    def compute_mlps(self, ret):
        if self.current_layer == self.layer_index + 1:
            self.mlps.append(ret[:, 1:].detach()) # [b, n, d] 
        return ret
    
    @torch.no_grad()
    def log_post_ln_mean(self, ret):
        self.post_ln_mean = ret.detach() # [b, 1]
        return ret
        
    @torch.no_grad()
    def log_post_ln_std(self, ret):
        self.post_ln_std = ret.detach() # [b, 1] 
        return ret
    
    def _normalize_mlps(self):
        len_intermediates = self.current_layer * 2 - 1
        # This is just the normalization layer:
        mean_centered = (self.mlps - 
                         self.post_ln_mean[:, :, np.newaxis, np.newaxis] / len_intermediates)

        weighted_mean_centered = self.model.visual.ln_post.weight.detach() * mean_centered
        weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[:, :, np.newaxis, np.newaxis]

        bias_term = self.model.visual.ln_post.bias.detach() / len_intermediates
        post_ln = weighted_mean_by_std + bias_term
        return post_ln @ self.model.visual.proj.detach()
    
    def _normalize_attentions(self):
        len_intermediates = self.current_layer * 2 - 1 # 2*l + 1
        normalization_term = self.attentions.shape[2] * 1  # n * h, h=1
        # This is just the normalization layer:
        mean_centered = (self.attentions - 
                         self.post_ln_mean[:, :, np.newaxis, np.newaxis] / 
                         (len_intermediates * normalization_term))
        weighted_mean_centered = self.model.visual.ln_post.weight.detach() * mean_centered
        weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[:, :, np.newaxis, np.newaxis]
        bias_term = self.model.visual.ln_post.bias.detach() / (len_intermediates * normalization_term)
        post_ln = weighted_mean_by_std + bias_term
        return post_ln @ self.model.visual.proj.detach()
    
    @torch.no_grad()
    def finalize(self, representation):
        """We calculate the post-ln scaling, project it and normalize by the last norm."""
        self.attentions = torch.stack(self.attentions, axis=1) # [b, 1, n, d]
        self.mlps = torch.stack(self.mlps, axis=1) # [b, 1, n, d]
        projected_attentions = self._normalize_attentions()
        projected_mlps = self._normalize_mlps()
        norm = representation.norm(dim=-1).detach()
        return (projected_attentions / norm[:, np.newaxis, np.newaxis, np.newaxis], 
                projected_mlps / norm[:, np.newaxis, np.newaxis, np.newaxis])
        
    def reinit(self):
        self.current_layer = 0
        self.attentions = []
        self.mlps = []
        self.post_ln_mean = None
        self.post_ln_std = None
        torch.cuda.empty_cache()


def hook_prs_logger(model, device, layer_index = 23):
    """Hooks a projected residual stream logger to the model."""
    prs = MaskHookLogger(model, device)
    model.hook_manager.register('visual.transformer.resblocks.*.attn.out.post', 
                                prs.compute_attentions)

    model.hook_manager.register('visual.transformer.resblocks.*.post', 
                                prs.compute_mlps)
    model.hook_manager.register('visual.ln_pre_post', 
                                prs.compute_mlps)
    model.hook_manager.register('visual.ln_post.mean', 
                                prs.log_post_ln_mean)
    model.hook_manager.register('visual.ln_post.sqrt_var', 
                                prs.log_post_ln_std)
    prs.layer_index = layer_index

    return prs