# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


class PerceiverAttention(nn.Module):
    def __init__(
        self, *, dim, dim_head=64, heads=8, dropout_p=0.05, concat_kv_latents=True
    ):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_x = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

        self.dropout_p = dropout_p
        self.concat_kv_latents = concat_kv_latents

    def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
        b, n, c = x.shape
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head

    def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, n_heads, n_tokens, c_per_head = x.shape
        x = x.transpose(1, 2)
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, latents, x, pos=None):
        latents = self.norm_latents(latents)
        x = self.norm_x(x)

        q = self.to_q(latents)

        # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
        if self.concat_kv_latents:
            kv_input = torch.cat((x, latents), dim=-2)
        else:
            kv_input = x
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        q = self._separate_heads(q, self.heads)
        k = self._separate_heads(k, self.heads)
        v = self._separate_heads(v, self.heads)

        if pos is not None:
            assert not self.concat_kv_latents
            pos = self._separate_heads(pos, self.heads)
            k, v = k + pos, v + pos

        out = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            dropout_p=self.dropout_p if self.training else 0.0,
        )
        out = self._recombine_heads(out)
        return self.to_out(out)


class Attention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8, dropout_p=0.05):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

        self.dropout_p = dropout_p

    def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
        b, n, c = x.shape
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head

    def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, n_heads, n_tokens, c_per_head = x.shape
        x = x.transpose(1, 2)
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, x):
        x = self.norm(x)

        q = self.to_q(x)
        k, v = self.to_kv(x).chunk(2, dim=-1)

        q = self._separate_heads(q, self.heads)
        k = self._separate_heads(k, self.heads)
        v = self._separate_heads(v, self.heads)

        out = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            dropout_p=self.dropout_p if self.training else 0.0,
        )
        out = self._recombine_heads(out)
        return self.to_out(out)


class PerceiverEncoderLayer(nn.Module):
    def __init__(
        self,
        dim,
        dim_head=64,
        heads=8,
        ff_mult=4,
        hidden_dropout_p=0.0,
        attention_dropout_p=0.0,
        concat_kv_latents=False,
        use_self_attn=False,
    ):
        super().__init__()
        self.attn = PerceiverAttention(
            dim=dim,
            dim_head=dim_head,
            heads=heads,
            dropout_p=attention_dropout_p,
            concat_kv_latents=concat_kv_latents,
        )
        self.ff = FeedForward(dim=dim, mult=ff_mult)
        self.dropout = nn.Dropout(hidden_dropout_p)
        self.use_self_attn = use_self_attn
        if use_self_attn:
            self.self_attn = Attention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout_p=attention_dropout_p,
            )
            self.self_ff = FeedForward(dim=dim, mult=ff_mult)

    def forward(self, latents, x, pos=None):
        latents = self.attn(latents, x, pos) + latents
        latents = self.dropout(latents)
        latents = self.ff(latents) + latents
        if self.use_self_attn:
            latents = self.self_attn(latents) + latents
            latents = self.self_ff(latents) + latents
        return latents


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = (
        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    )
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(
        B, H // window_size, W // window_size, window_size, window_size, -1
    )
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head=64,
        heads=1,
        num_latents=-1,
        num_latents_2d=-1,
        ff_mult=4,
        hidden_dropout_p=0.1,
        attention_dropout_p=0.05,
        pos_enc_at_key_value=False,
        concat_kv_latents=False,
        position_encoding=None,
        use_self_attn=False,
        **kwargs,
    ):
        super().__init__()
        self.num_latents = num_latents
        self.num_latents_2d = num_latents_2d

        if num_latents > 0:
            self.latents = nn.Parameter(torch.randn(num_latents, dim))
        if num_latents_2d > 0:
            self.latents_2d = nn.Parameter(torch.randn(num_latents_2d, dim))
        self.position_encoding = position_encoding

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                PerceiverEncoderLayer(
                    dim=dim,
                    dim_head=dim_head,
                    heads=heads,
                    ff_mult=ff_mult,
                    hidden_dropout_p=hidden_dropout_p,
                    attention_dropout_p=attention_dropout_p,
                    concat_kv_latents=concat_kv_latents,
                    use_self_attn=use_self_attn,
                ),
            )

        self.norm = nn.LayerNorm(dim)
        self.pos_enc_at_key_value = pos_enc_at_key_value

    def forward(self, x, pos=None):
        out_latents = []
        out_pos = []
        if self.num_latents > 0:
            latents_1d, pos_1d = self.forward_1d(x, pos)
            out_latents.append(latents_1d)
            out_pos.append(pos_1d)
        if self.num_latents_2d > 0:
            latents_2d, pos_2d = self.forward_2d(x)
            out_latents.append(latents_2d)
            out_pos.append(pos_2d)

        latents = torch.concat(out_latents, dim=1)
        if pos is not None:
            pos = torch.concat(out_pos, dim=1)

        return latents, pos

    def forward_1d(self, x, pos):
        latents = self.latents.unsqueeze(0).expand(x.shape[0], -1, -1)
        x = x.permute(0, 2, 3, 1).flatten(1, 2)

        if not self.pos_enc_at_key_value:
            _pos = None
        if pos is not None:
            _pos = pos.permute(0, 2, 3, 1).flatten(1, 2)
        else:
            _pos = None

        for layer in self.layers:
            latents = layer(latents, x, _pos)

        if pos is not None:
            pos = torch.zeros_like(latents)

        latents = self.norm(latents)
        return latents, pos

    def forward_2d(self, x):
        B, C, H, W = x.shape

        latents_2d = self.latents_2d.unsqueeze(0).expand(B, -1, -1).view(-1, 1, C)

        num_window = int(math.sqrt(self.num_latents_2d))
        window_size = H // num_window
        x = x.permute(0, 2, 3, 1)

        x = window_partition(x, window_size)
        x = x.flatten(1, 2)

        for layer in self.layers:
            latents_2d = layer(latents_2d, x)

        latents_2d = latents_2d.view(B, num_window, num_window, C).permute(0, 3, 1, 2)

        pos_2d = self.position_encoding(latents_2d)
        pos_2d = pos_2d.permute(0, 2, 3, 1).flatten(1, 2)

        latents_2d = latents_2d.permute(0, 2, 3, 1).flatten(1, 2)

        latents_2d = self.norm(latents_2d)

        return latents_2d, pos_2d