Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,55 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
6 |
|
7 |
-
# Load YOLOv5 model
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
if best_path is None:
|
16 |
-
print("Trained weights (best.pt) not found. Using pre-trained weights.")
|
17 |
-
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
|
18 |
-
else:
|
19 |
-
print(f"Model weights found at: {best_path}")
|
20 |
-
model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
|
21 |
-
|
22 |
-
# Detection function
|
23 |
def detect_weapons(image):
|
24 |
results = model(image)
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
# Gradio interface
|
30 |
iface = gr.Interface(
|
31 |
-
fn=
|
32 |
inputs=gr.Image(type="numpy", label="Upload Image"),
|
33 |
outputs=[
|
34 |
gr.Textbox(label="Threat Detection"),
|
35 |
gr.Image(label="Detected Image"),
|
36 |
],
|
37 |
title="Weapon Detection AI",
|
38 |
-
description="Upload an image to detect weapons like
|
39 |
)
|
40 |
|
41 |
-
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from PIL import Image
|
5 |
|
6 |
+
# Step 1: Load the custom YOLOv5 model
|
7 |
+
# Adjust the model path based on your deployment setup (Hugging Face model hub or local path)
|
8 |
+
model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt') # If you use Hugging Face model hub, use the correct identifier
|
9 |
+
|
10 |
+
# Step 2: Define weapon classes to detect
|
11 |
+
weapon_classes = ['pistol', 'gun', 'rifle', 'shotgun', 'handgun']
|
12 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def detect_weapons(image):
|
14 |
results = model(image)
|
15 |
+
|
16 |
+
# Filter detections by confidence threshold (0.5 or higher)
|
17 |
+
confidence_threshold = 0.5
|
18 |
+
filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
|
19 |
+
|
20 |
+
# Get the detected classes with high confidence
|
21 |
+
detected_classes = filtered_results['name'].unique()
|
22 |
+
|
23 |
+
# Check if any of the detected objects are weapons
|
24 |
+
detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
|
25 |
+
|
26 |
+
# Determine threat message based on weapons detected
|
27 |
+
if detected_threats:
|
28 |
+
threat_message = "Threat detected: Be careful"
|
29 |
+
else:
|
30 |
+
threat_message = "No threat detected. But all other features are good."
|
31 |
+
|
32 |
+
# Create a string with the detected objects' names
|
33 |
+
detected_objects = ', '.join(detected_classes)
|
34 |
+
|
35 |
+
# Render the image with bounding boxes
|
36 |
+
return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
|
37 |
+
|
38 |
+
# Step 3: Gradio Interface
|
39 |
+
def inference(image):
|
40 |
+
threat, detected_image = detect_weapons(image)
|
41 |
+
return threat, detected_image
|
42 |
|
|
|
43 |
iface = gr.Interface(
|
44 |
+
fn=inference,
|
45 |
inputs=gr.Image(type="numpy", label="Upload Image"),
|
46 |
outputs=[
|
47 |
gr.Textbox(label="Threat Detection"),
|
48 |
gr.Image(label="Detected Image"),
|
49 |
],
|
50 |
title="Weapon Detection AI",
|
51 |
+
description="Upload an image to detect weapons like pistols, guns, and rifles."
|
52 |
)
|
53 |
|
54 |
+
# Step 4: Launch Gradio App
|
55 |
+
iface.launch()
|