import os
import torch
import random

import gradio as gr
from glob import glob
from omegaconf import OmegaConf
from safetensors import safe_open

from diffusers import AutoencoderKL
from diffusers import EulerDiscreteScheduler, DDIMScheduler
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer

from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationFreeInitPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from diffusers.training_utils import set_seed

from animatediff.utils.freeinit_utils import get_freq_filter
from collections import namedtuple

pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
inference_config_path = "configs/inference/inference-v1.yaml"

css = """
.toolbutton {
    margin-buttom: 0em 0em 0em 0em;
    max-width: 2.5em;
    min-width: 2.5em !important;
    height: 2.5em;
}
"""

examples = [
    # 0-RealisticVision
    [
        "realisticVisionV51_v20Novae.safetensors", 
        "mm_sd_v14.ckpt", 
        "A panda standing on a surfboard in the ocean under moonlight.",
        "worst quality, low quality, nsfw, logo",
        512, 512, "2005563494988190",
        "butterworth", 0.25, 0.25, 3,
        ["use_fp16"]
    ],
    # 1-ToonYou
    [
        "toonyou_beta3.safetensors", 
        "mm_sd_v14.ckpt", 
        "(best quality, masterpiece), 1girl, looking at viewer, blurry background, upper body, contemporary, dress",
        "(worst quality, low quality)",
        512, 512, "478028150728261",
        "butterworth", 0.25, 0.25, 3,
        ["use_fp16"]
    ],
    # 2-Lyriel
    [
        "lyriel_v16.safetensors", 
        "mm_sd_v14.ckpt", 
        "hypercars cyberpunk moving, muted colors, swirling color smokes, legend, cityscape, space",
        "3d, cartoon, anime, sketches, worst quality, low quality, nsfw, logo",
        512, 512, "1566149281915957",
        "butterworth", 0.25, 0.25, 3,
        ["use_fp16"]
    ],
    # 3-RCNZ
    [
        "rcnzCartoon3d_v10.safetensors", 
        "mm_sd_v14.ckpt", 
        "A cute raccoon playing guitar in a boat on the ocean",
        "worst quality, low quality, nsfw, logo",
        512, 512, "1566149281915957",
        "butterworth", 0.25, 0.25, 3,
        ["use_fp16"]
    ],
    # 4-MajicMix
    [
        "majicmixRealistic_v5Preview.safetensors", 
        "mm_sd_v14.ckpt", 
        "1girl, reading book",
        "(ng_deepnegative_v1_75t:1.2), (badhandv4:1), (worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands, watermark, moles",
        512, 512, "2005563494988190",
        "butterworth", 0.25, 0.25, 3,
        ["use_fp16"]
    ],
    # # 5-RealisticVision
    # [
    #     "realisticVisionV51_v20Novae.safetensors", 
    #     "mm_sd_v14.ckpt", 
    #     "A panda standing on a surfboard in the ocean in sunset.",
    #     "worst quality, low quality, nsfw, logo",
    #     512, 512, "2005563494988190",
    #     "butterworth", 0.25, 0.25, 3,
    #     ["use_fp16"]
    # ]
]

# clean unrelated ckpts
# ckpts = [
#     "realisticVisionV40_v20Novae.safetensors",
#     "majicmixRealistic_v5Preview.safetensors",
#     "rcnzCartoon3d_v10.safetensors",
#     "lyriel_v16.safetensors",
#     "toonyou_beta3.safetensors"
# ]

# for path in glob(os.path.join("models", "DreamBooth_LoRA", "*.safetensors")):
#     for ckpt in ckpts:
#         if path.endswith(ckpt): break
#     else:
#         print(f"### Cleaning {path} ...")
#         os.system(f"rm -rf {path}")

# os.system(f"rm -rf {os.path.join('models', 'DreamBooth_LoRA', '*.safetensors')}")

# os.system(f"bash download_bashscripts/1-ToonYou.sh")
# os.system(f"bash download_bashscripts/2-Lyriel.sh")
# os.system(f"bash download_bashscripts/3-RcnzCartoon.sh")
# os.system(f"bash download_bashscripts/4-MajicMix.sh")
# os.system(f"bash download_bashscripts/5-RealisticVision.sh")

# # clean Gradio cache
# print(f"### Cleaning cached examples ...")
# os.system(f"rm -rf gradio_cached_examples/")


