#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Part of the code in this file is adapted from
# https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/eval.py and
# https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/train.py

# MIT License

# Copyright (c) 2022 Lorenzo Breschi

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import gradio as gr
import numpy as np
import time
from data import write_image_tensor, PatchDataModule, prepare_data, image2tensor, tensor2image
import torch
from tqdm import tqdm
from bigdl.nano.pytorch import InferenceOptimizer
from torch.utils.data import DataLoader
from pathlib import Path
from torch.utils.data import Dataset
import datetime
import huggingface_hub


device = 'cpu' 
dtype = torch.float32
MODEL_REPO = 'CVPR/FSPBT'
ckpt_path = huggingface_hub.hf_hub_download(
            MODEL_REPO, 'generator.pt')
generator = torch.load(ckpt_path)
generator.eval()
generator.to(device, dtype)
params = {'batch_size': 1,
          'num_workers': 0}


class ImageDataset(Dataset):
    def __init__(self, img):
        self.imgs = [image2tensor(img)]
    def __getitem__(self, idx: int) -> dict:
        return self.imgs[idx]
    
    def __len__(self) -> int:
        return len(self.imgs)


data_path = Path('data')
train_image_dd = prepare_data(data_path)
dm = PatchDataModule(train_image_dd, patch_size=2**6,
                     batch_size=2**3, patch_num=2**6)

# quantize model
train_loader = dm.train_dataloader()
train_loader_iter = iter(train_loader)
quantized_model = InferenceOptimizer.quantize(generator,
                                              accelerator=None,
                                              calib_dataloader=train_loader)


def original_transfer(input_img):
    w, h, _ = input_img.shape
    print(datetime.datetime.now())
    print("input size: ", w, h)
    # resize too large image
    if w > 3000 or h > 3000:
        ratio = min(3000 / w, 3000 / h)
        w = int(w * ratio)
        h = int(h * ratio)
    if w % 4 != 0 or h % 4 != 0:
        NW = int((w // 4) * 4)
        NH = int((h // 4) * 4)
        input_img = np.resize(input_img,(NW,NH,3))
    st = time.perf_counter()
    dataset = ImageDataset(input_img)
    loader = DataLoader(dataset, **params)
    with torch.no_grad():
        for inputs in tqdm(loader):
            inputs = inputs.to(device, dtype)
            st = time.perf_counter()
            outputs = generator(inputs)
            ori_time = time.perf_counter() - st
            ori_time = "{:.3f}s".format(ori_time)
            ori_image = np.array(tensor2image(outputs[0]))
            del inputs
            del outputs
    return ori_image, ori_time

def nano_transfer(input_img):
    w, h, _ = input_img.shape
    print(datetime.datetime.now())
    print("input size: ", w, h)
    # resize too large image
    if w > 3000 or h > 3000:
        ratio = min(3000 / w, 3000 / h)
        w = int(w * ratio)
        h = int(h * ratio)
    if w % 4 != 0 or h % 4 != 0:
        NW = int((w // 4) * 4)
        NH = int((h // 4) * 4)
        input_img = np.resize(input_img,(NW,NH,3))
    st = time.perf_counter()
    dataset = ImageDataset(input_img)
    loader = DataLoader(dataset, **params)
    with torch.no_grad():
        for inputs in tqdm(loader):
            inputs = inputs.to(device, dtype)
            st = time.perf_counter()
            outputs = quantized_model(inputs)
            nano_time = time.perf_counter() - st
            nano_time = "{:.3f}s".format(nano_time)
            nano_image = np.array(tensor2image(outputs[0]))
            del inputs
            del outputs
    return nano_image, nano_time


def clear():
    return None, None, None, None
    

demo = gr.Blocks()

with demo:
    gr.Markdown("<h1><center>BigDL-Nano inference demo</center></h1>")
    with gr.Row().style(equal_height=False):
        with gr.Column():
            gr.Markdown('''
                <h2>Overview</h2>
                
                BigDL-Nano is a library in [BigDL 2.0](https://github.com/intel-analytics/BigDL) that allows the users to transparently accelerate their deep learning pipelines (including data processing, training and inference) by automatically integrating optimized libraries, best-known configurations, and software optimizations. </p>
                
                The video on the right shows how the user can easily enable quantization using BigDL-Nano (with just a couple of lines of code); you may refer to our [CVPR 2022 demo paper](https://arxiv.org/abs/2204.01715) for more details.
                ''')
        with gr.Column():
            gr.Video(value="data/nano_quantize_api.mp4")
    gr.Markdown('''
            <h2>Demo</h2>
            
            This section uses an image stylization example to demostrate the speedup of the above code when using quantization in BigDL-Nano (about 2~3x inference time speedup). 
            The demo is adapted from the original [FSPBT-Image-Translation code](https://github.com/rnwzd/FSPBT-Image-Translation),
            and the default image is from [the COCO dataset](https://cocodataset.org/#home).
            ''')
    with gr.Row().style(equal_height=False):
        input_img = gr.Image(label="input image", value="data/COCO_image.jpg", source="upload")
        with gr.Column():
            ori_but = gr.Button("Standard PyTorch")
            nano_but = gr.Button("BigDL-Nano")
            clear_but = gr.Button("Clear Output")
    with gr.Row().style(equal_height=False):
        with gr.Column():
            ori_time = gr.Text(label="Standard PyTorch latency")
            ori_image = gr.Image(label="Standard PyTorch output image")
        with gr.Column():
            nano_time = gr.Text(label="BigDL-Nano latency")
            nano_image = gr.Image(label="BigDL-Nano output image")
    
    ori_but.click(original_transfer, inputs=input_img, outputs=[ori_image, ori_time])
    nano_but.click(nano_transfer, inputs=input_img, outputs=[nano_image, nano_time])
    clear_but.click(clear, inputs=None, outputs=[ori_image, ori_time, nano_image, nano_time])
    

demo.launch(share=True, enable_queue=True)