Subh775 commited on
Commit
745a794
·
verified ·
1 Parent(s): 08991cc

Update video_processing.py

Browse files
Files changed (1) hide show
  1. video_processing.py +223 -223
video_processing.py CHANGED
@@ -1,224 +1,224 @@
1
- # pip install -q rfdetr==1.2.1 supervision==0.26.1
2
-
3
- # RF-DETR video processing for threat detection.
4
- # Inference time depends on frame resolution (e.g., ~50 ms/frame on GPU for 640×640).
5
-
6
-
7
- import numpy as np
8
- import supervision as sv
9
- import torch
10
- import requests
11
- from PIL import Image
12
- import os
13
- import cv2
14
- from tqdm import tqdm
15
- import time
16
-
17
- from rfdetr import RFDETRNano
18
-
19
- THREAT_CLASSES = {
20
- 1: "Gun",
21
- 2: "Explosive",
22
- 3: "Grenade",
23
- 4: "Knife"
24
- }
25
-
26
- # Enable GPU if available
27
- if torch.cuda.is_available():
28
- print(f"GPU: {torch.cuda.get_device_name(0)}")
29
- # print(f"CUDA Version: {torch.version.cuda}")
30
- # print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
31
-
32
- # Optimize for batch processing
33
- torch.backends.cudnn.benchmark = True
34
- torch.backends.cudnn.deterministic = False
35
- else:
36
- print("CUDA not available, using CPU")
37
-
38
- # Configuration
39
- INPUT_VIDEO = "test_video.mp4"
40
-
41
- base, ext = os.path.splitext(INPUT_VIDEO)
42
- OUTPUT_VIDEO = f"{base}_detr{ext}"
43
-
44
- THRESHOLD = 0.5
45
- BATCH_SIZE = 32
46
-
47
- # Auto-adjust batch size based on GPU memory
48
- if torch.cuda.is_available():
49
- gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
50
-
51
- print(f"Using batch size: {BATCH_SIZE}")
52
-
53
- # Download weights
54
- weights_url = "https://huggingface.co/Subh775/Threat-Detection-RF-DETR/resolve/main/checkpoint_best_total.pth"
55
- weights_filename = "checkpoint_best_total.pth"
56
-
57
- if not os.path.exists(weights_filename):
58
- print(f"Downloading weights from {weights_url}")
59
- response = requests.get(weights_url, stream=True)
60
- response.raise_for_status()
61
- with open(weights_filename, 'wb') as f:
62
- for chunk in response.iter_content(chunk_size=8192):
63
- f.write(chunk)
64
- print("Download complete.")
65
-
66
- print("Loading model...")
67
- model = RFDETRNano(resolution=640, pretrain_weights=weights_filename)
68
- model.optimize_for_inference()
69
-
70
- # Setup annotators
71
- color = sv.ColorPalette.from_hex([
72
- "#1E90FF", "#32CD32", "#FF0000", "#FF8C00"
73
- ])
74
-
75
- bbox_annotator = sv.BoxAnnotator(color=color, thickness=3)
76
- label_annotator = sv.LabelAnnotator(
77
- color=color,
78
- text_color=sv.Color.BLACK,
79
- text_scale=1.0,
80
- text_thickness=2,
81
- smart_position=True
82
- )
83
-
84
- def process_frame_batch(frames):
85
- """Process a batch of frames for better GPU utilization"""
86
- batch_results = []
87
-
88
- # Convert all frames to PIL images
89
- pil_images = []
90
- for frame in frames:
91
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
92
- pil_image = Image.fromarray(rgb_frame)
93
- pil_images.append(pil_image)
94
-
95
- # Process each image in the batch (RF-DETR processes them efficiently)
96
- batch_detections = []
97
- for pil_image in pil_images:
98
- detections = model.predict(pil_image, threshold=THRESHOLD)
99
- batch_detections.append(detections)
100
-
101
- # Annotate all images in the batch
102
- annotated_frames = []
103
- for pil_image, detections in zip(pil_images, batch_detections):
104
- # Create labels
105
- labels = []
106
- for class_id, confidence in zip(detections.class_id, detections.confidence):
107
- class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}")
108
- labels.append(f"{class_name} {confidence:.2f}")
109
-
110
- # Annotate
111
- annotated_pil = pil_image.copy()
112
- annotated_pil = bbox_annotator.annotate(annotated_pil, detections)
113
- annotated_pil = label_annotator.annotate(annotated_pil, detections, labels)
114
-
115
- # Convert back to BGR
116
- annotated_frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR)
117
- annotated_frames.append(annotated_frame)
118
-
119
- return annotated_frames, batch_detections
120
-
121
- # Open video
122
- cap = cv2.VideoCapture(INPUT_VIDEO)
123
- if not cap.isOpened():
124
- print(f"Error: Could not open video file {INPUT_VIDEO}")
125
- exit()
126
-
127
- # Get video properties
128
- fps = int(cap.get(cv2.CAP_PROP_FPS))
129
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
130
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
131
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
132
-
133
- print(f"Video: {width}x{height}, {fps} FPS, {total_frames} frames")
134
- print(f"Processing in batches of {BATCH_SIZE} frames")
135
-
136
- # Setup video writer
137
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
138
- out = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height))
139
-
140
- # Batch processing
141
- print("Processing video with batch inference...")
142
- frame_buffer = []
143
- total_detections = 0
144
- processed_frames = 0
145
- processing_times = []
146
-
147
- with tqdm(total=total_frames, desc="Batch processing") as pbar:
148
- while True:
149
- ret, frame = cap.read()
150
- if not ret:
151
- # Process remaining frames in buffer
152
- if frame_buffer:
153
- start_time = time.time()
154
- annotated_frames, batch_detections = process_frame_batch(frame_buffer)
155
- processing_time = time.time() - start_time
156
- processing_times.append(processing_time)
157
-
158
- # Write remaining frames
159
- for annotated_frame, detections in zip(annotated_frames, batch_detections):
160
- out.write(annotated_frame)
161
- total_detections += len(detections)
162
-
163
- processed_frames += len(frame_buffer)
164
- pbar.update(len(frame_buffer))
165
- break
166
-
167
- # Add frame to buffer
168
- frame_buffer.append(frame)
169
-
170
- # Process when buffer is full
171
- if len(frame_buffer) >= BATCH_SIZE:
172
- start_time = time.time()
173
-
174
- # Process batch
175
- annotated_frames, batch_detections = process_frame_batch(frame_buffer)
176
-
177
- processing_time = time.time() - start_time
178
- processing_times.append(processing_time)
179
-
180
- # Write frames
181
- batch_threats = 0
182
- for annotated_frame, detections in zip(annotated_frames, batch_detections):
183
- out.write(annotated_frame)
184
- batch_threats += len(detections)
185
- total_detections += len(detections)
186
-
187
- processed_frames += len(frame_buffer)
188
-
189
- # Update progress
190
- batch_fps = len(frame_buffer) / processing_time if processing_time > 0 else 0
191
- pbar.set_postfix({
192
- 'Batch FPS': f"{batch_fps:.1f}",
193
- 'Threats': batch_threats,
194
- 'Total': total_detections
195
- })
196
- pbar.update(len(frame_buffer))
197
-
198
- # Clear buffer
199
- frame_buffer = []
200
-
201
- # Clear GPU cache every 10 batches
202
- if torch.cuda.is_available() and processed_frames % (BATCH_SIZE * 10) == 0:
203
- torch.cuda.empty_cache()
204
-
205
- # Cleanup
206
- cap.release()
207
- out.release()
208
-
209
- if torch.cuda.is_available():
210
- torch.cuda.empty_cache()
211
-
212
- # Performance summary
213
- total_time = sum(processing_times)
214
- avg_fps = processed_frames / total_time if total_time > 0 else 0
215
- speedup = avg_fps / fps if fps > 0 else 0
216
-
217
- print(f"Output: {OUTPUT_VIDEO}")
218
- print(f"Stats:")
219
- print(f" • Processed: {processed_frames} frames")
220
- print(f" • Detections: {total_detections}")
221
- print(f" • Batch size: {BATCH_SIZE}")
222
- print(f" • Average speed: {avg_fps:.1f} FPS")
223
- print(f" • Speedup: {speedup:.1f}x real-time")
224
  print(f" • Processing time: {total_time:.1f}s")
 
