StephanST commited on
Commit
e9160b2
·
verified ·
1 Parent(s): 902b1c2

Upload 2 files

Browse files

boilerplate inference code

run_sliced_inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ from sahi.models.yolov8 import Yolov8DetectionModel
4
+ from sahi.predict import get_sliced_prediction
5
+ import supervision as sv
6
+ import numpy as np
7
+
8
+ # Check the number of command-line arguments
9
+ if len(sys.argv) != 8:
10
+ print("Usage: python yolov8_video_inference.py <model_path> <input_video_path> <output_video_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>")
11
+ sys.exit(1)
12
+
13
+ # Get command-line arguments
14
+ model_path = sys.argv[1]
15
+ input_video_path = sys.argv[2]
16
+ output_video_path = sys.argv[3]
17
+ slice_height = int(sys.argv[4])
18
+ slice_width = int(sys.argv[5])
19
+ overlap_height_ratio = float(sys.argv[6])
20
+ overlap_width_ratio = float(sys.argv[7])
21
+
22
+ # Load YOLOv8 model with SAHI
23
+ detection_model = Yolov8DetectionModel(
24
+ model_path=model_path,
25
+ confidence_threshold=0.1,
26
+ device="cuda" # or "cpu"
27
+ )
28
+
29
+ # Open input video
30
+ cap = cv2.VideoCapture(input_video_path)
31
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
32
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
33
+ fps = cap.get(cv2.CAP_PROP_FPS)
34
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
35
+
36
+ # Set up output video writer
37
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
38
+
39
+ # Create bounding box and label annotators
40
+ #box_annotator = sv.BoundingBoxAnnotator(thickness=1)
41
+ box_annotator = sv.BoxCornerAnnotator(thickness=2)
42
+ label_annotator = sv.LabelAnnotator(text_scale=0.5, text_thickness=2)
43
+
44
+ # Process each frame
45
+ frame_count = 0
46
+ while cap.isOpened():
47
+ ret, frame = cap.read()
48
+ if not ret:
49
+ break
50
+
51
+ # Perform sliced inference on the current frame using SAHI
52
+ result = get_sliced_prediction(
53
+ image=frame,
54
+ detection_model=detection_model,
55
+ slice_height=slice_height,
56
+ slice_width=slice_width,
57
+ overlap_height_ratio=overlap_height_ratio,
58
+ overlap_width_ratio=overlap_width_ratio
59
+ )
60
+
61
+ # Extract data from SAHI result
62
+ object_predictions = result.object_prediction_list
63
+
64
+ # Initialize lists to hold the data
65
+ xyxy = []
66
+ confidences = []
67
+ class_ids = []
68
+ class_names = []
69
+
70
+ # Loop over the object predictions and extract data
71
+ for pred in object_predictions:
72
+ bbox = pred.bbox.to_xyxy() # Convert bbox to [x1, y1, x2, y2]
73
+ xyxy.append(bbox)
74
+ confidences.append(pred.score.value)
75
+ class_ids.append(pred.category.id)
76
+ class_names.append(pred.category.name)
77
+
78
+ # Check if there are any detections
79
+ if xyxy:
80
+ # Convert lists to numpy arrays
81
+ xyxy = np.array(xyxy, dtype=np.float32)
82
+ confidences = np.array(confidences, dtype=np.float32)
83
+ class_ids = np.array(class_ids, dtype=int)
84
+
85
+ # Create sv.Detections object
86
+ detections = sv.Detections(
87
+ xyxy=xyxy,
88
+ confidence=confidences,
89
+ class_id=class_ids
90
+ )
91
+
92
+ # Prepare labels for label annotator
93
+ labels = [
94
+ f"{class_name} {confidence:.2f}"
95
+ for class_name, confidence in zip(class_names, confidences)
96
+ ]
97
+
98
+ # Annotate frame with detection results
99
+ annotated_frame = frame.copy()
100
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
101
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
102
+ else:
103
+ # If no detections, use the original frame
104
+ annotated_frame = frame.copy()
105
+
106
+ # Write the annotated frame to the output video
107
+ out.write(annotated_frame)
108
+
109
+ frame_count += 1
110
+ print(f"Processed frame {frame_count}", end='\r')
111
+
112
+ # Release resources
113
+ cap.release()
114
+ out.release()
115
+ print("\nInference complete. Video saved at", output_video_path)
116
+
run_sliced_inference_with_tracker.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ from sahi.models.yolov8 import Yolov8DetectionModel
4
+ from sahi.predict import get_sliced_prediction
5
+ import supervision as sv
6
+ import numpy as np
7
+
8
+ # Check the number of command-line arguments
9
+ if len(sys.argv) != 8:
10
+ print("Usage: python yolov8_video_inference.py <model_path> <input_video_path> <output_video_path> <slice_height> <slice_width> <overlap_height_ratio> <overlap_width_ratio>")
11
+ sys.exit(1)
12
+
13
+ # Get command-line arguments
14
+ model_path = sys.argv[1]
15
+ input_video_path = sys.argv[2]
16
+ output_video_path = sys.argv[3]
17
+ slice_height = int(sys.argv[4])
18
+ slice_width = int(sys.argv[5])
19
+ overlap_height_ratio = float(sys.argv[6])
20
+ overlap_width_ratio = float(sys.argv[7])
21
+
22
+ # Load YOLOv8 model with SAHI
23
+ detection_model = Yolov8DetectionModel(
24
+ model_path=model_path,
25
+ confidence_threshold=0.25,
26
+ device="cuda" # or "cpu"
27
+ )
28
+
29
+ # Get video info
30
+ video_info = sv.VideoInfo.from_video_path(video_path=input_video_path)
31
+
32
+ # Open input video
33
+ cap = cv2.VideoCapture(input_video_path)
34
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ fps = cap.get(cv2.CAP_PROP_FPS)
37
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
38
+
39
+ # Set up output video writer
40
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
41
+
42
+ # Initialize tracker and smoother
43
+ tracker = sv.ByteTrack(frame_rate=video_info.fps)
44
+ smoother = sv.DetectionsSmoother()
45
+
46
+ # Create bounding box and label annotators
47
+ box_annotator = sv.BoxCornerAnnotator(thickness=2)
48
+ label_annotator = sv.LabelAnnotator(
49
+ text_scale=0.5,
50
+ text_thickness=1,
51
+ text_padding=1
52
+ )
53
+
54
+ # Process each frame
55
+ frame_count = 0
56
+ class_id_to_name = {} # Initialize once to store class_id to name mapping
57
+
58
+ while cap.isOpened():
59
+ ret, frame = cap.read()
60
+ if not ret:
61
+ break
62
+
63
+ # Perform sliced inference on the current frame using SAHI
64
+ result = get_sliced_prediction(
65
+ image=frame,
66
+ detection_model=detection_model,
67
+ slice_height=slice_height,
68
+ slice_width=slice_width,
69
+ overlap_height_ratio=overlap_height_ratio,
70
+ overlap_width_ratio=overlap_width_ratio
71
+ )
72
+
73
+ # Extract data from SAHI result
74
+ object_predictions = result.object_prediction_list
75
+
76
+ # Initialize lists to hold the data
77
+ xyxy = []
78
+ confidences = []
79
+ class_ids = []
80
+ # Build or update class_id to name mapping
81
+ for pred in object_predictions:
82
+ if pred.category.id not in class_id_to_name:
83
+ class_id_to_name[pred.category.id] = pred.category.name
84
+
85
+ # Loop over the object predictions and extract data
86
+ for pred in object_predictions:
87
+ bbox = pred.bbox.to_xyxy() # Convert bbox to [x1, y1, x2, y2]
88
+ xyxy.append(bbox)
89
+ confidences.append(pred.score.value)
90
+ class_ids.append(pred.category.id)
91
+
92
+ # Check if there are any detections
93
+ if xyxy:
94
+ # Convert lists to numpy arrays
95
+ xyxy = np.array(xyxy, dtype=np.float32)
96
+ confidences = np.array(confidences, dtype=np.float32)
97
+ class_ids = np.array(class_ids, dtype=int)
98
+
99
+ # Create sv.Detections object
100
+ detections = sv.Detections(
101
+ xyxy=xyxy,
102
+ confidence=confidences,
103
+ class_id=class_ids
104
+ )
105
+
106
+ # Update tracker with detections
107
+ detections = tracker.update_with_detections(detections)
108
+
109
+ # Update smoother with detections
110
+ detections = smoother.update_with_detections(detections)
111
+
112
+ # Prepare labels for label annotator
113
+ # Include tracker ID in labels if available
114
+ labels = []
115
+ for i in range(len(detections.xyxy)):
116
+ class_id = detections.class_id[i]
117
+ confidence = detections.confidence[i]
118
+ class_name = class_id_to_name.get(class_id, 'Unknown')
119
+ label = f"{class_name} {confidence:.2f}"
120
+
121
+ # Add tracker ID if available
122
+ if hasattr(detections, 'tracker_id') and detections.tracker_id is not None:
123
+ tracker_id = detections.tracker_id[i]
124
+ label = f"ID {tracker_id} {label}"
125
+
126
+ labels.append(label)
127
+
128
+ # Annotate frame with detection results
129
+ annotated_frame = frame.copy()
130
+ annotated_frame = box_annotator.annotate(
131
+ scene=annotated_frame,
132
+ detections=detections
133
+ )
134
+ annotated_frame = label_annotator.annotate(
135
+ scene=annotated_frame,
136
+ detections=detections,
137
+ labels=labels
138
+ )
139
+ else:
140
+ # If no detections, use the original frame
141
+ annotated_frame = frame.copy()
142
+
143
+ # Write the annotated frame to the output video
144
+ out.write(annotated_frame)
145
+
146
+ frame_count += 1
147
+ print(f"Processed frame {frame_count}", end='\r')
148
+
149
+ # Release resources
150
+ cap.release()
151
+ out.release()
152
+ print("\nInference complete. Video saved at", output_video_path)