Spaces:
Running
on
Zero
Running
on
Zero
remove CUDA init in main process
Browse files
app.py
CHANGED
@@ -69,21 +69,27 @@ examples = [
|
|
69 |
]
|
70 |
|
71 |
OBJ_ID = 0
|
72 |
-
sam2_checkpoint = "checkpoints/edgetam.pt"
|
73 |
-
model_cfg = "edgetam.yaml"
|
74 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
|
75 |
-
print("PREDICTOR LOADED")
|
76 |
-
predictor.to("cuda")
|
77 |
-
|
78 |
-
# use bfloat16 for the entire notebook
|
79 |
-
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
80 |
-
if torch.cuda.get_device_properties(0).major >= 8:
|
81 |
-
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
82 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
83 |
-
torch.backends.cudnn.allow_tf32 = True
|
84 |
|
85 |
|
86 |
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def get_video_fps(video_path):
|
88 |
# Open the video file
|
89 |
cap = cv2.VideoCapture(video_path)
|
@@ -100,6 +106,7 @@ def get_video_fps(video_path):
|
|
100 |
|
101 |
@spaces.GPU
|
102 |
def reset(session_state):
|
|
|
103 |
session_state["input_points"] = []
|
104 |
session_state["input_labels"] = []
|
105 |
if session_state["inference_state"] is not None:
|
@@ -119,6 +126,7 @@ def reset(session_state):
|
|
119 |
|
120 |
@spaces.GPU
|
121 |
def clear_points(session_state):
|
|
|
122 |
session_state["input_points"] = []
|
123 |
session_state["input_labels"] = []
|
124 |
if session_state["inference_state"]["tracking_has_started"]:
|
@@ -133,6 +141,7 @@ def clear_points(session_state):
|
|
133 |
|
134 |
@spaces.GPU
|
135 |
def preprocess_video_in(video_path, session_state):
|
|
|
136 |
if video_path is None:
|
137 |
return (
|
138 |
gr.update(open=True), # video_in_drawer
|
@@ -196,6 +205,7 @@ def segment_with_points(
|
|
196 |
session_state,
|
197 |
evt: gr.SelectData,
|
198 |
):
|
|
|
199 |
session_state["input_points"].append(evt.index)
|
200 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
201 |
|
@@ -249,7 +259,6 @@ def segment_with_points(
|
|
249 |
return selected_point_map, first_frame_output, session_state
|
250 |
|
251 |
|
252 |
-
@spaces.GPU
|
253 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
254 |
if random_color:
|
255 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
@@ -270,6 +279,7 @@ def propagate_to_all(
|
|
270 |
video_in,
|
271 |
session_state,
|
272 |
):
|
|
|
273 |
if (
|
274 |
len(session_state["input_points"]) == 0
|
275 |
or video_in is None
|
@@ -325,7 +335,6 @@ def propagate_to_all(
|
|
325 |
)
|
326 |
|
327 |
|
328 |
-
@spaces.GPU
|
329 |
def update_ui():
|
330 |
return gr.update(visible=True)
|
331 |
|
@@ -478,5 +487,6 @@ with gr.Blocks() as demo:
|
|
478 |
queue=False,
|
479 |
)
|
480 |
|
|
|
481 |
demo.queue()
|
482 |
demo.launch()
|
|
|
69 |
]
|
70 |
|
71 |
OBJ_ID = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
|
74 |
@spaces.GPU
|
75 |
+
def build_predictor():
|
76 |
+
sam2_checkpoint = "checkpoints/edgetam.pt"
|
77 |
+
model_cfg = "edgetam.yaml"
|
78 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
|
79 |
+
print("predictor loaded")
|
80 |
+
|
81 |
+
# use bfloat16 for the entire notebook
|
82 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
83 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
84 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
85 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
86 |
+
torch.backends.cudnn.allow_tf32 = True
|
87 |
+
return predictor
|
88 |
+
|
89 |
+
|
90 |
+
predictor = build_predictor()
|
91 |
+
|
92 |
+
|
93 |
def get_video_fps(video_path):
|
94 |
# Open the video file
|
95 |
cap = cv2.VideoCapture(video_path)
|
|
|
106 |
|
107 |
@spaces.GPU
|
108 |
def reset(session_state):
|
109 |
+
predictor.to("cuda")
|
110 |
session_state["input_points"] = []
|
111 |
session_state["input_labels"] = []
|
112 |
if session_state["inference_state"] is not None:
|
|
|
126 |
|
127 |
@spaces.GPU
|
128 |
def clear_points(session_state):
|
129 |
+
predictor.to("cuda")
|
130 |
session_state["input_points"] = []
|
131 |
session_state["input_labels"] = []
|
132 |
if session_state["inference_state"]["tracking_has_started"]:
|
|
|
141 |
|
142 |
@spaces.GPU
|
143 |
def preprocess_video_in(video_path, session_state):
|
144 |
+
predictor.to("cuda")
|
145 |
if video_path is None:
|
146 |
return (
|
147 |
gr.update(open=True), # video_in_drawer
|
|
|
205 |
session_state,
|
206 |
evt: gr.SelectData,
|
207 |
):
|
208 |
+
predictor.to("cuda")
|
209 |
session_state["input_points"].append(evt.index)
|
210 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
211 |
|
|
|
259 |
return selected_point_map, first_frame_output, session_state
|
260 |
|
261 |
|
|
|
262 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
263 |
if random_color:
|
264 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
279 |
video_in,
|
280 |
session_state,
|
281 |
):
|
282 |
+
predictor.to("cuda")
|
283 |
if (
|
284 |
len(session_state["input_points"]) == 0
|
285 |
or video_in is None
|
|
|
335 |
)
|
336 |
|
337 |
|
|
|
338 |
def update_ui():
|
339 |
return gr.update(visible=True)
|
340 |
|
|
|
487 |
queue=False,
|
488 |
)
|
489 |
|
490 |
+
|
491 |
demo.queue()
|
492 |
demo.launch()
|