SamiKhokhar commited on
Commit
b0c391e
·
verified ·
1 Parent(s): 00ace93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -11
app.py CHANGED
@@ -3,26 +3,50 @@ import torch
3
  import gradio as gr
4
  from PIL import Image
5
 
6
- # Step 1: Load the custom YOLOv5 model
7
- # Make sure the path to the model is correct, either locally or from Hugging Face model hub
8
- model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt') # Adjust the path if needed
9
 
10
- # Step 2: Define weapon classes to detect
11
- weapon_classes = ['pistol', 'gun', 'rifle', 'shotgun', 'handgun']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def detect_weapons(image):
14
  results = model(image)
15
-
 
 
 
 
16
  # Filter detections by confidence threshold (0.5 or higher)
17
  confidence_threshold = 0.5
18
  filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
19
-
20
  # Get the detected classes with high confidence
21
  detected_classes = filtered_results['name'].unique()
22
 
 
 
 
23
  # Check if any of the detected objects are weapons
24
  detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
25
-
26
  # Determine threat message based on weapons detected
27
  if detected_threats:
28
  threat_message = "Threat detected: Be careful"
@@ -35,7 +59,7 @@ def detect_weapons(image):
35
  # Render the image with bounding boxes
36
  return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
37
 
38
- # Step 3: Gradio Interface
39
  def inference(image):
40
  threat, detected_image = detect_weapons(image)
41
  return threat, detected_image
@@ -48,8 +72,8 @@ iface = gr.Interface(
48
  gr.Image(label="Detected Image"),
49
  ],
50
  title="Weapon Detection AI",
51
- description="Upload an image to detect weapons like pistols, guns, and rifles."
52
  )
53
 
54
- # Step 4: Launch Gradio App
55
  iface.launch()
 
3
  import gradio as gr
4
  from PIL import Image
5
 
6
+ # Step 1: Search for best.pt in the training directory
7
+ base_path = "yolov5/runs/train/"
8
+ best_path = None
9
 
10
+ # Search through the directory structure to find best.pt
11
+ for root, dirs, files in os.walk(base_path):
12
+ if "best.pt" in files:
13
+ best_path = os.path.join(root, "best.pt")
14
+ break
15
+
16
+ # Step 2: If best.pt is not found, use pre-trained weights
17
+ if best_path is None:
18
+ print("Trained weights (best.pt) not found.")
19
+ print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.")
20
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights
21
+ else:
22
+ print(f"Model weights found at: {best_path}")
23
+ model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
24
+
25
+ # Step 3: Define weapon classes to detect
26
+ weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka',
27
+ 'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG',
28
+ 'Sniper', 'Sword'] # Adjust based on your dataset
29
 
30
  def detect_weapons(image):
31
  results = model(image)
32
+
33
+ # Print available model class names to check for class mismatches
34
+ model_classes = results.names # This should give the list of class labels used by the model
35
+ print("Model class names:", model_classes)
36
+
37
  # Filter detections by confidence threshold (0.5 or higher)
38
  confidence_threshold = 0.5
39
  filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
40
+
41
  # Get the detected classes with high confidence
42
  detected_classes = filtered_results['name'].unique()
43
 
44
+ # Print detected classes for debugging
45
+ print("Detected classes:", detected_classes)
46
+
47
  # Check if any of the detected objects are weapons
48
  detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
49
+
50
  # Determine threat message based on weapons detected
51
  if detected_threats:
52
  threat_message = "Threat detected: Be careful"
 
59
  # Render the image with bounding boxes
60
  return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
61
 
62
+ # Step 4: Gradio Interface
63
  def inference(image):
64
  threat, detected_image = detect_weapons(image)
65
  return threat, detected_image
 
72
  gr.Image(label="Detected Image"),
73
  ],
74
  title="Weapon Detection AI",
75
+ description="Upload an image to detect weapons like bombs, guns, and pistols."
76
  )
77
 
78
+ # Step 5: Launch Gradio App
79
  iface.launch()