File size: 4,365 Bytes
8f7598e
 
 
 
 
c9e9eb6
8f7598e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9e9eb6
8f7598e
 
 
 
 
 
c9e9eb6
 
 
 
 
 
 
 
 
 
 
8f7598e
 
 
 
 
 
 
c9e9eb6
8f7598e
 
 
 
 
 
 
 
 
c9e9eb6
 
 
 
 
 
 
 
 
 
8f7598e
 
 
 
 
 
 
 
 
 
 
c9e9eb6
 
 
8f7598e
 
 
 
c9e9eb6
 
 
 
 
8f7598e
 
 
c9e9eb6
 
8f7598e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9e9eb6
 
8f7598e
0a3a9fc
 
c9e9eb6
0a3a9fc
8f7598e
 
 
 
9cb0704
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from safetensors.torch import load_model, save_model
from models import *
import os


class WasteClassifier:
    def __init__(self, model, class_names, device):
        self.model = model
        self.class_names = class_names
        self.device = device
        self.transform = transforms.Compose(
            [
                transforms.Resize((384, 384)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def predict(self, image):
        self.model.eval()

        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        original_size = image.size
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs, seg_mask = self.model(img_tensor)  # Handle both outputs
            probabilities = torch.nn.functional.softmax(outputs, dim=1)

        probs = probabilities[0].cpu().numpy()
        pred_class = self.class_names[np.argmax(probs)]
        confidence = np.max(probs)

        # Process segmentation mask
        seg_mask = (
            seg_mask[0, 0].cpu().numpy().astype(np.float32)
        )  # Get first image, first channel
        # seg_mask = (seg_mask >= 0.2).astype(np.float32)  # Threshold at 0.2

        # Resize mask back to original image size
        seg_mask = Image.fromarray(seg_mask)
        seg_mask = seg_mask.resize(original_size, Image.NEAREST)
        seg_mask = np.array(seg_mask)

        results = {
            "predicted_class": pred_class,
            "confidence": confidence,
            "class_probabilities": {
                class_name: float(prob)
                for class_name, prob in zip(self.class_names, probs)
            },
            "segmentation_mask": seg_mask,
        }

        return results


def interface(classifier):
    def process_image(image):
        results = classifier.predict(image)

        if isinstance(image, Image.Image):
            image_np = np.array(image)
        else:
            image_np = image

        mask = results["segmentation_mask"]

        overlay = image_np.copy()
        overlay[mask < 0.2] = overlay[mask < 0.2] * 0

        output_str = f"Predicted Class: {results['predicted_class']}\n"
        output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
        output_str += "Class Probabilities:\n"

        sorted_probs = sorted(
            results["class_probabilities"].items(), key=lambda x: x[1], reverse=True
        )

        for class_name, prob in sorted_probs:
            output_str += f"{class_name}: {prob*100:.2f}%\n"

        mask_viz = (mask * 255).astype(np.uint8)

        return [output_str, overlay, mask_viz]

    demo = gr.Interface(
        fn=process_image,
        inputs=[gr.Image(type="pil", label="Upload Image")],
        outputs=[
            gr.Textbox(label="Classification Results"),
            gr.Image(label="Segmented Object"),
            gr.Image(label="Segmentation Mask"),
        ],
        title="Waste Classification System",
        description="""
        Upload an image of waste to classify it into different categories.
        The model will predict the type of waste, show confidence scores for each category,
        and display the segmented object along with its mask.
        """,
        examples=(
            [["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
            if os.path.exists("example1.jpg")
            else None
        ),
        analytics_enabled=False,
        theme="default",
    )

    return demo


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = [
    "Cardboard",
    "Food Organics",
    "Glass",
    "Metal",
    "Miscellaneous Trash",
    "Paper",
    "Plastic",
    "Textile Trash",
    "Vegetation",
]

best_model = ResNet101UNet(num_classes=len(class_names))
best_model = best_model.to(device)
load_model(
    best_model,
    os.path.join(os.path.dirname(os.path.abspath(__file__)), "3q7y4e.safetensors"),
)

classifier = WasteClassifier(best_model, class_names, device)

demo = interface(classifier)
demo.launch()