Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,26 +3,50 @@ import torch
|
|
3 |
import gradio as gr
|
4 |
from PIL import Image
|
5 |
|
6 |
-
# Step 1:
|
7 |
-
|
8 |
-
|
9 |
|
10 |
-
#
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def detect_weapons(image):
|
14 |
results = model(image)
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
# Filter detections by confidence threshold (0.5 or higher)
|
17 |
confidence_threshold = 0.5
|
18 |
filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
|
19 |
-
|
20 |
# Get the detected classes with high confidence
|
21 |
detected_classes = filtered_results['name'].unique()
|
22 |
|
|
|
|
|
|
|
23 |
# Check if any of the detected objects are weapons
|
24 |
detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
|
25 |
-
|
26 |
# Determine threat message based on weapons detected
|
27 |
if detected_threats:
|
28 |
threat_message = "Threat detected: Be careful"
|
@@ -35,7 +59,7 @@ def detect_weapons(image):
|
|
35 |
# Render the image with bounding boxes
|
36 |
return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
|
37 |
|
38 |
-
# Step
|
39 |
def inference(image):
|
40 |
threat, detected_image = detect_weapons(image)
|
41 |
return threat, detected_image
|
@@ -48,8 +72,8 @@ iface = gr.Interface(
|
|
48 |
gr.Image(label="Detected Image"),
|
49 |
],
|
50 |
title="Weapon Detection AI",
|
51 |
-
description="Upload an image to detect weapons like
|
52 |
)
|
53 |
|
54 |
-
# Step
|
55 |
iface.launch()
|
|
|
3 |
import gradio as gr
|
4 |
from PIL import Image
|
5 |
|
6 |
+
# Step 1: Search for best.pt in the training directory
|
7 |
+
base_path = "yolov5/runs/train/"
|
8 |
+
best_path = None
|
9 |
|
10 |
+
# Search through the directory structure to find best.pt
|
11 |
+
for root, dirs, files in os.walk(base_path):
|
12 |
+
if "best.pt" in files:
|
13 |
+
best_path = os.path.join(root, "best.pt")
|
14 |
+
break
|
15 |
+
|
16 |
+
# Step 2: If best.pt is not found, use pre-trained weights
|
17 |
+
if best_path is None:
|
18 |
+
print("Trained weights (best.pt) not found.")
|
19 |
+
print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.")
|
20 |
+
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights
|
21 |
+
else:
|
22 |
+
print(f"Model weights found at: {best_path}")
|
23 |
+
model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
|
24 |
+
|
25 |
+
# Step 3: Define weapon classes to detect
|
26 |
+
weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka',
|
27 |
+
'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG',
|
28 |
+
'Sniper', 'Sword'] # Adjust based on your dataset
|
29 |
|
30 |
def detect_weapons(image):
|
31 |
results = model(image)
|
32 |
+
|
33 |
+
# Print available model class names to check for class mismatches
|
34 |
+
model_classes = results.names # This should give the list of class labels used by the model
|
35 |
+
print("Model class names:", model_classes)
|
36 |
+
|
37 |
# Filter detections by confidence threshold (0.5 or higher)
|
38 |
confidence_threshold = 0.5
|
39 |
filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
|
40 |
+
|
41 |
# Get the detected classes with high confidence
|
42 |
detected_classes = filtered_results['name'].unique()
|
43 |
|
44 |
+
# Print detected classes for debugging
|
45 |
+
print("Detected classes:", detected_classes)
|
46 |
+
|
47 |
# Check if any of the detected objects are weapons
|
48 |
detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
|
49 |
+
|
50 |
# Determine threat message based on weapons detected
|
51 |
if detected_threats:
|
52 |
threat_message = "Threat detected: Be careful"
|
|
|
59 |
# Render the image with bounding boxes
|
60 |
return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
|
61 |
|
62 |
+
# Step 4: Gradio Interface
|
63 |
def inference(image):
|
64 |
threat, detected_image = detect_weapons(image)
|
65 |
return threat, detected_image
|
|
|
72 |
gr.Image(label="Detected Image"),
|
73 |
],
|
74 |
title="Weapon Detection AI",
|
75 |
+
description="Upload an image to detect weapons like bombs, guns, and pistols."
|
76 |
)
|
77 |
|
78 |
+
# Step 5: Launch Gradio App
|
79 |
iface.launch()
|