chongzhou commited on
Commit
6e60611
·
1 Parent(s): 9bc4638
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +10 -30
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.egg-info/
2
+ __pycache__/
app.py CHANGED
@@ -71,25 +71,17 @@ examples = [
71
  OBJ_ID = 0
72
 
73
 
74
- @spaces.GPU
75
- def get_predictor(session_state):
76
- if "predictor" not in session_state:
77
- sam2_checkpoint = "checkpoints/edgetam.pt"
78
- model_cfg = "edgetam.yaml"
79
- predictor = build_sam2_video_predictor(
80
- model_cfg, sam2_checkpoint, device="cuda"
81
- )
82
- print("predictor loaded")
83
-
84
- # use bfloat16 for the entire demo
85
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
86
- if torch.cuda.get_device_properties(0).major >= 8:
87
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
88
- torch.backends.cuda.matmul.allow_tf32 = True
89
- torch.backends.cudnn.allow_tf32 = True
90
 
91
- session_state["predictor"] = predictor
92
- return session_state["predictor"]
 
 
 
 
93
 
94
 
95
  def get_video_fps(video_path):
@@ -106,10 +98,7 @@ def get_video_fps(video_path):
106
  return fps
107
 
108
 
109
- @spaces.GPU
110
  def reset(session_state):
111
- predictor = get_predictor(session_state)
112
- predictor.to("cuda")
113
  session_state["input_points"] = []
114
  session_state["input_labels"] = []
115
  if session_state["inference_state"] is not None:
@@ -127,10 +116,7 @@ def reset(session_state):
127
  )
128
 
129
 
130
- @spaces.GPU
131
  def clear_points(session_state):
132
- predictor = get_predictor(session_state)
133
- predictor.to("cuda")
134
  session_state["input_points"] = []
135
  session_state["input_labels"] = []
136
  if session_state["inference_state"]["tracking_has_started"]:
@@ -145,8 +131,6 @@ def clear_points(session_state):
145
 
146
  @spaces.GPU
147
  def preprocess_video_in(video_path, session_state):
148
- predictor = get_predictor(session_state)
149
- predictor.to("cuda")
150
  if video_path is None:
151
  return (
152
  gr.update(open=True), # video_in_drawer
@@ -210,8 +194,6 @@ def segment_with_points(
210
  session_state,
211
  evt: gr.SelectData,
212
  ):
213
- predictor = get_predictor(session_state)
214
- predictor.to("cuda")
215
  session_state["input_points"].append(evt.index)
216
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
217
 
@@ -285,8 +267,6 @@ def propagate_to_all(
285
  video_in,
286
  session_state,
287
  ):
288
- predictor = get_predictor(session_state)
289
- predictor.to("cuda")
290
  if (
291
  len(session_state["input_points"]) == 0
292
  or video_in is None
 
71
  OBJ_ID = 0
72
 
73
 
74
+ sam2_checkpoint = "checkpoints/edgetam.pt"
75
+ model_cfg = "edgetam.yaml"
76
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
77
+ print("predictor loaded")
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # use bfloat16 for the entire demo
80
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
81
+ if torch.cuda.get_device_properties(0).major >= 8:
82
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
83
+ torch.backends.cuda.matmul.allow_tf32 = True
84
+ torch.backends.cudnn.allow_tf32 = True
85
 
86
 
87
  def get_video_fps(video_path):
 
98
  return fps
99
 
100
 
 
101
  def reset(session_state):
 
 
102
  session_state["input_points"] = []
103
  session_state["input_labels"] = []
104
  if session_state["inference_state"] is not None:
 
116
  )
117
 
118
 
 
119
  def clear_points(session_state):
 
 
120
  session_state["input_points"] = []
121
  session_state["input_labels"] = []
122
  if session_state["inference_state"]["tracking_has_started"]:
 
131
 
132
  @spaces.GPU
133
  def preprocess_video_in(video_path, session_state):
 
 
134
  if video_path is None:
135
  return (
136
  gr.update(open=True), # video_in_drawer
 
194
  session_state,
195
  evt: gr.SelectData,
196
  ):
 
 
197
  session_state["input_points"].append(evt.index)
198
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
199
 
 
267
  video_in,
268
  session_state,
269
  ):
 
 
270
  if (
271
  len(session_state["input_points"]) == 0
272
  or video_in is None