LPX55 commited on
Commit
79a0d4b
·
verified ·
1 Parent(s): 6c8ab7b

Update sam2_mask.py

Browse files
Files changed (1) hide show
  1. sam2_mask.py +3 -3
sam2_mask.py CHANGED
@@ -111,6 +111,7 @@ def process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_s
111
  mask = feather_mask(mask, feather_size)
112
  return mask
113
 
 
114
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
115
  image = Image.open(input_image)
116
  image = np.array(image.convert("RGB"))
@@ -123,7 +124,7 @@ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label,
123
  # sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
124
  # Use CPU for both model and computations
125
  # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
126
- predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="cpu")
127
 
128
  # predictor = SAM2ImagePredictor(sam2_model)
129
  predictor.set_image(image)
@@ -152,8 +153,7 @@ with gr.Blocks() as demo:
152
  tracking_points = gr.State([])
153
  trackings_input_label = gr.State([])
154
  with gr.Column():
155
- gr.Markdown("# SAM2 Image Predictor (CPU Version)")
156
- gr.Markdown("This version runs entirely on CPU")
157
  with gr.Row():
158
  with gr.Column():
159
  input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
 
111
  mask = feather_mask(mask, feather_size)
112
  return mask
113
 
114
+ @spaces.GPU()
115
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
116
  image = Image.open(input_image)
117
  image = np.array(image.convert("RGB"))
 
124
  # sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
125
  # Use CPU for both model and computations
126
  # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
127
+ predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="cuda")
128
 
129
  # predictor = SAM2ImagePredictor(sam2_model)
130
  predictor.set_image(image)
 
153
  tracking_points = gr.State([])
154
  trackings_input_label = gr.State([])
155
  with gr.Column():
156
+ gr.Markdown("# SAM2 Image Predictor / Masking Assistant")
 
157
  with gr.Row():
158
  with gr.Column():
159
  input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)