from typing import Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from collections import deque
from einops import rearrange
from timm.models.layers import trunc_normal_
from IPython import embed
from torch import Tensor

from utils import (
    is_context_parallel_initialized,
    get_context_parallel_group,
    get_context_parallel_world_size,
    get_context_parallel_rank,
    get_context_parallel_group_rank,
)

from .context_parallel_ops import (
    conv_scatter_to_context_parallel_region,
    conv_gather_from_context_parallel_region,
    cp_pass_from_previous_rank,
)


def divisible_by(num, den):
    return (num % den) == 0

def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

def is_odd(n):
    return not divisible_by(n, 2)


class CausalGroupNorm(nn.GroupNorm):

    def forward(self, x: Tensor) -> Tensor:
        t = x.shape[2]
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        x = super().forward(x)
        x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
        return x


class CausalConv3d(nn.Module):

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size: Union[int, Tuple[int, int, int]],
            stride: Union[int, Tuple[int, int, int]] = 1,
            pad_mode: str ='constant',
            **kwargs
    ):
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = cast_tuple(kernel_size, 3)
    
        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
        self.time_kernel_size = time_kernel_size
        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
        dilation = kwargs.pop('dilation', 1)
        self.pad_mode = pad_mode

        if isinstance(stride, int):
            stride = (stride, 1, 1)
    
        time_pad = dilation * (time_kernel_size - 1)
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        self.temporal_stride = stride[0]
        self.time_pad = time_pad
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
        self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)

        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)
        self.cache_front_feat = deque()

    def _clear_context_parallel_cache(self):
        del self.cache_front_feat
        self.cache_front_feat = deque()

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def context_parallel_forward(self, x):
        x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size)
        
        x = F.pad(x, self.time_uncausal_padding, mode='constant')

        cp_rank = get_context_parallel_rank()
        if cp_rank != 0:
            if self.temporal_stride == 2 and self.time_kernel_size == 3:
                x = x[:,:,1:]
    
        x = self.conv(x)
        return x

    def forward(self, x, is_init_image=True, temporal_chunk=False):
        # temporal_chunk: whether to use the temporal chunk

        if is_context_parallel_initialized():
            return self.context_parallel_forward(x)
        
        pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'

        if not temporal_chunk:
            x = F.pad(x, self.time_causal_padding, mode=pad_mode)
        else:
            assert not self.training, "The feature cache should not be used in training"
            if is_init_image:
                # Encode the first chunk
                x = F.pad(x, self.time_causal_padding, mode=pad_mode)
                self._clear_context_parallel_cache()
                self.cache_front_feat.append(x[:, :, -2:].clone().detach())
            else:
                x = F.pad(x, self.time_uncausal_padding, mode=pad_mode)
                video_front_context = self.cache_front_feat.pop()
                self._clear_context_parallel_cache()

                if self.temporal_stride == 1 and self.time_kernel_size == 3:
                    x = torch.cat([video_front_context, x], dim=2)
                elif self.temporal_stride == 2 and self.time_kernel_size == 3:
                    x = torch.cat([video_front_context[:,:,-1:], x], dim=2)

                self.cache_front_feat.append(x[:, :, -2:].clone().detach())
        
        x = self.conv(x)
        return x