chongzhou commited on
Commit
dd43162
·
1 Parent(s): 18ff31a

put predictor in session state

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -72,22 +72,24 @@ OBJ_ID = 0
72
 
73
 
74
  @spaces.GPU
75
- def build_predictor():
76
- sam2_checkpoint = "checkpoints/edgetam.pt"
77
- model_cfg = "edgetam.yaml"
78
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
79
- print("predictor loaded")
80
-
81
- # use bfloat16 for the entire notebook
82
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
83
- if torch.cuda.get_device_properties(0).major >= 8:
84
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
85
- torch.backends.cuda.matmul.allow_tf32 = True
86
- torch.backends.cudnn.allow_tf32 = True
87
- return predictor
88
 
 
 
 
 
 
 
89
 
90
- predictor = build_predictor()
 
91
 
92
 
93
  def get_video_fps(video_path):
@@ -106,6 +108,7 @@ def get_video_fps(video_path):
106
 
107
  @spaces.GPU
108
  def reset(session_state):
 
109
  predictor.to("cuda")
110
  session_state["input_points"] = []
111
  session_state["input_labels"] = []
@@ -126,6 +129,7 @@ def reset(session_state):
126
 
127
  @spaces.GPU
128
  def clear_points(session_state):
 
129
  predictor.to("cuda")
130
  session_state["input_points"] = []
131
  session_state["input_labels"] = []
@@ -141,6 +145,7 @@ def clear_points(session_state):
141
 
142
  @spaces.GPU
143
  def preprocess_video_in(video_path, session_state):
 
144
  predictor.to("cuda")
145
  if video_path is None:
146
  return (
@@ -205,6 +210,7 @@ def segment_with_points(
205
  session_state,
206
  evt: gr.SelectData,
207
  ):
 
208
  predictor.to("cuda")
209
  session_state["input_points"].append(evt.index)
210
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
@@ -279,6 +285,7 @@ def propagate_to_all(
279
  video_in,
280
  session_state,
281
  ):
 
282
  predictor.to("cuda")
283
  if (
284
  len(session_state["input_points"]) == 0
 
72
 
73
 
74
  @spaces.GPU
75
+ def get_predictor(session_state):
76
+ if "predictor" not in session_state:
77
+ sam2_checkpoint = "checkpoints/edgetam.pt"
78
+ model_cfg = "edgetam.yaml"
79
+ predictor = build_sam2_video_predictor(
80
+ model_cfg, sam2_checkpoint, device="cuda"
81
+ )
82
+ print("predictor loaded")
 
 
 
 
 
83
 
84
+ # use bfloat16 for the entire demo
85
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
86
+ if torch.cuda.get_device_properties(0).major >= 8:
87
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
88
+ torch.backends.cuda.matmul.allow_tf32 = True
89
+ torch.backends.cudnn.allow_tf32 = True
90
 
91
+ session_state["predictor"] = predictor
92
+ return session_state["predictor"] = predictor
93
 
94
 
95
  def get_video_fps(video_path):
 
108
 
109
  @spaces.GPU
110
  def reset(session_state):
111
+ predictor = get_predictor(session_state)
112
  predictor.to("cuda")
113
  session_state["input_points"] = []
114
  session_state["input_labels"] = []
 
129
 
130
  @spaces.GPU
131
  def clear_points(session_state):
132
+ predictor = get_predictor(session_state)
133
  predictor.to("cuda")
134
  session_state["input_points"] = []
135
  session_state["input_labels"] = []
 
145
 
146
  @spaces.GPU
147
  def preprocess_video_in(video_path, session_state):
148
+ predictor = get_predictor(session_state)
149
  predictor.to("cuda")
150
  if video_path is None:
151
  return (
 
210
  session_state,
211
  evt: gr.SelectData,
212
  ):
213
+ predictor = get_predictor(session_state)
214
  predictor.to("cuda")
215
  session_state["input_points"].append(evt.index)
216
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
 
285
  video_in,
286
  session_state,
287
  ):
288
+ predictor = get_predictor(session_state)
289
  predictor.to("cuda")
290
  if (
291
  len(session_state["input_points"]) == 0