Spaces:
Running
on
Zero
Running
on
Zero
put predictor in session state
Browse files
app.py
CHANGED
@@ -72,22 +72,24 @@ OBJ_ID = 0
|
|
72 |
|
73 |
|
74 |
@spaces.GPU
|
75 |
-
def
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
if torch.cuda.get_device_properties(0).major >= 8:
|
84 |
-
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
85 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
86 |
-
torch.backends.cudnn.allow_tf32 = True
|
87 |
-
return predictor
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
predictor =
|
|
|
91 |
|
92 |
|
93 |
def get_video_fps(video_path):
|
@@ -106,6 +108,7 @@ def get_video_fps(video_path):
|
|
106 |
|
107 |
@spaces.GPU
|
108 |
def reset(session_state):
|
|
|
109 |
predictor.to("cuda")
|
110 |
session_state["input_points"] = []
|
111 |
session_state["input_labels"] = []
|
@@ -126,6 +129,7 @@ def reset(session_state):
|
|
126 |
|
127 |
@spaces.GPU
|
128 |
def clear_points(session_state):
|
|
|
129 |
predictor.to("cuda")
|
130 |
session_state["input_points"] = []
|
131 |
session_state["input_labels"] = []
|
@@ -141,6 +145,7 @@ def clear_points(session_state):
|
|
141 |
|
142 |
@spaces.GPU
|
143 |
def preprocess_video_in(video_path, session_state):
|
|
|
144 |
predictor.to("cuda")
|
145 |
if video_path is None:
|
146 |
return (
|
@@ -205,6 +210,7 @@ def segment_with_points(
|
|
205 |
session_state,
|
206 |
evt: gr.SelectData,
|
207 |
):
|
|
|
208 |
predictor.to("cuda")
|
209 |
session_state["input_points"].append(evt.index)
|
210 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
@@ -279,6 +285,7 @@ def propagate_to_all(
|
|
279 |
video_in,
|
280 |
session_state,
|
281 |
):
|
|
|
282 |
predictor.to("cuda")
|
283 |
if (
|
284 |
len(session_state["input_points"]) == 0
|
|
|
72 |
|
73 |
|
74 |
@spaces.GPU
|
75 |
+
def get_predictor(session_state):
|
76 |
+
if "predictor" not in session_state:
|
77 |
+
sam2_checkpoint = "checkpoints/edgetam.pt"
|
78 |
+
model_cfg = "edgetam.yaml"
|
79 |
+
predictor = build_sam2_video_predictor(
|
80 |
+
model_cfg, sam2_checkpoint, device="cuda"
|
81 |
+
)
|
82 |
+
print("predictor loaded")
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
# use bfloat16 for the entire demo
|
85 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
86 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
87 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
88 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
89 |
+
torch.backends.cudnn.allow_tf32 = True
|
90 |
|
91 |
+
session_state["predictor"] = predictor
|
92 |
+
return session_state["predictor"] = predictor
|
93 |
|
94 |
|
95 |
def get_video_fps(video_path):
|
|
|
108 |
|
109 |
@spaces.GPU
|
110 |
def reset(session_state):
|
111 |
+
predictor = get_predictor(session_state)
|
112 |
predictor.to("cuda")
|
113 |
session_state["input_points"] = []
|
114 |
session_state["input_labels"] = []
|
|
|
129 |
|
130 |
@spaces.GPU
|
131 |
def clear_points(session_state):
|
132 |
+
predictor = get_predictor(session_state)
|
133 |
predictor.to("cuda")
|
134 |
session_state["input_points"] = []
|
135 |
session_state["input_labels"] = []
|
|
|
145 |
|
146 |
@spaces.GPU
|
147 |
def preprocess_video_in(video_path, session_state):
|
148 |
+
predictor = get_predictor(session_state)
|
149 |
predictor.to("cuda")
|
150 |
if video_path is None:
|
151 |
return (
|
|
|
210 |
session_state,
|
211 |
evt: gr.SelectData,
|
212 |
):
|
213 |
+
predictor = get_predictor(session_state)
|
214 |
predictor.to("cuda")
|
215 |
session_state["input_points"].append(evt.index)
|
216 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
|
|
285 |
video_in,
|
286 |
session_state,
|
287 |
):
|
288 |
+
predictor = get_predictor(session_state)
|
289 |
predictor.to("cuda")
|
290 |
if (
|
291 |
len(session_state["input_points"]) == 0
|