SamiKhokhar commited on
Commit
69ebf82
·
verified ·
1 Parent(s): 79ea84b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -24
app.py CHANGED
@@ -1,41 +1,55 @@
1
-
2
  import os
3
  import torch
4
  import gradio as gr
5
  from PIL import Image
6
 
7
- # Load YOLOv5 model
8
- base_path = "yolov5/runs/train/"
9
- best_path = None
10
- for root, dirs, files in os.walk(base_path):
11
- if "best.pt" in files:
12
- best_path = os.path.join(root, "best.pt")
13
- break
14
-
15
- if best_path is None:
16
- print("Trained weights (best.pt) not found. Using pre-trained weights.")
17
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
18
- else:
19
- print(f"Model weights found at: {best_path}")
20
- model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path)
21
-
22
- # Detection function
23
  def detect_weapons(image):
24
  results = model(image)
25
- detected_classes = results.pandas().xyxy[0]['name'].unique()
26
- threat_message = "Threat detected: Be careful" if len(detected_classes) > 0 else "No threat detected"
27
- return threat_message, Image.fromarray(results.render()[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Gradio interface
30
  iface = gr.Interface(
31
- fn=detect_weapons,
32
  inputs=gr.Image(type="numpy", label="Upload Image"),
33
  outputs=[
34
  gr.Textbox(label="Threat Detection"),
35
  gr.Image(label="Detected Image"),
36
  ],
37
  title="Weapon Detection AI",
38
- description="Upload an image to detect weapons like bombs, guns, and pistols."
39
  )
40
 
41
- iface.launch()
 
 
 
1
  import os
2
  import torch
3
  import gradio as gr
4
  from PIL import Image
5
 
6
+ # Step 1: Load the custom YOLOv5 model
7
+ # Adjust the model path based on your deployment setup (Hugging Face model hub or local path)
8
+ model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt') # If you use Hugging Face model hub, use the correct identifier
9
+
10
+ # Step 2: Define weapon classes to detect
11
+ weapon_classes = ['pistol', 'gun', 'rifle', 'shotgun', 'handgun']
12
+
 
 
 
 
 
 
 
 
 
13
  def detect_weapons(image):
14
  results = model(image)
15
+
16
+ # Filter detections by confidence threshold (0.5 or higher)
17
+ confidence_threshold = 0.5
18
+ filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold]
19
+
20
+ # Get the detected classes with high confidence
21
+ detected_classes = filtered_results['name'].unique()
22
+
23
+ # Check if any of the detected objects are weapons
24
+ detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes]
25
+
26
+ # Determine threat message based on weapons detected
27
+ if detected_threats:
28
+ threat_message = "Threat detected: Be careful"
29
+ else:
30
+ threat_message = "No threat detected. But all other features are good."
31
+
32
+ # Create a string with the detected objects' names
33
+ detected_objects = ', '.join(detected_classes)
34
+
35
+ # Render the image with bounding boxes
36
+ return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0])
37
+
38
+ # Step 3: Gradio Interface
39
+ def inference(image):
40
+ threat, detected_image = detect_weapons(image)
41
+ return threat, detected_image
42
 
 
43
  iface = gr.Interface(
44
+ fn=inference,
45
  inputs=gr.Image(type="numpy", label="Upload Image"),
46
  outputs=[
47
  gr.Textbox(label="Threat Detection"),
48
  gr.Image(label="Detected Image"),
49
  ],
50
  title="Weapon Detection AI",
51
+ description="Upload an image to detect weapons like pistols, guns, and rifles."
52
  )
53
 
54
+ # Step 4: Launch Gradio App
55
+ iface.launch()