# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, List, Union

import torch
import torch.nn.functional as F
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin


from .attention import LinearTransformerBlock, t2i_modulate
from .lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder


def cross_norm(hidden_states, controlnet_input):
    # input N x T x c
    mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
    mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
    controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
    return controlnet_input


# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


class T2IFinalLayer(nn.Module):
    """
    The final layer of Sana.
    """

    def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
        super().__init__()
        self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
        self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
        self.out_channels = out_channels
        self.patch_size = patch_size

    def unpatchfy(
        self,
        hidden_states: torch.Tensor,
        width: int,
    ):
        # 4 unpatchify
        new_height, new_width = 1, hidden_states.size(1)
        hidden_states = hidden_states.reshape(
            shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
        ).contiguous()
        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
        output = hidden_states.reshape(
            shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
        ).contiguous()
        if width > new_width:
            output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
        elif width < new_width:
            output = output[:, :, :, :width]
        return output

    def forward(self, x, t, output_length):
        shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
        x = t2i_modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        # unpatchify
        output = self.unpatchfy(x, output_length)
        return output


class PatchEmbed(nn.Module):
    """2D Image to Patch Embedding"""

    def __init__(
        self,
        height=16,
        width=4096,
        patch_size=(16, 1),
        in_channels=8,
        embed_dim=1152,
        bias=True,
    ):
        super().__init__()
        patch_size_h, patch_size_w = patch_size
        self.early_conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias),
            torch.nn.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True),
            nn.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
        )
        self.patch_size = patch_size
        self.height, self.width = height // patch_size_h, width // patch_size_w
        self.base_size = self.width

    def forward(self, latent):
        # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
        latent = self.early_conv_layers(latent)
        latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
        return latent


@dataclass
class Transformer2DModelOutput(BaseOutput):

    sample: torch.FloatTensor
    proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None


class ACEStepTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: Optional[int] = 8,
        num_layers: int = 28,
        inner_dim: int = 1536,
        attention_head_dim: int = 64,
        num_attention_heads: int = 24,
        mlp_ratio: float = 4.0,
        out_channels: int = 8,
        max_position: int = 32768,
        rope_theta: float = 1000000.0,
        speaker_embedding_dim: int = 512,
        text_embedding_dim: int = 768,
        ssl_encoder_depths: List[int] = [9, 9],
        ssl_names: List[str] = ["mert", "m-hubert"],
        ssl_latent_dims: List[int] = [1024, 768],
        lyric_encoder_vocab_size: int = 6681,
        lyric_hidden_size: int = 1024,
        patch_size: List[int] = [16, 1],
        max_height: int = 16,
        max_width: int = 4096,
        **kwargs,
    ):
        super().__init__()

        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim
        self.inner_dim = inner_dim
        self.out_channels = out_channels
        self.max_position = max_position
        self.patch_size = patch_size

        self.rope_theta = rope_theta

        self.rotary_emb = Qwen2RotaryEmbedding(
            dim=self.attention_head_dim,
            max_position_embeddings=self.max_position,
            base=self.rope_theta,
        )

        # 2. Define input layers
        self.in_channels = in_channels

        # 3. Define transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                LinearTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    mlp_ratio=mlp_ratio,
                    add_cross_attention=True,
                    add_cross_attention_dim=self.inner_dim,
                )
                for i in range(self.config.num_layers)
            ]
        )
        self.num_layers = num_layers

        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
        self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))

        # speaker
        self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)

        # genre
        self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)

        # lyric
        self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
        self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
        self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)

        projector_dim = 2 * self.inner_dim

        self.projectors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.inner_dim, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, ssl_dim),
            ) for ssl_dim in ssl_latent_dims
        ])

        self.ssl_latent_dims = ssl_latent_dims
        self.ssl_encoder_depths = ssl_encoder_depths

        self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
        self.ssl_names = ssl_names

        self.proj_in = PatchEmbed(
            height=max_height,
            width=max_width,
            patch_size=patch_size,
            embed_dim=self.inner_dim,
            bias=True,
        )

        self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
        self.gradient_checkpointing = False

    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        Sets the attention processor to use [feed forward
        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).

        Parameters:
            chunk_size (`int`, *optional*):
                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
                over each tensor of dim=`dim`.
            dim (`int`, *optional*, defaults to `0`):
                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
                or dim=1 (sequence length).
        """
        if dim not in [0, 1]:
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # By default chunk size is 1
        chunk_size = chunk_size or 1

        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    def forward_lyric_encoder(
        self,
        lyric_token_idx: Optional[torch.LongTensor] = None,
        lyric_mask: Optional[torch.LongTensor] = None,
    ):
        # N x T x D
        lyric_embs = self.lyric_embs(lyric_token_idx)
        prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
        prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
        return prompt_prenet_out

    def encode(
        self,
        encoder_text_hidden_states: Optional[torch.Tensor] = None,
        text_attention_mask: Optional[torch.LongTensor] = None,
        speaker_embeds: Optional[torch.FloatTensor] = None,
        lyric_token_idx: Optional[torch.LongTensor] = None,
        lyric_mask: Optional[torch.LongTensor] = None,
    ):

        bs = encoder_text_hidden_states.shape[0]
        device = encoder_text_hidden_states.device
        
        # speaker embedding
        encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
        speaker_mask = torch.ones(bs, 1, device=device)

        # genre embedding
        encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)

        # lyric
        encoder_lyric_hidden_states = self.forward_lyric_encoder(
            lyric_token_idx=lyric_token_idx,
            lyric_mask=lyric_mask,
        )

        encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
        encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
        return encoder_hidden_states, encoder_hidden_mask

    def decode(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        encoder_hidden_mask: torch.Tensor,
        timestep: Optional[torch.Tensor],
        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
        output_length: int = 0,
        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
        controlnet_scale: Union[float, torch.Tensor] = 1.0,
        return_dict: bool = True,
    ):

        embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
        temb = self.t_block(embedded_timestep)

        hidden_states = self.proj_in(hidden_states)

        # controlnet logic
        if block_controlnet_hidden_states is not None:
            control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
            hidden_states = hidden_states + control_condi * controlnet_scale

        inner_hidden_states = []

        rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
        encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])

        for index_block, block in enumerate(self.transformer_blocks):

            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module, return_dict=None):
                    def custom_forward(*inputs):
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states=hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_hidden_mask,
                    rotary_freqs_cis=rotary_freqs_cis,
                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
                    temb=temb,
                    **ckpt_kwargs,
                )

            else:
                hidden_states = block(
                    hidden_states=hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_hidden_mask,
                    rotary_freqs_cis=rotary_freqs_cis,
                    rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
                    temb=temb,
                )

            for ssl_encoder_depth in self.ssl_encoder_depths:
                if index_block == ssl_encoder_depth:
                    inner_hidden_states.append(hidden_states)

        proj_losses = []
        if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:

            for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
                if ssl_hidden_state is None:
                    continue
                # 1. N x T x D1 -> N x D x D2
                est_ssl_hidden_state = projector(inner_hidden_state)
                # 3. projection loss
                bs = inner_hidden_state.shape[0]
                proj_loss = 0.0
                for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
                    # 2. interpolate
                    z_tilde = F.interpolate(z_tilde.unsqueeze(0).transpose(1, 2), size=len(z), mode='linear', align_corners=False).transpose(1, 2).squeeze(0)

                    z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
                    z = torch.nn.functional.normalize(z, dim=-1)
                    # T x d -> T x 1 -> 1
                    target = torch.ones(z.shape[0], device=z.device)
                    proj_loss += self.cosine_loss(z, z_tilde, target)
                proj_losses.append((ssl_name, proj_loss / bs))

        output = self.final_layer(hidden_states, embedded_timestep, output_length)
        if not return_dict:
            return (output, proj_losses)

        return Transformer2DModelOutput(sample=output, proj_losses=proj_losses)

    # @torch.compile
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        encoder_text_hidden_states: Optional[torch.Tensor] = None,
        text_attention_mask: Optional[torch.LongTensor] = None,
        speaker_embeds: Optional[torch.FloatTensor] = None,
        lyric_token_idx: Optional[torch.LongTensor] = None,
        lyric_mask: Optional[torch.LongTensor] = None,
        timestep: Optional[torch.Tensor] = None,
        ssl_hidden_states: Optional[List[torch.Tensor]] = None,
        block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
        controlnet_scale: Union[float, torch.Tensor] = 1.0,
        return_dict: bool = True,
    ):
        encoder_hidden_states, encoder_hidden_mask = self.encode(
            encoder_text_hidden_states=encoder_text_hidden_states,
            text_attention_mask=text_attention_mask,
            speaker_embeds=speaker_embeds,
            lyric_token_idx=lyric_token_idx,
            lyric_mask=lyric_mask,
        )

        output_length = hidden_states.shape[-1]

        output = self.decode(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_hidden_mask=encoder_hidden_mask,
            timestep=timestep,
            ssl_hidden_states=ssl_hidden_states,
            output_length=output_length,
            block_controlnet_hidden_states=block_controlnet_hidden_states,
            controlnet_scale=controlnet_scale,
            return_dict=return_dict,
        )

        return output