import gradio as gr
import numpy as np
import torch
import cv2
import os
import imageio
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import LineartDetector
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, Resize

from NaRCan_model import Homography, Siren
from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting



def get_example():
    case = [
        [
            'examples/bear.mp4',     
        ],
        [
            'examples/boat.mp4',     
        ],
        [
            'examples/woman-drink.mp4',     
        ],
        [
            'examples/corgi.mp4',     
        ],
        [
            'examples/yacht.mp4',     
        ],
        [
            'examples/koolshooters.mp4',     
        ],
        [
            'examples/overlook-the-ocean.mp4',     
        ],
        [
            'examples/rotate.mp4',
        ],
        [
            'examples/shark-ocean.mp4',     
        ],
        [
            'examples/surf.mp4',     
        ],
        [
            'examples/cactus.mp4',     
        ],
        [
            'examples/gold-fish.mp4',
        ]
    ]
    return case


def set_default_prompt(video_name):
    video_to_prompt = {
        'bear.mp4': 'bear, Van Gogh Style',
        'boat.mp4': 'a burning boat sails on lava',
        'cactus.mp4': 'cactus, made of paper',
        'corgi.mp4': 'a hellhound',
        'gold-fish.mp4': 'Goldfish in the Milky Way',
        'koolshooters.mp4': 'Avatar',
        'overlook-the-ocean.mp4': 'ocean, pixel style',
        'rotate.mp4': 'turbine engine',
        'shark-ocean.mp4': 'A sleek shark, cartoon style',
        'surf.mp4': 'Sailing, The background is a large white cloud, sketch style',
        'woman-drink.mp4': 'a drinking zombie',
        'yacht.mp4': 'yacht, cyberpunk style',
    }
    return video_to_prompt.get(video_name, '')


def update_prompt(input_video):
    video_name = input_video.split('/')[-1]
    return set_default_prompt(video_name)


# Map videos to corresponding images
video_to_image = {
    'bear.mp4': ['canonical/bear.png', 'pth_file/bear', 'examples_frames/bear'],
    'boat.mp4': ['canonical/boat.png', 'pth_file/boat', 'examples_frames/boat'],
    'cactus.mp4': ['canonical/cactus.png', 'pth_file/cactus', 'examples_frames/cactus'],
    'corgi.mp4': ['canonical/corgi.png', 'pth_file/corgi', 'examples_frames/corgi'],
    'gold-fish.mp4': ['canonical/gold-fish.png', 'pth_file/gold-fish', 'examples_frames/gold-fish'],
    'koolshooters.mp4': ['canonical/koolshooters.png', 'pth_file/koolshooters', 'examples_frames/koolshooters'],
    'overlook-the-ocean.mp4': ['canonical/overlook-the-ocean.png', 'pth_file/overlook-the-ocean', 'examples_frames/overlook-the-ocean'],
    'rotate.mp4': ['canonical/rotate.png', 'pth_file/rotate', 'examples_frames/rotate'],
    'shark-ocean.mp4': ['canonical/shark-ocean.png', 'pth_file/shark-ocean', 'examples_frames/shark-ocean'],
    'surf.mp4': ['canonical/surf.png', 'pth_file/surf', 'examples_frames/surf'],
    'woman-drink.mp4': ['canonical/woman-drink.png', 'pth_file/woman-drink', 'examples_frames/woman-drink'],
    'yacht.mp4': ['canonical/yacht.png', 'pth_file/yacht', 'examples_frames/yacht'],
}


def images_to_video(image_list, output_path, fps=10):
    # Convert PIL Images to numpy arrays
    frames = [np.array(img).astype(np.uint8) for img in image_list]
    frames = frames[:20]

    # Create video writer
    writer = imageio.get_writer(output_path, fps=fps, codec='libx264')

    for frame in frames:
        writer.append_data(frame)

    writer.close()


def NaRCan_make_video(edit_canonical, pth_path, frames_path):
    # load NaRCan model
    checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth"))
    checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth"))
    g_old = Homography(hidden_features=256, hidden_layers=2).cuda()
    g = Siren(in_features=3, out_features=2, hidden_features=256,
              hidden_layers=5, outermost_linear=True).cuda()
    
    g_old.load_state_dict(checkpoint_g_old)
    g.load_state_dict(checkpoint_g)

    g_old.eval()
    g.eval()

    transform = Compose([
        Resize(512),
        ToTensor(),
        Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
    ])
    v = TestVideoFitting(frames_path, transform)
    videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)

    model_input, ground_truth = next(iter(videoloader))
    model_input, ground_truth = model_input[0].cuda(), ground_truth[0].cuda()

    myoutput = None
    data_len = len(os.listdir(frames_path))

    with torch.no_grad():
        batch_size = (v.H * v.W)
        for step in range(data_len):
            start = (step * batch_size) % len(model_input)
            end = min(start + batch_size, len(model_input))

            # get the deformation
            xy, t = model_input[start:end, :-1], model_input[start:end, [-1]]
            xyt = model_input[start:end]
            h_old = apply_homography(xy, g_old(t))
            h = g(xyt)
            xy_ = h_old + h

            # use canonical to reconstruct
            w, h = v.W, v.H
            canonical_img = np.array(edit_canonical.convert('RGB'))
            canonical_img = torch.from_numpy(canonical_img).float().cuda()
            h_c, w_c = canonical_img.shape[:2]
            grid_new = xy_.clone()
            grid_new[..., 1] = xy_[..., 0] / 1.5
            grid_new[..., 0] = xy_[..., 1] / 2.0

            if len(canonical_img.shape) == 3:
                canonical_img = canonical_img.unsqueeze(0)
            results = torch.nn.functional.grid_sample(
                canonical_img.permute(0, 3, 1, 2),
                grid_new.unsqueeze(1).unsqueeze(0),
                mode='bilinear',
                padding_mode='border')
            o = results.squeeze().permute(1,0)

            if step == 0:
                myoutput = o
            
            else:
                myoutput = torch.cat([myoutput, o])

    myoutput = myoutput.reshape(512, 512, data_len, 3).permute(2, 0, 1, 3).clone().detach().cpu().numpy().astype(np.float32)
    # myoutput = np.clip(myoutput, -1, 1) * 0.5 + 0.5

    for i in range(len(myoutput)):
        myoutput[i] = Image.fromarray(np.uint8(myoutput[i])).resize((512, 512)) #854, 480

    edit_video_path = f'NaRCan_fps_10.mp4'
    images_to_video(myoutput, edit_video_path)
    
    return edit_video_path


