File size: 5,477 Bytes
439a441
 
675f436
439a441
 
 
 
675f436
439a441
 
675f436
439a441
 
 
675f436
 
 
 
439a441
675f436
 
 
439a441
 
675f436
 
 
 
 
439a441
 
675f436
 
 
 
 
 
 
 
439a441
675f436
439a441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675f436
 
 
 
 
 
439a441
 
 
 
 
 
 
675f436
439a441
 
 
 
 
 
 
 
 
 
 
 
675f436
 
439a441
 
 
675f436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d19056a
675f436
 
 
 
 
 
439a441
675f436
439a441
675f436
439a441
 
 
675f436
439a441
 
675f436
 
 
 
439a441
675f436
 
 
 
27b9ba6
 
675f436
 
439a441
 
675f436
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont

from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection

# Note: Can load from Hugging Face or can load from local 
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector"

# Load the model and preprocessor
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Get the id2label dictionary from the model
id2label = model.config.id2label

# Set up a colour dictionary for plotting boxes with different colours
color_dict = {   
    "bin": "green",
    "trash": "blue",
    "hand": "purple",
    "trash_arm": "yellow",
    "not_trash": "red",
    "not_bin": "red",
    "not_hand": "red",
}

# Create helper functions for seeing if items from one list are in another 
def any_in_list(list_a, list_b):
    "Returns True if any item from list_a is in list_b, otherwise False."
    return any(item in list_b for item in list_a)

def all_in_list(list_a, list_b):
    "Returns True if all items from list_a are in list_b, otherwise False."
    return all(item in list_b for item in list_a)

def predict_on_image(image, conf_threshold):
    with torch.no_grad():
        inputs = image_processor(images=[image], return_tensors="pt")
        outputs = model(**inputs.to(device))

        target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width 

        results = image_processor.post_process_object_detection(outputs,
                                                                threshold=conf_threshold,
                                                                target_sizes=target_sizes)[0]
    # Return all items in results to CPU
    for key, value in results.items():
        try:
            results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
        except:
            results[key] = value.cpu()

    # Can return results as plotted on a PIL image (then display the image)
    draw = ImageDraw.Draw(image)

    # Get a font from ImageFont
    font = ImageFont.load_default(size=20)

    # Get class names as text for print out
    class_name_text_labels = []

    for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
        # Create coordinates
        x, y, x2, y2 = tuple(box.tolist())

        # Get label_name
        label_name = id2label[label.item()]
        targ_color = color_dict[label_name]
        class_name_text_labels.append(label_name)

        # Draw the rectangle
        draw.rectangle(xy=(x, y, x2, y2), 
                       outline=targ_color,
                       width=3)
        
        # Create a text string to display
        text_string_to_show = f"{label_name} ({round(score.item(), 3)})"

        # Draw the text on the image
        draw.text(xy=(x, y),
                  text=text_string_to_show,
                  fill="white",
                  font=font)
    
    # Remove the draw each time
    del draw

    # Setup blank string to print out
    return_string = ""

    # Setup list of target items to discover
    target_items = ["trash", "bin", "hand"]

    # If no items detected or trash, bin, hand not in list, return notification 
    if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
        return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
        return image, return_string

    # If there are some missing, print the ones which are missing
    elif not all_in_list(list_a=target_items, list_b=class_name_text_labels):
        missing_items = []
        for item in target_items:
            if item not in class_name_text_labels:
                missing_items.append(item)
        return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
        
    # If all 3 trash, bin, hand occur = + 1
    if all_in_list(list_a=target_items, list_b=class_name_text_labels):
        return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!"

    print(return_string)
    
    return image, return_string

# Create the interface
demo = gr.Interface(
    fn=predict_on_image,
    inputs=[
        gr.Image(type="pil", label="Target Image"),
        gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
    ],
    outputs=[
        gr.Image(type="pil", label="Image Output"),
        gr.Text(label="Text Output")
    ],
    title="🚮 Trashify Object Detection Demo",
    description="Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.",
    # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
    examples=[
        ["examples/trashify_example_1.jpeg", 0.25],
        ["examples/trashify_example_2.jpeg", 0.25],
        ["examples/trashify_example_3.jpeg", 0.25],
    ],
    cache_examples=True
)

# Launch the demo
demo.launch()