import torch
import os
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from functools import partial

import Utils.Pneumonia_Utils as PU
import Utils.CT_Scan_Utils as CSU
import Utils.Covid19_Utils as C19U
import Utils.DR_Utils as DRU

# Constants for model paths
CANCER_MODEL_PATH = 'cs_models/EfficientNet_CT_Scans.pth.tar'
DIABETIC_RETINOPATHY_MODEL_PATH = 'cs_models/model_DR_9.pth.tar'
PNEUMONIA_MODEL_PATH = 'cs_models/DenseNet_Pneumonia.pth.tar'
COVID_MODEL_PATH = 'cs_models/DenseNet_Covid.pth.tar'

# Constants for class labels
CANCER_CLASS_LABELS = ['adenocarcinoma','large.cell.carcinoma','normal','squamous.cell.carcinoma']
DIABETIC_RETINOPATHY_CLASS_LABELS = ['No DR','Mild', 'Moderate', 'Severe', 'Proliferative DR']
PNEUMONIA_CLASS_LABELS = ['Normal', 'Pneumonia']
COVID_CLASS_LABELS = ['Normal','Covid19']

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


def cancer_page(image, test_model):
    x_ray_image = CSU.transform_image(image, CSU.val_transform)
    x_ray_image = x_ray_image.to(device)
    grad_x_ray_image, pred_label, pred_conf = CSU.plot_grad_cam(test_model, 
                                                                x_ray_image, 
                                                                CANCER_CLASS_LABELS, 
                                                                normalized=True)
    grad_x_ray_image = np.clip(grad_x_ray_image, 0, 1)
    return grad_x_ray_image, pred_label, pred_conf


def covid_page(image, test_model):
    x_ray_image = C19U.transform_image(image, C19U.val_transform)
    x_ray_image = x_ray_image.to(device)
    grad_x_ray_image, pred_label, pred_conf = C19U.plot_grad_cam(test_model, 
                                                                x_ray_image, 
                                                                COVID_CLASS_LABELS, 
                                                                normalized=True)
    grad_x_ray_image = np.clip(grad_x_ray_image, 0, 1)
    return grad_x_ray_image, pred_label, pred_conf            


def pneumonia_page(image, test_model):
    x_ray_image = PU.transform_image(image, PU.val_transform)
    x_ray_image = x_ray_image.to(device)
    grad_x_ray_image, pred_label, pred_conf = PU.plot_grad_cam(test_model, 
                                                               x_ray_image, 
                                                               PNEUMONIA_CLASS_LABELS, 
                                                               normalized=True)
    grad_x_ray_image = np.clip(grad_x_ray_image, 0, 1)
    return grad_x_ray_image, pred_label, pred_conf               

def diabetic_retinopathy_page(image_1, image_2, test_model):
    images = DRU.transform_image(image_1, image_2, DRU.val_transform)
    pred_label_1, pred_label_2 = DRU.Inf_predict_image(test_model, 
                                                        images, 
                                                        DIABETIC_RETINOPATHY_CLASS_LABELS)
    return pred_label_1, pred_label_2        