def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt, control_type="Lineart"):
    video_name = input_video.split('/')[-1]
    if video_name in video_to_image:
        image_path = video_to_image[video_name][0]
        pth_path = video_to_image[video_name][1]
        frames_path = video_to_image[video_name][2]
    else:
        return None

    if control_type == "Lineart":
        # Load the control net model for lineart
        controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16)
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
        )
        pipe.to("cuda")
        # lineart
        processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
        processor_partial = partial(processor, coarse=False)
        size_ = 768
        canonical_image = Image.open(image_path)
        ori_size = canonical_image.size
        image = processor_partial(canonical_image.resize((size_, size_)), detect_resolution=size_, image_resolution=size_)
        image = image.resize(ori_size, resample=Image.BILINEAR)
        
        generator = torch.manual_seed(seed) if seed != -1 else None
        output_images = pipe(
            prompt=prompt,
            image=image,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            negative_prompt=n_prompt,
            generator=generator
        ).images
        # output_images[0] = output_images[0].resize(ori_size, resample=Image.BILINEAR)
    
    else:
        # Load the control net model for canny
        controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16)
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
        )
        pipe.to("cuda")
        # canny
        canonical_image = cv2.imread(image_path)
        canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB)
        image = cv2.cvtColor(canonical_image, cv2.COLOR_RGB2GRAY)
        image = cv2.Canny(image, 100, 200)
        image = image[:, :, None]
        image = np.concatenate([image, image, image], axis=2)
        image = Image.fromarray(image)
        
        generator = torch.manual_seed(seed) if seed != -1 else None
        output_images = pipe(
            prompt=prompt,
            image=image,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            negative_prompt=n_prompt,
            generator=generator
        ).images
    
    edit_video_path = NaRCan_make_video(output_images[0], pth_path, frames_path)

    # Here we return the first output image as the result
    return edit_video_path


########
# demo #
########


intro = """
<div style="text-align:center">
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
   NaRCan - <small>Natural Refined Canonical Image</small>
</h1>
<span>[<a target="_blank" href="https://koi953215.github.io/NaRCan_page/">Project page</a>], [<a target="_blank" href="https://huggingface.co/papers/2406.06523">Paper</a>]</span>
<div style="display:flex; justify-content: center;margin-top: 0.5em">Each edit takes ~10 sec </div>
</div>
"""



with gr.Blocks(css="style.css") as demo:
    
    gr.HTML(intro)
    frames = gr.State()
    inverted_latents = gr.State()
    latents = gr.State()
    zs = gr.State()
    do_inversion = gr.State(value=True)

    with gr.Row():
        input_video = gr.Video(label="Input Video", interactive=False, elem_id="input_video", value='examples/bear.mp4')
        output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video")
        input_video.style(height=365, width=365)
        output_video.style(height=365, width=365)


    with gr.Row():
            prompt = gr.Textbox(
                            label="Describe your edited video",
                            max_lines=1, 
                            value="bear, Van Gogh Style"
                            # placeholder="bear, Van Gogh Style"
                        )
    
               
    with gr.Row():
        run_button = gr.Button("Edit your video!", visible=True)

    max_images = 12
    default_num_images = 3
    with gr.Accordion('Advanced options', open=False):
        control_type = gr.Dropdown(
            ["Canny", "Lineart"], 
            label="Control Type", 
            info="Canny or Lineart",
            value="Lineart"
        )
        num_steps = gr.Slider(label='Steps',
                                minimum=1,
                                maximum=100,
                                value=20,
                                step=1)
        guidance_scale = gr.Slider(label='Guidance Scale',
                                    minimum=0.1,
                                    maximum=30.0,
                                    value=9.0,
                                    step=0.1)
        seed = gr.Slider(label='Seed',
                            minimum=-1,
                            maximum=2147483647,
                            step=1,
                            randomize=True)
        n_prompt = gr.Textbox(
            label='Negative Prompt',
            value=""
        )
                    
    input_video.change(
        fn = update_prompt,
        inputs = [input_video],
        outputs = [prompt],
        queue = False)
    
    run_button.click(fn = edit_with_pnp,
                     inputs = [input_video, 
                               prompt, 
                               num_steps, 
                               guidance_scale, 
                               seed, 
                               n_prompt,
                               control_type,
                               ],
                                 outputs = [output_video]
                                )

    gr.Examples(
        examples=get_example(),
        label='Examples',
        inputs=[input_video],
        outputs=[output_video],
        examples_per_page=8
    )

demo.queue()

demo.launch(share=True)