rollback
Browse files- .gitignore +2 -0
- app.py +10 -30
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.egg-info/
|
2 |
+
__pycache__/
|
app.py
CHANGED
@@ -71,25 +71,17 @@ examples = [
|
|
71 |
OBJ_ID = 0
|
72 |
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
model_cfg = "edgetam.yaml"
|
79 |
-
predictor = build_sam2_video_predictor(
|
80 |
-
model_cfg, sam2_checkpoint, device="cuda"
|
81 |
-
)
|
82 |
-
print("predictor loaded")
|
83 |
-
|
84 |
-
# use bfloat16 for the entire demo
|
85 |
-
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
86 |
-
if torch.cuda.get_device_properties(0).major >= 8:
|
87 |
-
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
88 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
89 |
-
torch.backends.cudnn.allow_tf32 = True
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
|
94 |
|
95 |
def get_video_fps(video_path):
|
@@ -106,10 +98,7 @@ def get_video_fps(video_path):
|
|
106 |
return fps
|
107 |
|
108 |
|
109 |
-
@spaces.GPU
|
110 |
def reset(session_state):
|
111 |
-
predictor = get_predictor(session_state)
|
112 |
-
predictor.to("cuda")
|
113 |
session_state["input_points"] = []
|
114 |
session_state["input_labels"] = []
|
115 |
if session_state["inference_state"] is not None:
|
@@ -127,10 +116,7 @@ def reset(session_state):
|
|
127 |
)
|
128 |
|
129 |
|
130 |
-
@spaces.GPU
|
131 |
def clear_points(session_state):
|
132 |
-
predictor = get_predictor(session_state)
|
133 |
-
predictor.to("cuda")
|
134 |
session_state["input_points"] = []
|
135 |
session_state["input_labels"] = []
|
136 |
if session_state["inference_state"]["tracking_has_started"]:
|
@@ -145,8 +131,6 @@ def clear_points(session_state):
|
|
145 |
|
146 |
@spaces.GPU
|
147 |
def preprocess_video_in(video_path, session_state):
|
148 |
-
predictor = get_predictor(session_state)
|
149 |
-
predictor.to("cuda")
|
150 |
if video_path is None:
|
151 |
return (
|
152 |
gr.update(open=True), # video_in_drawer
|
@@ -210,8 +194,6 @@ def segment_with_points(
|
|
210 |
session_state,
|
211 |
evt: gr.SelectData,
|
212 |
):
|
213 |
-
predictor = get_predictor(session_state)
|
214 |
-
predictor.to("cuda")
|
215 |
session_state["input_points"].append(evt.index)
|
216 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
217 |
|
@@ -285,8 +267,6 @@ def propagate_to_all(
|
|
285 |
video_in,
|
286 |
session_state,
|
287 |
):
|
288 |
-
predictor = get_predictor(session_state)
|
289 |
-
predictor.to("cuda")
|
290 |
if (
|
291 |
len(session_state["input_points"]) == 0
|
292 |
or video_in is None
|
|
|
71 |
OBJ_ID = 0
|
72 |
|
73 |
|
74 |
+
sam2_checkpoint = "checkpoints/edgetam.pt"
|
75 |
+
model_cfg = "edgetam.yaml"
|
76 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
|
77 |
+
print("predictor loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
# use bfloat16 for the entire demo
|
80 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
81 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
82 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
83 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
84 |
+
torch.backends.cudnn.allow_tf32 = True
|
85 |
|
86 |
|
87 |
def get_video_fps(video_path):
|
|
|
98 |
return fps
|
99 |
|
100 |
|
|
|
101 |
def reset(session_state):
|
|
|
|
|
102 |
session_state["input_points"] = []
|
103 |
session_state["input_labels"] = []
|
104 |
if session_state["inference_state"] is not None:
|
|
|
116 |
)
|
117 |
|
118 |
|
|
|
119 |
def clear_points(session_state):
|
|
|
|
|
120 |
session_state["input_points"] = []
|
121 |
session_state["input_labels"] = []
|
122 |
if session_state["inference_state"]["tracking_has_started"]:
|
|
|
131 |
|
132 |
@spaces.GPU
|
133 |
def preprocess_video_in(video_path, session_state):
|
|
|
|
|
134 |
if video_path is None:
|
135 |
return (
|
136 |
gr.update(open=True), # video_in_drawer
|
|
|
194 |
session_state,
|
195 |
evt: gr.SelectData,
|
196 |
):
|
|
|
|
|
197 |
session_state["input_points"].append(evt.index)
|
198 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
199 |
|
|
|
267 |
video_in,
|
268 |
session_state,
|
269 |
):
|
|
|
|
|
270 |
if (
|
271 |
len(session_state["input_points"]) == 0
|
272 |
or video_in is None
|