SamiKhokhar commited on
Commit
bbe4eba
·
verified ·
1 Parent(s): d8b410a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -27
app.py CHANGED
@@ -1,56 +1,81 @@
1
  import os
2
  import torch
 
 
3
  import gradio as gr
4
  from PIL import Image
5
 
6
- # Step 1: Search for best.pt in the training directory
7
  base_path = "yolov5/runs/train/"
8
  best_path = None
9
-
10
- # Search through the directory structure to find best.pt
11
  for root, dirs, files in os.walk(base_path):
12
  if "best.pt" in files:
13
  best_path = os.path.join(root, "best.pt")
14
  break
15
 
16
- # Step 2: If best.pt is not found, use pre-trained weights
17
  if best_path is None:
18
- print("Trained weights (best.pt) not found.")
19
- print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.")
20
-
21
- # Fallback to a pre-trained model if best.pt is not available
22
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights
23
-
24
  else:
25
  print(f"Model weights found at: {best_path}")
26
- # Load YOLOv5 model with the correct path
27
  model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
28
 
29
- # Step 3: Detection Function
 
 
 
30
  def detect_weapons(image):
31
  results = model(image)
32
- detected_classes = results.pandas().xyxy[0]['name'].unique()
33
-
34
- # Check for threats
35
- threat_message = "Threat detected: Be careful" if len(detected_classes) > 0 else "No threat detected"
36
  return threat_message, Image.fromarray(results.render()[0])
37
 
38
- # Step 4: Gradio Interface
39
- def inference(image):
40
- threat, detected_image = detect_weapons(image)
41
- return threat, detected_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
43
  iface = gr.Interface(
44
- fn=inference,
45
  inputs=gr.Image(type="numpy", label="Upload Image"),
46
- outputs=[
47
- gr.Textbox(label="Threat Detection"),
48
- gr.Image(label="Detected Image")
49
  ],
50
  title="Weapon Detection AI",
51
- description="Upload an image to detect weapons like bombs, guns, and pistols."
52
  )
53
 
54
- # Step 5: Launch Gradio App
55
- iface.launch()
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
1
  import os
2
  import torch
3
+ import cv2
4
+ import numpy as np
5
  import gradio as gr
6
  from PIL import Image
7
 
8
+ # Load YOLOv5 model
9
  base_path = "yolov5/runs/train/"
10
  best_path = None
 
 
11
  for root, dirs, files in os.walk(base_path):
12
  if "best.pt" in files:
13
  best_path = os.path.join(root, "best.pt")
14
  break
15
 
 
16
  if best_path is None:
17
+ print("Trained weights (best.pt) not found. Using pre-trained weights.")
18
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
 
 
 
 
19
  else:
20
  print(f"Model weights found at: {best_path}")
 
21
  model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
22
 
23
+ # Define weapon-related classes
24
+ WEAPON_CLASSES = ["pistol", "rifle", "knife", "bomb", "gun", "weapon"]
25
+
26
+ # Detection function
27
  def detect_weapons(image):
28
  results = model(image)
29
+ detected_classes = results.pandas().xyxy[0]['name']
30
+ weapons = [cls for cls in detected_classes if cls in WEAPON_CLASSES]
31
+ threat_message = "Threat detected: Be careful" if weapons else "No threat detected"
 
32
  return threat_message, Image.fromarray(results.render()[0])
33
 
34
+ # Real-time detection from webcam
35
+ def detect_from_camera():
36
+ cap = cv2.VideoCapture(0) # Open webcam (use 0 for default camera)
37
+ while cap.isOpened():
38
+ ret, frame = cap.read()
39
+ if not ret:
40
+ break
41
+ # Convert frame (OpenCV format) to PIL image
42
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
43
+ results = model(image)
44
+ detected_classes = results.pandas().xyxy[0]['name']
45
+ weapons = [cls for cls in detected_classes if cls in WEAPON_CLASSES]
46
+ annotated_frame = np.array(results.render()[0])
47
+
48
+ # Display live feed with annotations
49
+ cv2.imshow("Weapon Detection", cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
50
+ if cv2.waitKey(1) & 0xFF == ord('q'): # Press 'q' to quit
51
+ break
52
+ cap.release()
53
+ cv2.destroyAllWindows()
54
 
55
+ # Gradio interface for image uploads
56
  iface = gr.Interface(
57
+ fn=detect_weapons,
58
  inputs=gr.Image(type="numpy", label="Upload Image"),
59
+ outputs=[
60
+ gr.Textbox(label="Threat Detection"),
61
+ gr.Image(label="Detected Image"),
62
  ],
63
  title="Weapon Detection AI",
64
+ description="Upload an image to detect weapons like bombs, guns, and pistols. For real-time detection, use the live camera option."
65
  )
66
 
67
+ # Main entry point
68
+ def main():
69
+ print("Select mode:")
70
+ print("1. Image Upload (Gradio Interface)")
71
+ print("2. Live Camera Detection")
72
+ choice = input("Enter your choice (1 or 2): ").strip()
73
+ if choice == "1":
74
+ iface.launch()
75
+ elif choice == "2":
76
+ detect_from_camera()
77
+ else:
78
+ print("Invalid choice. Please restart and select 1 or 2.")
79
 
80
+ if __name__ == "__main__":
81
+ main()