class AnimateController:
    def __init__(self):
        
        # config dirs
        self.basedir                = os.getcwd()
        self.stable_diffusion_dir   = os.path.join(self.basedir, "models", "StableDiffusion")
        self.motion_module_dir      = os.path.join(self.basedir, "models", "Motion_Module")
        self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
        self.savedir                = os.path.join(self.basedir, "samples")
        os.makedirs(self.savedir, exist_ok=True)

        self.base_model_list    = []
        self.motion_module_list = []
        self.filter_type_list = [
            "butterworth",
            "gaussian",
            "box",
            "ideal"
        ]
        
        self.selected_base_model    = None
        self.selected_motion_module = None
        self.selected_filter_type = None
        self.set_width = None
        self.set_height = None
        self.set_d_s = None
        self.set_d_t = None
        
        self.refresh_motion_module()
        self.refresh_personalized_model()
        
        # config models
        self.inference_config      = OmegaConf.load(inference_config_path)

        self.tokenizer             = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
        self.text_encoder          = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
        self.vae                   = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
        self.unet                  = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()

        self.freq_filter = None

        self.update_base_model(self.base_model_list[-2])
        self.update_motion_module(self.motion_module_list[0])
        self.update_filter(512, 512, self.filter_type_list[0], 0.25, 0.25)
        
        
    def refresh_motion_module(self):
        motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
        self.motion_module_list = sorted([os.path.basename(p) for p in motion_module_list])

    def refresh_personalized_model(self):
        base_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
        self.base_model_list = sorted([os.path.basename(p) for p in base_model_list])


    def update_base_model(self, base_model_dropdown):
        self.selected_base_model = base_model_dropdown
        
        base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
        base_model_state_dict = {}
        with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
            for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key)
                
        converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
        self.vae.load_state_dict(converted_vae_checkpoint)

        converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
        self.unet.load_state_dict(converted_unet_checkpoint, strict=False)

        self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
        return gr.Dropdown.update()

    def update_motion_module(self, motion_module_dropdown):
        self.selected_motion_module = motion_module_dropdown
        
        motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
        motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
        _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
        assert len(unexpected) == 0
        return gr.Dropdown.update()
    
    # def update_filter(self, shape, method, n, d_s, d_t):
    def update_filter(self, width_slider, height_slider, filter_type_dropdown, d_s_slider, d_t_slider):
        self.set_width = width_slider
        self.set_height = height_slider
        self.selected_filter_type = filter_type_dropdown
        self.set_d_s = d_s_slider
        self.set_d_t = d_t_slider

        vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

        shape = [1, 4, 16, self.set_height//vae_scale_factor, self.set_width//vae_scale_factor]
        self.freq_filter = get_freq_filter(
            shape, 
            device="cuda", 
            filter_type=self.selected_filter_type,
            n=4,
            d_s=self.set_d_s,
            d_t=self.set_d_t
        )

    def animate(
        self,
        base_model_dropdown,
        motion_module_dropdown,
        prompt_textbox,
        negative_prompt_textbox,
        width_slider,
        height_slider,
        seed_textbox,
        # freeinit params
        filter_type_dropdown,
        d_s_slider,
        d_t_slider,
        num_iters_slider,
        # speed up
        speed_up_options
    ):
        # set global seed
        set_seed(42)

        d_s = float(d_s_slider)
        d_t = float(d_t_slider)
        num_iters = int(num_iters_slider)


        if self.selected_base_model != base_model_dropdown: self.update_base_model(base_model_dropdown)
        if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
        
        if self.set_width != width_slider or self.set_height != height_slider or self.selected_filter_type != filter_type_dropdown or self.set_d_s != d_s or self.set_d_t != d_t:
            self.update_filter(width_slider, height_slider, filter_type_dropdown, d_s, d_t)
        
        if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()

        pipeline = AnimationFreeInitPipeline(
            vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
            scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
            ).to("cuda")
        
        # (freeinit) initialize frequency filter for noise reinitialization -------------
        pipeline.freq_filter = self.freq_filter
        # -------------------------------------------------------------------------------

        
        if int(seed_textbox) > 0: seed = int(seed_textbox)
        else: seed = random.randint(1, 1e16)
        torch.manual_seed(int(seed))
        
        assert seed == torch.initial_seed()
        # print(f"### seed: {seed}")
        
        generator = torch.Generator(device="cuda")
        generator.manual_seed(seed)
               
        sample_output = pipeline(
            prompt_textbox,
            negative_prompt     = negative_prompt_textbox,
            num_inference_steps = 25,
            guidance_scale      = 7.5,
            width               = width_slider,
            height              = height_slider,
            video_length        = 16,
            num_iters           = num_iters,
            use_fast_sampling   = True if "use_coarse_to_fine_sampling" in speed_up_options else False,
            save_intermediate   = False,
            return_orig         = True,
            use_fp16            = True if "use_fp16" in speed_up_options else False
        )
        orig_sample = sample_output.orig_videos
        sample = sample_output.videos

        save_sample_path = os.path.join(self.savedir, f"sample.mp4")
        save_videos_grid(sample, save_sample_path)

        save_orig_sample_path = os.path.join(self.savedir, f"sample_orig.mp4")
        save_videos_grid(orig_sample, save_orig_sample_path)

        # save_compare_path = os.path.join(self.savedir, f"compare.mp4")
        # save_videos_grid(torch.concat([orig_sample, sample]), save_compare_path)
    
        json_config = {
            "prompt": prompt_textbox,
            "n_prompt": negative_prompt_textbox,
            "width": width_slider,
            "height": height_slider,
            "seed": seed,
            "base_model": base_model_dropdown,
            "motion_module": motion_module_dropdown,
            "filter_type": filter_type_dropdown,
            "d_s": d_s,
            "d_t": d_t,
            "num_iters": num_iters,
            "use_fp16": True if "use_fp16" in speed_up_options else False,
            "use_coarse_to_fine_sampling": True if "use_coarse_to_fine_sampling" in speed_up_options else False
        }
        print(json_config)

        # return gr.Video.update(value=save_compare_path), gr.Json.update(value=json_config)
        # return gr.Video.update(value=save_orig_sample_path), gr.Video.update(value=save_sample_path), gr.Video.update(value=save_compare_path), gr.Json.update(value=json_config)
        return gr.Video.update(value=save_orig_sample_path), gr.Video.update(value=save_sample_path), gr.Json.update(value=json_config)
        

controller = AnimateController()


def ui():
    with gr.Blocks(css=css) as demo:
        # gr.Markdown('# FreeInit')
        gr.Markdown(
            """
            <div align="center">
            <h1>FreeInit</h1>
            </div>
            """
        )
        gr.Markdown(
            """
            <p align="center">
                    <a title="Project Page" href="https://tianxingwu.github.io/pages/FreeInit/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                        <img src="https://img.shields.io/badge/Project-Website-5B7493?logo=googlechrome&logoColor=5B7493">
                    </a>
                    <a title="arXiv" href="https://arxiv.org/abs/2312.07537" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                        <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=b31b1b">
                    </a>
                    <a title="GitHub" href="https://github.com/TianxingWu/FreeInit" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                        <img src="https://img.shields.io/github/stars/TianxingWu/FreeInit?label=GitHub%20%E2%98%85&&logo=github" alt="badge-github-stars">
                    </a>
                    <a title="Video" href="https://youtu.be/lS5IYbAqriI" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                        <img src="https://img.shields.io/badge/YouTube-Video-red?logo=youtube&logoColor=red">
                    </a>
                    <a title="Visitor" href="https://hits.seeyoufarm.com" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                        <img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fhuggingface.co%2Fspaces%2FTianxingWu%2FFreeInit&count_bg=%23678F74&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false">
                    </a>
            </p>
            """
            # <a title="Visitor" href="https://hits.seeyoufarm.com" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            #     <img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fhuggingface.co%2Fspaces%2FTianxingWu%2FFreeInit&count_bg=%23678F74&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false">
            # </a>
        )
        gr.Markdown(
            """
            Official Gradio Demo for ***FreeInit: Bridging Initialization Gap in Video Diffusion Models***.
            FreeInit improves time consistency of diffusion-based video generation at inference time. In this demo, we apply FreeInit on [AnimateDiff v1](https://github.com/guoyww/AnimateDiff) as an example. Sampling time: ~ 80s.<br>
            """
        )

        with gr.Row():
            with gr.Column():
                # gr.Markdown(
                #     """
                #     ### Usage
                #     1. Select customized model and motion module in `Model Settings`.
                #     3. Set `FreeInit Settings`.
                #     3. Provide `Prompt` and `Negative Prompt` for your selected model. You can refer to each model's webpage on CivitAI to learn how to write prompts for them:
                #         - [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
                #         - [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
                #         - [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
                #         - [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)
                #         - [`realisticVisionV20_v20.safetensors`](https://civitai.com/models/4201?modelVersionId=29460)
                #     4. Click `Generate`.
                #     """
                # )
                prompt_textbox          = gr.Textbox( label="Prompt",          lines=3, placeholder="Enter your prompt here")
                negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")

                gr.Markdown(
                    """
                    *Prompt Tips:*

                    For each personalized model in `Model Settings`, you can refer to their webpage on CivitAI to learn how to write good prompts for them:
                    - [`realisticVisionV51_v20Novae.safetensors`](https://civitai.com/models/4201?modelVersionId=130072)
                    - [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
                    - [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
                    - [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
                    - [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)   
                    """
                )
                
                with gr.Accordion("Model Settings", open=False):
                    gr.Markdown(
                        """
                        Select personalized model and motion module for AnimateDiff.
                        """
                        )
                    base_model_dropdown     = gr.Dropdown( label="Base DreamBooth Model", choices=controller.base_model_list,    value=controller.base_model_list[-2],    interactive=True,
                                                          info="Select personalized text-to-image model from community")
                    motion_module_dropdown  = gr.Dropdown( label="Motion Module",  choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True,
                                                          info="Select motion module. Recommend mm_sd_v14.ckpt for larger movements.")
                
                base_model_dropdown.change(fn=controller.update_base_model,       inputs=[base_model_dropdown],    outputs=[base_model_dropdown])
                motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
                
                with gr.Accordion("FreeInit Params", open=False):
                    gr.Markdown(
                        """
                        Adjust to control the smoothness.
                        """
                    )
                    filter_type_dropdown    = gr.Dropdown( label="Filter Type",  choices=controller.filter_type_list, value=controller.filter_type_list[0], interactive=True, 
                                                          info="Default as Butterworth. To fix large inconsistencies, consider using Gaussian.")
                    d_s_slider             = gr.Slider( label="d_s",  value=0.25, minimum=0, maximum=1, step=0.125, 
                                                       info="Stop frequency for spatial dimensions (0.0-1.0)")
                    d_t_slider             = gr.Slider( label="d_t",  value=0.25, minimum=0, maximum=1, step=0.125, 
                                                       info="Stop frequency for temporal dimension (0.0-1.0)")
                    # num_iters_textbox       = gr.Textbox( label="FreeInit Iterations", value=3, info="Sould be integer >1, larger value leads to smoother results)")
                    num_iters_slider        = gr.Slider( label="FreeInit Iterations", value=3, minimum=2, maximum=5, step=1,
                                                        info="Larger value leads to smoother results & longer inference time.")

                with gr.Accordion("Advance", open=False):
                    with gr.Row():
                        width_slider  = gr.Slider(  label="Width",  value=512, minimum=256, maximum=1024, step=64 )
                        height_slider = gr.Slider(  label="Height", value=512, minimum=256, maximum=1024, step=64 )
                    with gr.Row():
                        seed_textbox = gr.Textbox( label="Seed",  value=2005563494988190)
                        seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
                        seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox])
                    with gr.Row():
                        speed_up_options = gr.CheckboxGroup(
                            ["use_fp16", "use_coarse_to_fine_sampling"],
                            label="Speed-Up Options",
                            value=["use_fp16"]
                        )


                generate_button = gr.Button( value="Generate", variant='primary' )


            # with gr.Column():
            #     result_video = gr.Video( label="Generated Animation", interactive=False )
            #     json_config  = gr.Json( label="Config", value=None )
            with gr.Column():
                with gr.Row():
                    orig_video = gr.Video( label="AnimateDiff", interactive=False )
                    freeinit_video = gr.Video( label="AnimateDiff + FreeInit", interactive=False )
                # with gr.Row():
                #     compare_video = gr.Video( label="Compare", interactive=False )
                with gr.Row():
                    json_config  = gr.Json( label="Config", value=None )

            inputs  = [base_model_dropdown, motion_module_dropdown, 
                       prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox,
                       filter_type_dropdown, d_s_slider, d_t_slider, num_iters_slider,
                       speed_up_options
                       ]
            # outputs = [result_video, json_config]
            # outputs = [orig_video, freeinit_video, compare_video, json_config]
            outputs = [orig_video, freeinit_video, json_config]
            
            generate_button.click( fn=controller.animate, inputs=inputs, outputs=outputs )
                
        # gr.Examples( fn=controller.animate, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True)
        gr.Examples( fn=controller.animate, examples=examples, inputs=inputs, outputs=outputs, cache_examples="lazy")

    return demo


if __name__ == "__main__":
    demo = ui()
    demo.queue(max_size=20)
    demo.launch()