Update visualizer.py
Browse files- visualizer.py +15 -13
visualizer.py
CHANGED
@@ -104,7 +104,7 @@ class Visualizer:
|
|
104 |
Handles drawing tracked points and their trajectories on video frames.
|
105 |
|
106 |
Args:
|
107 |
-
|
108 |
padding: Padding to add around video frames in pixels
|
109 |
fps: Frames per second for output video
|
110 |
colormap: Color scheme for tracks ('rainbow' or 'spring')
|
@@ -115,20 +115,21 @@ class Visualizer:
|
|
115 |
|
116 |
def __init__(
|
117 |
self,
|
118 |
-
|
119 |
-
|
|
|
120 |
fps: int = 10,
|
121 |
colormap: str = "rainbow",
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
):
|
126 |
-
self.
|
127 |
-
self.padding =
|
128 |
self.fps = fps
|
129 |
-
self.line_width =
|
130 |
-
self.initial_frame_repeat =
|
131 |
-
self.track_history_length =
|
132 |
|
133 |
# Set up colormap for track visualization
|
134 |
self.colormap = colormap
|
@@ -143,6 +144,7 @@ class Visualizer:
|
|
143 |
self,
|
144 |
video: torch.Tensor, # (B,T,C,H,W)
|
145 |
tracks: torch.Tensor, # (B,T,N,2)
|
|
|
146 |
visibility: torch.Tensor = None, # (B,T,N,1) bool
|
147 |
segmentation: torch.Tensor = None, # (B,1,H,W)
|
148 |
filename: str = "video",
|
@@ -209,7 +211,7 @@ class Visualizer:
|
|
209 |
video: Video tensor of shape (B,T,C,H,W)
|
210 |
filename: Output filename without extension
|
211 |
"""
|
212 |
-
os.makedirs(self.
|
213 |
|
214 |
# Extract frames from video tensor
|
215 |
frames = [
|
@@ -217,7 +219,7 @@ class Visualizer:
|
|
217 |
for frame in video.unbind(1)
|
218 |
]
|
219 |
|
220 |
-
output_path = os.path.join(self.
|
221 |
|
222 |
try:
|
223 |
with imageio.get_writer(output_path, fps=self.fps, quality=8) as writer:
|
|
|
104 |
Handles drawing tracked points and their trajectories on video frames.
|
105 |
|
106 |
Args:
|
107 |
+
save_dir: Directory to save output visualizations
|
108 |
padding: Padding to add around video frames in pixels
|
109 |
fps: Frames per second for output video
|
110 |
colormap: Color scheme for tracks ('rainbow' or 'spring')
|
|
|
115 |
|
116 |
def __init__(
|
117 |
self,
|
118 |
+
save_dir: str = "./results",
|
119 |
+
grayscale: bool = False,
|
120 |
+
pad_value: int = 0,
|
121 |
fps: int = 10,
|
122 |
colormap: str = "rainbow",
|
123 |
+
linewidth: int = 2,
|
124 |
+
show_first_frame: int = 10,
|
125 |
+
tracks_leave_trace: int = 0,
|
126 |
):
|
127 |
+
self.save_dir = save_dir
|
128 |
+
self.padding = pad_value
|
129 |
self.fps = fps
|
130 |
+
self.line_width = linewidth
|
131 |
+
self.initial_frame_repeat = show_first_frame
|
132 |
+
self.track_history_length = tracks_leave_trace
|
133 |
|
134 |
# Set up colormap for track visualization
|
135 |
self.colormap = colormap
|
|
|
144 |
self,
|
145 |
video: torch.Tensor, # (B,T,C,H,W)
|
146 |
tracks: torch.Tensor, # (B,T,N,2)
|
147 |
+
|
148 |
visibility: torch.Tensor = None, # (B,T,N,1) bool
|
149 |
segmentation: torch.Tensor = None, # (B,1,H,W)
|
150 |
filename: str = "video",
|
|
|
211 |
video: Video tensor of shape (B,T,C,H,W)
|
212 |
filename: Output filename without extension
|
213 |
"""
|
214 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
215 |
|
216 |
# Extract frames from video tensor
|
217 |
frames = [
|
|
|
219 |
for frame in video.unbind(1)
|
220 |
]
|
221 |
|
222 |
+
output_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
223 |
|
224 |
try:
|
225 |
with imageio.get_writer(output_path, fps=self.fps, quality=8) as writer:
|