import gradio as gr

from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation

feature_extractor = SegformerFeatureExtractor.from_pretrained(
    "nvidia/segformer-b5-finetuned-ade-640-640"
)
model = TFSegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b5-finetuned-ade-640-640"
)

def ade_palette():
    """ADE20K palette that maps each class to RGB values."""
    return [
        [120, 120, 120],
        [180, 120, 120],
        [6, 230, 230],
        [80, 50, 50],
        [4, 200, 3],
        [120, 120, 80],
        [140, 140, 140],
        [204, 5, 255],
        [230, 230, 230],
        [4, 250, 7],
        [224, 5, 255],
        [235, 255, 7],
        [150, 5, 61],
        [120, 120, 70],
        [8, 255, 51],
        [255, 6, 82],
        [143, 255, 140],
        [204, 255, 4],
        [255, 51, 7],
        [204, 70, 3],
        [0, 102, 200],
        [61, 230, 250],
        [255, 6, 51],
        [11, 102, 255],
        [255, 7, 71],
        [255, 9, 224],
        [9, 7, 230],
        [220, 220, 220],
        [255, 9, 92],
        [112, 9, 255],
        [8, 255, 214],
        [7, 255, 224],
        [255, 184, 6],
        [10, 255, 71],
        [255, 41, 10],
        [7, 255, 255],
        [224, 255, 8],
        [102, 8, 255],
        [255, 61, 6],
        [255, 194, 7],
        [255, 122, 8],
        [0, 255, 20],
        [255, 8, 41],
        [255, 5, 153],
        [6, 51, 255],
        [235, 12, 255],
        [160, 150, 20],
        [0, 163, 255],
        [140, 140, 140],
        [250, 10, 15],
        [20, 255, 0],
        [31, 255, 0],
        [255, 31, 0],
        [255, 224, 0],
        [153, 255, 0],
        [0, 0, 255],
        [255, 71, 0],
        [0, 235, 255],
        [0, 173, 255],
        [31, 0, 255],
        [11, 200, 200],
        [255, 82, 0],
        [0, 255, 245],
        [0, 61, 255],
        [0, 255, 112],
        [0, 255, 133],
        [255, 0, 0],
        [255, 163, 0],
        [255, 102, 0],
        [194, 255, 0],
        [0, 143, 255],
        [51, 255, 0],
        [0, 82, 255],
        [0, 255, 41],
        [0, 255, 173],
        [10, 0, 255],
        [173, 255, 0],
        [0, 255, 153],
        [255, 92, 0],
        [255, 0, 255],
        [255, 0, 245],
        [255, 0, 102],
        [255, 173, 0],
        [255, 0, 20],
        [255, 184, 184],
        [0, 31, 255],
        [0, 255, 61],
        [0, 71, 255],
        [255, 0, 204],
        [0, 255, 194],
        [0, 255, 82],
        [0, 10, 255],
        [0, 112, 255],
        [51, 0, 255],
        [0, 194, 255],
        [0, 122, 255],
        [0, 255, 163],
        [255, 153, 0],
        [0, 255, 10],
        [255, 112, 0],
        [143, 255, 0],
        [82, 0, 255],
        [163, 255, 0],
        [255, 235, 0],
        [8, 184, 170],
        [133, 0, 255],
        [0, 255, 92],
        [184, 0, 255],
        [255, 0, 31],
        [0, 184, 255],
        [0, 214, 255],
        [255, 0, 112],
        [92, 255, 0],
        [0, 224, 255],
        [112, 224, 255],
        [70, 184, 160],
        [163, 0, 255],
        [153, 0, 255],
        [71, 255, 0],
        [255, 0, 163],
        [255, 204, 0],
        [255, 0, 143],
        [0, 255, 235],
        [133, 255, 0],
        [255, 0, 235],
        [245, 0, 255],
        [255, 0, 122],
        [255, 245, 0],
        [10, 190, 212],
        [214, 255, 0],
        [0, 204, 255],
        [20, 0, 255],
        [255, 255, 0],
        [0, 153, 255],
        [0, 41, 255],
        [0, 255, 204],
        [41, 0, 255],
        [41, 255, 0],
        [173, 0, 255],
        [0, 245, 255],
        [71, 0, 255],
        [122, 0, 255],
        [0, 255, 184],
        [0, 92, 255],
        [184, 255, 0],
        [0, 133, 255],
        [255, 214, 0],
        [25, 194, 194],
        [102, 255, 0],
        [92, 0, 255],
    ]

