import gradio as gr
import json
import numpy as np
import datasets
import cv2
import matplotlib.pyplot as plt

sample_dataset1 = datasets.load_dataset("asgaardlab/SampleDataset", split="validation")
sample_dataset2 = datasets.load_dataset("asgaardlab/SampleDataset2", split="validation")




def overlay_with_transparency(background, overlay, alpha_mask):
    """
    Overlay a semi-transparent image on top of another image.

    Args:
    - background: The image on which the overlay will be added.
    - overlay: The image to overlay.
    - alpha_mask: The mask specifying transparency levels.
    """
    return cv2.addWeighted(background, 1, overlay, alpha_mask, 0)

def generate_overlay_image(buggy_image, objects, segmentation_image_rgb, font_scale=0.5, font_color=(0, 255, 255)):
    """
    Generate an overlaid image using the provided annotations.

    Args:
    - buggy_image: The image to be overlaid.
    - objects: The JSON object details.
    - segmentation_image_rgb: The segmentation image.
    - font_scale: Scale factor for the font size.
    - font_color: Color for the font in BGR format.

    Returns:
    - The overlaid image.
    """
    overlaid_img = buggy_image.copy()

    for obj in objects:
        # Get the mask for this object
        color = tuple(obj["color"])[:-1]
        mask = np.all(segmentation_image_rgb[:, :, :3] == np.array(color), axis=-1).astype(np.float32)

        # Create a colored version of the mask using the object's color
        colored_mask = np.zeros_like(overlaid_img)
        colored_mask[mask == 1] = color

        # Overlay the colored mask onto the original image with 0.3 transparency
        overlaid_img = overlay_with_transparency(overlaid_img, colored_mask, 0.3)

        # Find the center of the mask to place the label
        mask_coords = np.argwhere(mask)
        y_center, x_center = np.mean(mask_coords, axis=0).astype(int)

        # Draw the object's name at the center with specified font size and color
        cv2.putText(overlaid_img, obj["labelName"], (x_center, y_center), 
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_color, 1, cv2.LINE_AA)

    return overlaid_img



def generate_annotations(selected_dataset, image_index):
    bugs_ds = sample_dataset1 if selected_dataset == 'Western Scene' else sample_dataset2

    image_index = int(image_index)
    objects_json = bugs_ds[image_index]["Objects JSON (Correct)"]
    objects = json.loads(objects_json)

    segmentation_image_rgb = bugs_ds[image_index]["Segmentation Image (Correct)"]
    segmentation_image_rgb = np.array(segmentation_image_rgb)

    annotations = []
    for obj in objects:
        color = tuple(obj["color"])[:-1]
        mask = np.all(segmentation_image_rgb[:, :, :3] == np.array(color), axis=-1).astype(np.float32)

        annotations.append((mask, obj["labelName"]))

    object_count = 0 # bugs_ds[image_index]["Object Count"]
    victim_name = bugs_ds[image_index]["Victim Name"]
    bug_type = bugs_ds[image_index]["Tag"]

    bug_image = bugs_ds[image_index]["Buggy Image"]
    correct_image = bugs_ds[image_index]["Correct Image"]





                # # Load a single image sample from the first dataset for demonstration
                # image_sample = sample_dataset1[0]

                # # Extract annotations for this image sample
                # objects_json = image_sample["Objects JSON (Correct)"]
                # objects = json.loads(objects_json)
                # segmentation_image_rgb = np.array(image_sample["Segmentation Image (Correct)"])

                # # Generate the overlaid image with custom font size and color
                # overlaid_image = generate_overlay_image(np.array(image_sample["Buggy Image"]), objects, segmentation_image_rgb, font_scale=0.7, font_color=(255, 0, 0))

                # # Display the overlaid image
                # plt.imshow(overlaid_image)
                # plt.axis('off')
                # plt.show()

    overlaid_image = generate_overlay_image(np.array(bugs_ds[image_index]["Buggy Image"]), objects, segmentation_image_rgb)


    return (
        bug_image,
        correct_image,
        (bugs_ds[image_index]["Correct Image"], annotations),
        overlaid_image,
        objects,
        object_count,
        victim_name,
        bug_type,
    )

def update_slider(selected_dataset):
    dataset = sample_dataset1 if selected_dataset == 'Western Scene' else sample_dataset2
    return gr.update(minimum=0, maximum=len(dataset) - 1, step=1)

# Setting up the Gradio interface using blocks API
with gr.Blocks() as demo:
    gr.Markdown(
        "Enter the image index and click **Submit** to view the segmentation annotations."
    )
    
    with gr.Row():
        selected_dataset = gr.Dropdown(['Western Scene', 'Viking Village'], label="Dataset") 
        input_slider = gr.Slider(
            minimum=0, maximum=1, step=1, label="Image Index"
        )
        btn = gr.Button("Visualize")
    with gr.Row():
        bug_image = gr.Image()
        correct_image = gr.Image()
    with gr.Row():
        seg_img = gr.AnnotatedImage()
        overlaid_img = gr.Image()
    with gr.Row():
        object_count = gr.Number(label="Object Count")
        victim_name = gr.Textbox(label="Victim Name")
        bug_type = gr.Textbox(label="Bug Type")

    with gr.Row():
        json_data = gr.JSON()

    btn.click(
        fn=generate_annotations,
        inputs=[selected_dataset, input_slider],
        outputs=[bug_image, correct_image, seg_img, overlaid_img, json_data, object_count, victim_name, bug_type],
    )

    selected_dataset.change(
        fn=update_slider,
        inputs=[selected_dataset],
        outputs=[input_slider]
    )

demo.launch()