from tqdm import tqdm
from typing import Iterable, List, Union
from transformers import PreTrainedModel, PreTrainedTokenizer
import torch
from torch import nn
from sklearn.linear_model import LinearRegression
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


def extract_token_i_hidden_states(
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        inputs: Union[str, List[str]],
        token_idx_to_extract: int = -1,
        batch_size: int = 1,
        layers_to_extract: List[int] = None,
        return_dict: bool = True,
        verbose: bool = True,
) -> torch.Tensor:
    device = model.device
    model.eval()

    if isinstance(inputs, str):
        inputs = [inputs]

    if layers_to_extract is None:
        layers_to_extract = list(range(1, model.config.num_hidden_layers + 1))  # extract all but initial embeddings
    all_hidden_states = {layer: [] for layer in layers_to_extract}

    with torch.no_grad():
        for i in tqdm(range(0, len(inputs), batch_size), desc="Extracting hidden states", unit="batch", disable=not verbose):
            input_ids = tokenizer(inputs[i:i+batch_size], return_tensors="pt", return_attention_mask=False)['input_ids']
            try:
                outputs = model(input_ids.to(device), output_hidden_states=True)
            except:
                import pdb; pdb.set_trace()
                # from transformers import AutoModelForCausalLM
                # model2 = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16).to(device)
            for input_i in range(len(input_ids)):
                for layer in layers_to_extract:
                    hidden_states = outputs.hidden_states[layer]
                    all_hidden_states[layer].append(hidden_states[:, token_idx_to_extract, :].detach().cpu())
    for layer in all_hidden_states:
        all_hidden_states[layer] = torch.concat(all_hidden_states[layer], dim=0)

    if not return_dict:
        all_hidden_states = torch.concat([all_hidden_states[layer] for layer in layers_to_extract], dim=0)

    return all_hidden_states


def extract_vocab_hidden_states(
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        tokens_ids_to_extract: Iterable[int] = None,
        prompt: str = "{target}",
        prompt_target: str = "{target}",
        batch_size: int = 128,
        layers_to_extract: List[int] = None
) -> torch.Tensor:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    if layers_to_extract is None:
        layers_to_extract = list(range(1, model.config.num_hidden_layers + 1))  # extract all but initial embeddings
    all_hidden_states = {layer: [] for layer in layers_to_extract}
    tokens_ids_to_extract = tokens_ids_to_extract if tokens_ids_to_extract is not None else range(tokenizer.vocab_size)
    tokens_to_extract = [tokenizer.decode(tok_id) for tok_id in tokens_ids_to_extract]

    # add pad token if necessary
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    with torch.no_grad():
        for i in tqdm(range(0, len(tokens_to_extract), batch_size), desc="Extracting hidden states", unit="batch"):
            prompts = [prompt.replace(prompt_target, target) for target in tokens_to_extract[i:i+batch_size]]
            input_ids = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")["input_ids"]
            # input_ids = tokenizer(prompts, return_tensors="pt")["input_ids"]
            outputs = model(input_ids.to(device), output_hidden_states=True)
            for layer in layers_to_extract:
                hidden_states = outputs.hidden_states[layer]
                all_hidden_states[layer].append(hidden_states[:, -1, :].detach().cpu())

    for layer in all_hidden_states:
        all_hidden_states[layer] = torch.concat(all_hidden_states[layer], dim=0)

    return all_hidden_states


def get_vocab_tokens(tokenizer: PreTrainedTokenizer, min_word_len: int = None):
    vocab_size = tokenizer.vocab_size
    tokens = list(range(vocab_size))
    if min_word_len:
        tokens_str = [tokenizer.decode(i) for i in tokens]
        tokens_len = [len(x) for x in tokens_str]
        tokens = [tok for tok, tok_len in zip(tokens, tokens_len) if tok_len >= min_word_len]
    return tokens


def learn_linear_map(X: torch.Tensor, Y: torch.Tensor, fit_intercept=False):
    input_dtype = X.dtype
    linear_reg = LinearRegression(fit_intercept=fit_intercept).fit(X.cpu().to(float).numpy(), Y.cpu().to(float).numpy())
    linear_map = nn.Linear(X.size(1), Y.size(1), bias=fit_intercept)
    with torch.no_grad():
        linear_map.weight.data = torch.Tensor(linear_reg.coef_.T)
        if fit_intercept:
            linear_map.bias.data = torch.Tensor(linear_reg.intercept_)
    linear_map = linear_map.to(input_dtype)
    return linear_map