1
+ # pip install -q rfdetr==1.2.1 supervision==0.26.1
2
+
3
+ # RF-DETR video processing for threat detection.
4
+ # Inference time depends on frame resolution (e.g., ~50 ms/frame on GPU for 640×640).
5
+
6
+
7
+ import numpy as np
8
+ import supervision as sv
9
+ import torch
10
+ import requests
11
+ from PIL import Image
12
+ import os
13
+ import cv2
14
+ from tqdm import tqdm
15
+ import time
16
+
17
+ from rfdetr import RFDETRNano
18
+
19
+ THREAT_CLASSES = {
20
+ 1: "Gun",
21
+ 2: "Explosive",
22
+ 3: "Grenade",
23
+ 4: "Knife"
24
+ }
25
+
26
+ # Enable GPU if available
27
+ if torch.cuda.is_available():
28
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
29
+ # print(f"CUDA Version: {torch.version.cuda}")
30
+ # print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
31
+
32
+ # Optimize for batch processing
33
+ torch.backends.cudnn.benchmark = True
34
+ torch.backends.cudnn.deterministic = False
35
+ else:
36
+ print("CUDA not available, using CPU")
37
+
38
+ # Configuration
39
+ INPUT_VIDEO = "test_video.mp4"
40
+
41
+ base, ext = os.path.splitext(INPUT_VIDEO)
42
+ OUTPUT_VIDEO = f"{base}_detr{ext}"
43
+
44
+ THRESHOLD = 0.5
45
+ BATCH_SIZE = 32
46
+
47
+ # Auto-adjust batch size based on GPU memory
48
+ if torch.cuda.is_available():
49
+ gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
50
+
51
+ print(f"Using batch size: {BATCH_SIZE}")
52
+
53
+ # Download weights
54
+ weights_url = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth"
55
+ weights_filename = "checkpoint_best_total.pth"
56
+
57
+ if not os.path.exists(weights_filename):
58
+ print(f"Downloading weights from {weights_url}")
59
+ response = requests.get(weights_url, stream=True)
60
+ response.raise_for_status()
61
+ with open(weights_filename, 'wb') as f:
62
+ for chunk in response.iter_content(chunk_size=8192):
63
+ f.write(chunk)
64
+ print("Download complete.")
65
+
66
+ print("Loading model...")
67
+ model = RFDETRNano(resolution=640, pretrain_weights=weights_filename)
68
+ model.optimize_for_inference()
69
+
70
+ # Setup annotators
71
+ color = sv.ColorPalette.from_hex([
72
+ "#1E90FF", "#32CD32", "#FF0000", "#FF8C00"
73
+ ])
74
+
75
+ bbox_annotator = sv.BoxAnnotator(color=color, thickness=3)
76
+ label_annotator = sv.LabelAnnotator(
77
+ color=color,
78
+ text_color=sv.Color.BLACK,
79
+ text_scale=1.0,
80
+ text_thickness=2,
81
+ smart_position=True
82
+ )
83
+
84
+ def process_frame_batch(frames):
85
+ """Process a batch of frames for better GPU utilization"""
86
+ batch_results = []
87
+
88
+ # Convert all frames to PIL images
89
+ pil_images = []
90
+ for frame in frames:
91
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
92
+ pil_image = Image.fromarray(rgb_frame)
93
+ pil_images.append(pil_image)
94
+
95
+ # Process each image in the batch (RF-DETR processes them efficiently)
96
+ batch_detections = []
97
+ for pil_image in pil_images:
98
+ detections = model.predict(pil_image, threshold=THRESHOLD)
99
+ batch_detections.append(detections)
100
+
101
+ # Annotate all images in the batch
102
+ annotated_frames = []
103
+ for pil_image, detections in zip(pil_images, batch_detections):
104
+ # Create labels
105
+ labels = []
106
+ for class_id, confidence in zip(detections.class_id, detections.confidence):
107
+ class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}")
108
+ labels.append(f"{class_name} {confidence:.2f}")
109
+
110
+ # Annotate
111
+ annotated_pil = pil_image.copy()
112
+ annotated_pil = bbox_annotator.annotate(annotated_pil, detections)
113
+ annotated_pil = label_annotator.annotate(annotated_pil, detections, labels)
114
+
115
+ # Convert back to BGR
116
+ annotated_frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR)
117
+ annotated_frames.append(annotated_frame)
118
+
119
+ return annotated_frames, batch_detections
120
+
121
+ # Open video
122
+ cap = cv2.VideoCapture(INPUT_VIDEO)
123
+ if not cap.isOpened():
124
+ print(f"Error: Could not open video file {INPUT_VIDEO}")
125
+ exit()
126
+
127
+ # Get video properties
128
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
129
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
130
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
131
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
132
+
133
+ print(f"Video: {width}x{height}, {fps} FPS, {total_frames} frames")
134
+ print(f"Processing in batches of {BATCH_SIZE} frames")
135
+
136
+ # Setup video writer
137
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
138
+ out = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height))
139
+
140
+ # Batch processing
141
+ print("Processing video with batch inference...")
142
+ frame_buffer = []
143
+ total_detections = 0
144
+ processed_frames = 0
145
+ processing_times = []
146
+
147
+ with tqdm(total=total_frames, desc="Batch processing") as pbar:
148
+ while True:
149
+ ret, frame = cap.read()
150
+ if not ret:
151
+ # Process remaining frames in buffer
152
+ if frame_buffer:
153
+ start_time = time.time()
154
+ annotated_frames, batch_detections = process_frame_batch(frame_buffer)
155
+ processing_time = time.time() - start_time
156
+ processing_times.append(processing_time)
157
+
158
+ # Write remaining frames
159
+ for annotated_frame, detections in zip(annotated_frames, batch_detections):
160
+ out.write(annotated_frame)
161
+ total_detections += len(detections)
162
+
163
+ processed_frames += len(frame_buffer)
164
+ pbar.update(len(frame_buffer))
165
+ break
166
+
167
+ # Add frame to buffer
168
+ frame_buffer.append(frame)
169
+
170
+ # Process when buffer is full
171
+ if len(frame_buffer) >= BATCH_SIZE:
172
+ start_time = time.time()
173
+
174
+ # Process batch
175
+ annotated_frames, batch_detections = process_frame_batch(frame_buffer)
176
+
177
+ processing_time = time.time() - start_time
178
+ processing_times.append(processing_time)
179
+
180
+ # Write frames
181
+ batch_threats = 0
182
+ for annotated_frame, detections in zip(annotated_frames, batch_detections):
183
+ out.write(annotated_frame)
184
+ batch_threats += len(detections)
185
+ total_detections += len(detections)
186
+
187
+ processed_frames += len(frame_buffer)
188
+
189
+ # Update progress
190
+ batch_fps = len(frame_buffer) / processing_time if processing_time > 0 else 0
191
+ pbar.set_postfix({
192
+ 'Batch FPS': f"{batch_fps:.1f}",
193
+ 'Threats': batch_threats,
194
+ 'Total': total_detections
195
+ })
196
+ pbar.update(len(frame_buffer))
197
+
198
+ # Clear buffer
199
+ frame_buffer = []
200
+
201
+ # Clear GPU cache every 10 batches
202
+ if torch.cuda.is_available() and processed_frames % (BATCH_SIZE * 10) == 0:
203
+ torch.cuda.empty_cache()
204
+
205
+ # Cleanup
206
+ cap.release()
207
+ out.release()
208
+
209
+ if torch.cuda.is_available():
210
+ torch.cuda.empty_cache()
211
+
212
+ # Performance summary
213
+ total_time = sum(processing_times)
214
+ avg_fps = processed_frames / total_time if total_time > 0 else 0
215
+ speedup = avg_fps / fps if fps > 0 else 0
216
+
217
+ print(f"Output: {OUTPUT_VIDEO}")
218
+ print(f"Stats:")
219
+ print(f" • Processed: {processed_frames} frames")
220
+ print(f" • Detections: {total_detections}")
221
+ print(f" • Batch size: {BATCH_SIZE}")
222
+ print(f" • Average speed: {avg_fps:.1f} FPS")
223
+ print(f" • Speedup: {speedup:.1f}x real-time")
224
  print(f" • Processing time: {total_time:.1f}s")