Spaces:
Sleeping
Sleeping
File size: 4,365 Bytes
8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e 0a3a9fc c9e9eb6 0a3a9fc 8f7598e 9cb0704 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from safetensors.torch import load_model, save_model
from models import *
import os
class WasteClassifier:
def __init__(self, model, class_names, device):
self.model = model
self.class_names = class_names
self.device = device
self.transform = transforms.Compose(
[
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def predict(self, image):
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
original_size = image.size
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs, seg_mask = self.model(img_tensor) # Handle both outputs
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probs = probabilities[0].cpu().numpy()
pred_class = self.class_names[np.argmax(probs)]
confidence = np.max(probs)
# Process segmentation mask
seg_mask = (
seg_mask[0, 0].cpu().numpy().astype(np.float32)
) # Get first image, first channel
# seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2
# Resize mask back to original image size
seg_mask = Image.fromarray(seg_mask)
seg_mask = seg_mask.resize(original_size, Image.NEAREST)
seg_mask = np.array(seg_mask)
results = {
"predicted_class": pred_class,
"confidence": confidence,
"class_probabilities": {
class_name: float(prob)
for class_name, prob in zip(self.class_names, probs)
},
"segmentation_mask": seg_mask,
}
return results
def interface(classifier):
def process_image(image):
results = classifier.predict(image)
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
mask = results["segmentation_mask"]
overlay = image_np.copy()
overlay[mask < 0.2] = overlay[mask < 0.2] * 0
output_str = f"Predicted Class: {results['predicted_class']}\n"
output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
output_str += "Class Probabilities:\n"
sorted_probs = sorted(
results["class_probabilities"].items(), key=lambda x: x[1], reverse=True
)
for class_name, prob in sorted_probs:
output_str += f"{class_name}: {prob*100:.2f}%\n"
mask_viz = (mask * 255).astype(np.uint8)
return [output_str, overlay, mask_viz]
demo = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Upload Image")],
outputs=[
gr.Textbox(label="Classification Results"),
gr.Image(label="Segmented Object"),
gr.Image(label="Segmentation Mask"),
],
title="Waste Classification System",
description="""
Upload an image of waste to classify it into different categories.
The model will predict the type of waste, show confidence scores for each category,
and display the segmented object along with its mask.
""",
examples=(
[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
if os.path.exists("example1.jpg")
else None
),
analytics_enabled=False,
theme="default",
)
return demo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = [
"Cardboard",
"Food Organics",
"Glass",
"Metal",
"Miscellaneous Trash",
"Paper",
"Plastic",
"Textile Trash",
"Vegetation",
]
best_model = ResNet101UNet(num_classes=len(class_names))
best_model = best_model.to(device)
load_model(
best_model,
os.path.join(os.path.dirname(os.path.abspath(__file__)), "3q7y4e.safetensors"),
)
classifier = WasteClassifier(best_model, class_names, device)
demo = interface(classifier)
demo.launch()
|