chongzhou commited on
Commit
18ff31a
·
1 Parent(s): 9096363

remove CUDA init in main process

Browse files
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -69,21 +69,27 @@ examples = [
69
  ]
70
 
71
  OBJ_ID = 0
72
- sam2_checkpoint = "checkpoints/edgetam.pt"
73
- model_cfg = "edgetam.yaml"
74
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
75
- print("PREDICTOR LOADED")
76
- predictor.to("cuda")
77
-
78
- # use bfloat16 for the entire notebook
79
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
80
- if torch.cuda.get_device_properties(0).major >= 8:
81
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
82
- torch.backends.cuda.matmul.allow_tf32 = True
83
- torch.backends.cudnn.allow_tf32 = True
84
 
85
 
86
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def get_video_fps(video_path):
88
  # Open the video file
89
  cap = cv2.VideoCapture(video_path)
@@ -100,6 +106,7 @@ def get_video_fps(video_path):
100
 
101
  @spaces.GPU
102
  def reset(session_state):
 
103
  session_state["input_points"] = []
104
  session_state["input_labels"] = []
105
  if session_state["inference_state"] is not None:
@@ -119,6 +126,7 @@ def reset(session_state):
119
 
120
  @spaces.GPU
121
  def clear_points(session_state):
 
122
  session_state["input_points"] = []
123
  session_state["input_labels"] = []
124
  if session_state["inference_state"]["tracking_has_started"]:
@@ -133,6 +141,7 @@ def clear_points(session_state):
133
 
134
  @spaces.GPU
135
  def preprocess_video_in(video_path, session_state):
 
136
  if video_path is None:
137
  return (
138
  gr.update(open=True), # video_in_drawer
@@ -196,6 +205,7 @@ def segment_with_points(
196
  session_state,
197
  evt: gr.SelectData,
198
  ):
 
199
  session_state["input_points"].append(evt.index)
200
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
201
 
@@ -249,7 +259,6 @@ def segment_with_points(
249
  return selected_point_map, first_frame_output, session_state
250
 
251
 
252
- @spaces.GPU
253
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
254
  if random_color:
255
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -270,6 +279,7 @@ def propagate_to_all(
270
  video_in,
271
  session_state,
272
  ):
 
273
  if (
274
  len(session_state["input_points"]) == 0
275
  or video_in is None
@@ -325,7 +335,6 @@ def propagate_to_all(
325
  )
326
 
327
 
328
- @spaces.GPU
329
  def update_ui():
330
  return gr.update(visible=True)
331
 
@@ -478,5 +487,6 @@ with gr.Blocks() as demo:
478
  queue=False,
479
  )
480
 
 
481
  demo.queue()
482
  demo.launch()
 
69
  ]
70
 
71
  OBJ_ID = 0
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  @spaces.GPU
75
+ def build_predictor():
76
+ sam2_checkpoint = "checkpoints/edgetam.pt"
77
+ model_cfg = "edgetam.yaml"
78
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
79
+ print("predictor loaded")
80
+
81
+ # use bfloat16 for the entire notebook
82
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
83
+ if torch.cuda.get_device_properties(0).major >= 8:
84
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
85
+ torch.backends.cuda.matmul.allow_tf32 = True
86
+ torch.backends.cudnn.allow_tf32 = True
87
+ return predictor
88
+
89
+
90
+ predictor = build_predictor()
91
+
92
+
93
  def get_video_fps(video_path):
94
  # Open the video file
95
  cap = cv2.VideoCapture(video_path)
 
106
 
107
  @spaces.GPU
108
  def reset(session_state):
109
+ predictor.to("cuda")
110
  session_state["input_points"] = []
111
  session_state["input_labels"] = []
112
  if session_state["inference_state"] is not None:
 
126
 
127
  @spaces.GPU
128
  def clear_points(session_state):
129
+ predictor.to("cuda")
130
  session_state["input_points"] = []
131
  session_state["input_labels"] = []
132
  if session_state["inference_state"]["tracking_has_started"]:
 
141
 
142
  @spaces.GPU
143
  def preprocess_video_in(video_path, session_state):
144
+ predictor.to("cuda")
145
  if video_path is None:
146
  return (
147
  gr.update(open=True), # video_in_drawer
 
205
  session_state,
206
  evt: gr.SelectData,
207
  ):
208
+ predictor.to("cuda")
209
  session_state["input_points"].append(evt.index)
210
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
211
 
 
259
  return selected_point_map, first_frame_output, session_state
260
 
261
 
 
262
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
263
  if random_color:
264
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
279
  video_in,
280
  session_state,
281
  ):
282
+ predictor.to("cuda")
283
  if (
284
  len(session_state["input_points"]) == 0
285
  or video_in is None
 
335
  )
336
 
337
 
 
338
  def update_ui():
339
  return gr.update(visible=True)
340
 
 
487
  queue=False,
488
  )
489
 
490
+
491
  demo.queue()
492
  demo.launch()