labels_list = [
     'wall',
     'building;edifice',
     'sky',
     'floor;flooring',
     'tree',
     'ceiling',
     'road;route',
     'bed',
     'windowpane;window',
     'grass',
     'cabinet',
     'sidewalk;pavement',
     'person;individual;someone;somebody;mortal;soul',
     'earth;ground',
     'door;double;door',
     'table',
     'mountain;mount',
     'plant;flora;plant;life',
     'curtain;drape;drapery;mantle;pall',
     'chair',
     'car;auto;automobile;machine;motorcar',
     'water',
     'painting;picture',
     'sofa;couch;lounge',
     'shelf',
     'house',
     'sea',
     'mirror',
     'rug;carpet;carpeting',
     'field',
     'armchair',
     'seat',
     'fence;fencing',
     'desk',
     'rock;stone',
     'wardrobe;closet;press',
     'lamp',
     'bathtub;bathing;tub;bath;tub',
     'railing;rail',
     'cushion',
     'base;pedestal;stand',
     'box',
     'column;pillar',
     'signboard;sign',
     'chest;of;drawers;chest;bureau;dresser',
     'counter',
     'sand',
     'sink',
     'skyscraper',
     'fireplace;hearth;open;fireplace',
     'refrigerator;icebox',
     'grandstand;covered;stand',
     'path',
     'stairs;steps',
     'runway',
     'case;display;case;showcase;vitrine',
     'pool;table;billiard;table;snooker;table',
     'pillow',
     'screen;door;screen',
     'stairway;staircase',
     'river',
     'bridge;span',
     'bookcase',
     'blind;screen',
     'coffee;table;cocktail;table',
     'toilet;can;commode;crapper;pot;potty;stool;throne',
     'flower',
     'book',
     'hill',
     'bench',
     'countertop',
     'stove;kitchen;stove;range;kitchen;range;cooking;stove',
     'palm;palm;tree',
     'kitchen;island',
     'computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system',
     'swivel;chair',
     'boat',
     'bar',
     'arcade;machine',
     'hovel;hut;hutch;shack;shanty',
     'bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle',
     'towel',
     'light;light;source',
     'truck;motortruck',
     'tower',
     'chandelier;pendant;pendent',
     'awning;sunshade;sunblind',
     'streetlight;street;lamp',
     'booth;cubicle;stall;kiosk',
     'television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box',
     'airplane;aeroplane;plane',
     'dirt;track',
     'apparel;wearing;apparel;dress;clothes',
     'pole',
     'land;ground;soil',
     'bannister;banister;balustrade;balusters;handrail',
     'escalator;moving;staircase;moving;stairway',
     'ottoman;pouf;pouffe;puff;hassock',
     'bottle',
     'buffet;counter;sideboard',
     'poster;posting;placard;notice;bill;card',
     'stage',
     'van',
     'ship',
     'fountain',
     'conveyer;belt;conveyor;belt;conveyer;conveyor;transporter',
     'canopy',
     'washer;automatic;washer;washing;machine',
     'plaything;toy',
     'swimming;pool;swimming;bath;natatorium',
     'stool',
     'barrel;cask',
     'basket;handbasket',
     'waterfall;falls',
     'tent;collapsible;shelter',
     'bag',
     'minibike;motorbike',
     'cradle',
     'oven',
     'ball',
     'food;solid;food',
     'step;stair',
     'tank;storage;tank',
     'trade;name;brand;name;brand;marque',
     'microwave;microwave;oven',
     'pot;flowerpot',
     'animal;animate;being;beast;brute;creature;fauna',
     'bicycle;bike;wheel;cycle',
     'lake',
     'dishwasher;dish;washer;dishwashing;machine',
     'screen;silver;screen;projection;screen',
     'blanket;cover',
     'sculpture',
     'hood;exhaust;hood',
     'sconce',
     'vase',
     'traffic;light;traffic;signal;stoplight',
     'tray',
     'ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin',
     'fan',
     'pier;wharf;wharfage;dock',
     'crt;screen',
     'plate',
     'monitor;monitoring;device',
     'bulletin;board;notice;board',
     'shower',
     'radiator',
     'glass;drinking;glass',
     'clock',
     'flag']

def label_to_color_image(label):
    """Adds color defined by the dataset colormap to the label.

  Args:
    label: A 2D array with integer type, storing the segmentation label.

  Returns:
    result: A 2D array with floating type. The element of the array
      is the color indexed by the corresponding element in the input label
      to the PASCAL color map.

  Raises:
    ValueError: If label is not of rank 2 or its value is larger than color
      map maximum entry.
  """
    if label.ndim != 2:
        raise ValueError("Expect 2-D input label")

    colormap = np.asarray(ade_palette())

    if np.max(label) >= len(colormap):
        raise ValueError("label value too large.")

    return colormap[label]

def draw_plot(pred_img, seg):
    fig = plt.figure(figsize=(20, 15))

    grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(pred_img)
    plt.axis('off')

    LABEL_NAMES = np.asarray(labels_list)
    FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

    unique_labels = np.unique(seg.numpy().astype("uint8"))
    ax = plt.subplot(grid_spec[1])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0, labelsize=25)
    return fig

def sepia(input_img):
    input_img = Image.fromarray(input_img)

    inputs = feature_extractor(images=input_img, return_tensors="tf")
    outputs = model(**inputs)
    logits = outputs.logits

    logits = tf.transpose(logits, [0, 2, 3, 1])
    logits = tf.image.resize(
        logits, input_img.size[::-1]
    )  # We reverse the shape of `image` because `image.size` returns width and height.
    seg = tf.math.argmax(logits, axis=-1)[0]

    color_seg = np.zeros(
        (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
    )  # height, width, 3
    palette = np.array(ade_palette())

    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color

    # Convert to BGR
    color_seg = color_seg[..., ::-1]

    # Show image + mask
    pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
    pred_img = pred_img.astype(np.uint8)    

    fig = draw_plot(pred_img, seg)
    return fig

demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), outputs=['plot'], examples=["ADE_val_00000001.jpeg"])

demo.launch()