File size: 2,515 Bytes
ac831c4
a1e8b35
 
 
 
5f4e276
 
 
a1e8b35
5f4e276
 
ff9f2a3
a364f88
a1e8b35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8215f62
a1e8b35
 
 
 
 
 
 
 
 
 
7f01c6a
 
a1e8b35
 
 
 
 
 
5f4e276
a1e8b35
5f4e276
a1e8b35
 
5f4e276
a1e8b35
 
5f4e276
 
a1e8b35
 
5f4e276
a1e8b35
 
5f4e276
a1e8b35
5f4e276
a1e8b35
5f4e276
 
 
 
a1e8b35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import supervision as sv
import numpy as np
import cv2
from inference import get_roboflow_model
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()
api_key = os.getenv("ROBOFLOW_API_KEY")
model_id = os.getenv("ROBOFLOW_PROJECT")
model_version = os.getenv("ROBOFLOW_MODEL_VERSION")

# Initialize the Roboflow model
model = get_roboflow_model(model_id=f"{model_id}/{model_version}", api_key=api_key)

# Callback function for SAHI Slicer
def callback(image_slice: np.ndarray) -> sv.Detections:
    results = model.infer(image_slice)[0]
    return sv.Detections.from_inference(results)

# Object detection function
def detect_objects_with_sahi(image):
    # Convert Gradio PIL image to NumPy array
    image_np = np.array(image)

    # Run inference with SAHI Slicer
    slicer = sv.InferenceSlicer(callback=callback, overlap_wh=(50, 50), overlap_ratio_wh=None)
    sliced_detections = slicer(image=image_np)

    # Annotate image with detected objects
    label_annotator = sv.LabelAnnotator()
    box_annotator = sv.BoxAnnotator()
    annotated_image = box_annotator.annotate(scene=image_np.copy(), detections=sliced_detections)
    annotated_image = label_annotator.annotate(scene=annotated_image, detections=sliced_detections)

    # Count objects by class
    class_counts = {}
    for i in range(len(sliced_detections.class_id)):  # Iterate over the detections
        class_name = sliced_detections.class_id[i]
        class_counts[class_name] = class_counts.get(class_name, 0) + 1

    # Create summary text
    total_objects = sum(class_counts.values())
    result_text = "Detected Objects:\n"
    for class_name, count in class_counts.items():
        result_text += f"{class_name}: {count}\n"
    result_text += f"\nTotal Objects: {total_objects}"

    # Return the annotated image and summary text
    return annotated_image, result_text

# Create Gradio interface
with gr.Blocks() as app:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Upload Image")
            detect_button = gr.Button("Detect Objects")
        with gr.Column():
            output_image = gr.Image(label="Annotated Image")
            output_text = gr.Textbox(label="Object Count Summary", lines=10)

    # Link button to detection function
    detect_button.click(
        fn=detect_objects_with_sahi,
        inputs=input_image,
        outputs=[output_image, output_text]
    )

# Launch Gradio app
app.launch()