if __name__ == "__main__":

    CSU_model = CSU.Efficient().to(device)
    CSU_model.load_state_dict(torch.load(CANCER_MODEL_PATH,map_location=torch.device('cpu')),strict=False)
    CSU_test_model = CSU.ModelGradCam(CSU_model).to(device)
    CSU_images_dir = "TESTS/CHEST_CT_SCANS"
    all_images = os.listdir(CSU_images_dir)
    CSU_examples = [[os.path.join(CSU_images_dir,image)] for image in np.random.choice(all_images, size=4, replace=False)]
    
    C19U_model = C19U.DenseNet().to(device)
    C19U_model.load_state_dict(torch.load(COVID_MODEL_PATH,map_location=torch.device('cpu')),strict=False)
    C19U_test_model = C19U.ModelGradCam(C19U_model).to(device)
    C19U_C19_images_dir = [[os.path.join("TESTS/COVID19",image)] for image in np.random.choice(os.listdir("TESTS/COVID19"), size=2, replace=False)]
    NORM_images_dir = [[os.path.join("TESTS/NORMAL",image)] for image in np.random.choice(os.listdir("TESTS/NORMAL"), size=2, replace=False)]
    C19U_examples = C19U_C19_images_dir + NORM_images_dir

    PU_model = PU.DenseNet.to(device)
    PU_model.load_state_dict(torch.load(PNEUMONIA_MODEL_PATH,map_location=torch.device('cpu')),strict=False)
    PU_test_model = PU.ModelGradCam(PU_model).to(device)
    PU_images_dir = [[os.path.join("TESTS/PNEUMONIA",image)] for image in np.random.choice(os.listdir("TESTS/PNEUMONIA"), size=2, replace=False)]
    NORM_images_dir = [[os.path.join("TESTS/NORMAL",image)] for image in np.random.choice(os.listdir("TESTS/NORMAL"), size=2, replace=False)]
    PU_examples = PU_images_dir + NORM_images_dir

    DRU_cnn_model = DRU.ConvolutionNeuralNetwork().to(device)
    DRU_eff_b3 = DRU.Efficient().to(device)
    DRU_ensemble = DRU.EnsembleModel(DRU_cnn_model, DRU_eff_b3).to(device)
    DRU_ensemble.load_state_dict(torch.load(DIABETIC_RETINOPATHY_MODEL_PATH,map_location=torch.device('cpu'))["state_dict"], strict=False)
    DRU_test_model  = DRU_ensemble
    DRU_examples = [['TESTS/DR_1/10030_left._aug_0._aug_6.jpeg','TESTS/DR_0/10031_right._aug_17.jpeg']]

    cancer_interface = gr.Interface(
        fn=partial(cancer_page,test_model=CSU_test_model), 
        inputs=gr.Image(type="pil", label="Image"),
        outputs=[
            gr.Image(type="numpy", label="Heatmap Image"),
            gr.Textbox(label="Labels Present"),
            gr.Label(label="Probabilities", show_label=False)
        ],
        examples=CSU_examples,
        cache_examples=False,
        allow_flagging="never",
        title="Chest Cancer Detection System"
    )

    covid_interface = gr.Interface(
        fn=partial(covid_page,test_model=C19U_test_model), 
        inputs=gr.Image(type="pil", label="Image"),
        outputs=[
            gr.Image(type="numpy", label="Heatmap Image"),
            gr.Textbox(label="Labels Present"),
            gr.Label(label="Probabilities", show_label=False)
        ],
        examples=C19U_examples,
        cache_examples=False,
        allow_flagging="never",
        title="Covid Detection System"
    )

    pneumonia_interface = gr.Interface(
        fn=partial(pneumonia_page,test_model=PU_test_model), 
        inputs=gr.Image(type="pil", label="Image"),
        outputs=[
            gr.Image(type="numpy", label="Heatmap Image"),
            gr.Textbox(label="Labels Present"),
            gr.Label(label="Probabilities", show_label=False)
        ],
        examples=PU_examples,
        cache_examples=False,
        allow_flagging="never",
        title="Pneumonia Detection System"
    )

    diabetic_retinopathy_interface = gr.Interface(
        fn=partial(diabetic_retinopathy_page,test_model=DRU_test_model), 
        inputs=[gr.Image(type="pil", label="Image"), gr.Image(type="pil", label="Image")],
        outputs=[
            gr.Textbox(label="Labels Present"),
            gr.Textbox(label="Labels Present")
        ],
        examples=DRU_examples,
        cache_examples=False,
        allow_flagging="never",
        title="Diabetic Retinopathy System"
    )

    demo = gr.TabbedInterface(
        [cancer_interface, 
         covid_interface, 
         pneumonia_interface, 
         diabetic_retinopathy_interface], 
         ["Chest Cancer", "Covid19", "Pneumonia", "Diabetic Retinopathy"])

    demo.launch(share=True)