GradioApp / app.py
Manu101's picture
Update app.py
51ef4a8 verified
raw
history blame
2.78 kB
import torch
import torchvision
from torchvision import transforms
import gradio as gr
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18
model = ResNet18()
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu') ), strict=False)
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std = [1/0.23, 1/0.23, 1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def resize_image_pil(image, new_width, new_height):
# convert to PIL IMage
img = Image.fromarray(np.array(image))
# get original size
width, height = img.size
# calculate scale
width_scale = new_width/width
height_scale = new_height/height
scale = min(width_scale, height_scale)
# resize
resized = img.resize(size=(int(width*scale), int(height*scale)), resample=Image.NEAREST)
# crop resized image
resized = resized.crop((0, 0, new_width, new_height))
return resized
def inference(input_image, transparency=0.5, target_layer_number=-1):
input_image = resize_image_pil(input_image, 32, 32)
input_image = np.array(input_image)
org_img = input_image
input_image = input_image.reshape((32, 32, 3))
transforms = transforms.ToTensor()
input_image = transforms(input_image)
input_image = input_image.unsqueeze(0)
outputs = model(input_image)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
target_layers = [model.layer2[target_layer_number]]
cam = GradCAM(model= model, target_layers = target_layers)
grayscale_cam = cam(input_tensor=input_image, target=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(
org_img/255,
grayscale_cam,
use_rgb=True,
image_weight = transparency
)
return classes[prediction[0].item(), visualization, confidences]
demo = gr.Interface(
inference,
inputs = [
gr,Image(width=256, height=256, label="Input Image"),
gr.Slider(0, 1, value=0.5, label="Overall opacity fo the overlay"),
gr.Slider(-2, -1, value=-2, step=1, label="Which GradCAM layer?")
],
outputs = [
"text",
gr.Image(width=256, height=256, label="Output"),
gr.Label(num_top_classes=3)
],
title="CIFAR10 trained on ResNet18 with GradCAM feature",
description = "A simple Gradio app for checking GradCAM outputs from results of ResNet18 model.",
examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.7, -2]]
)
demo.launch()