def train_model(
    model,
    dataloader,
    optimizer,
    loss_func="mse",
    scheduler=None,
    num_epochs=5,
    gradient_accumulation_steps=1,
    max_grads_norm=1.0,
):
    """
    Trains a two-layer MLP to map hidden states from X to Y.

    Parameters:
        X (torch.Tensor): Input tensor of shape (N, D).
        Y (torch.Tensor): Target tensor of shape (N, D).
        activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
        lr (float): Learning rate. Default is 0.001.
        weight_decay (float): Weight decay for the optimizer. Default is 0.0.
        loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
        lr_schedule (str): Learning rate schedule. Default is 'linear'.
        num_epochs (int): Number of training epochs. Default is 20.
        batch_size (int): Batch size for DataLoader. Default is 32.
        gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.

    Returns:
        nn.Module: Trained MLP model.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Select loss function
    if loss_func == "mse":
        criterion = nn.MSELoss()
    elif loss_func == "huber":
        criterion = nn.HuberLoss()
    elif loss_func == "cosine":
        criterion = nn.CosineEmbeddingLoss()
    else:
        raise ValueError("Unsupported loss function. Choose from 'mse', 'huber', or 'cosine'.")

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for i, (x_batch, y_batch) in enumerate(dataloader):
            outputs = model(x_batch.to(device))
            if loss_func == "cosine":
                # Cosine loss requires an additional target tensor of 1s
                loss = criterion(outputs, y_batch.to(device), torch.ones(x_batch.size(0)))
            else:
                loss = criterion(outputs, y_batch.to(device))

            loss = loss / gradient_accumulation_steps
            loss.backward()

            if max_grads_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), max_grads_norm)

            if (i + 1) % gradient_accumulation_steps == 0 or (i + 1) == len(dataloader):
                optimizer.step()
                optimizer.zero_grad()
                if scheduler:
                    scheduler.step()

            epoch_loss += loss.item() * gradient_accumulation_steps

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(dataloader):.6f}")

    return model.cpu()


def learn_mlp(
    X: torch.Tensor, Y: torch.Tensor,
    activation_func=nn.SiLU,
    batch_size=128,
    lr=0.001,
    weight_decay=0.0,
    loss_func="mse",
    lr_schedule="linear",
    expansion_alpha=1.0,
    num_epochs=5,
    gradient_accumulation_steps=1,
    max_grads_norm=1.0,
):
    """
    Trains a two-layer MLP to map hidden states from X to Y.

    Parameters:
        X (torch.Tensor): Input tensor of shape (N, D).
        Y (torch.Tensor): Target tensor of shape (N, D).
        activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
        lr (float): Learning rate. Default is 0.001.
        weight_decay (float): Weight decay for the optimizer. Default is 0.0.
        loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
        lr_schedule (str): Learning rate schedule. Default is 'linear'.
        num_epochs (int): Number of training epochs. Default is 20.
        batch_size (int): Batch size for DataLoader. Default is 32.
        gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.

    Returns:
        nn.Module: Trained MLP model.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_dim = X.shape[1]
    hidden_dim = int(input_dim * expansion_alpha)
    output_dim = Y.shape[1]
    model = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                activation_func(),
                nn.Linear(hidden_dim, output_dim)
    ).to(device)

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # DataLoader setup
    dataset = TensorDataset(X, Y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Learning rate scheduler
    if lr_schedule == "linear":
        total_steps = (len(dataloader) * num_epochs) // gradient_accumulation_steps
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1 - step / total_steps)
    else:
        scheduler = None

    return train_model(
        model,
        dataloader,
        optimizer,
        loss_func=loss_func,
        scheduler=scheduler,
        num_epochs=num_epochs,
        gradient_accumulation_steps=gradient_accumulation_steps,
        max_grads_norm=max_grads_norm,
    )


class FFN(nn.Module):
    def __init__(self, input_dim):
        super(FFN, self).__init__()
        self.gate_proj = nn.Linear(input_dim, input_dim)
        self.activation = nn.SiLU()
        self.map_proj = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        return (self.activation(self.gate_proj(x)) * x) + self.map_proj(x)


def learn_ffn(
    X: torch.Tensor, Y: torch.Tensor,
    activation_func=nn.SiLU,
    batch_size=128,
    lr=0.001,
    weight_decay=0.0,
    loss_func="mse",
    lr_schedule="linear",
    num_epochs=5,
    gradient_accumulation_steps=1,
    max_grads_norm=1.0,
):
    """
    Trains a two-layer MLP to map hidden states from X to Y.

    Parameters:
        X (torch.Tensor): Input tensor of shape (N, D).
        Y (torch.Tensor): Target tensor of shape (N, D).
        activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
        lr (float): Learning rate. Default is 0.001.
        weight_decay (float): Weight decay for the optimizer. Default is 0.0.
        loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
        lr_schedule (str): Learning rate schedule. Default is 'linear'.
        num_epochs (int): Number of training epochs. Default is 20.
        batch_size (int): Batch size for DataLoader. Default is 32.
        gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.

    Returns:
        nn.Module: Trained MLP model.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_dim = X.shape[1]
    model = FFN(input_dim).to(device)

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # DataLoader setup
    dataset = TensorDataset(X, Y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Learning rate scheduler
    if lr_schedule == "linear":
        total_steps = (len(dataloader) * num_epochs) // gradient_accumulation_steps
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1 - step / total_steps)
    else:
        scheduler = None

    return train_model(
        model,
        dataloader,
        optimizer,
        loss_func=loss_func,
        scheduler=scheduler,
        num_epochs=num_epochs,
        gradient_accumulation_steps=gradient_accumulation_steps,
        max_grads_norm=max_grads_norm,
    )