Spaces:
Running
on
Zero
Running
on
Zero
Update sam2_mask.py
Browse files- sam2_mask.py +3 -3
sam2_mask.py
CHANGED
@@ -111,6 +111,7 @@ def process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_s
|
|
111 |
mask = feather_mask(mask, feather_size)
|
112 |
return mask
|
113 |
|
|
|
114 |
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
|
115 |
image = Image.open(input_image)
|
116 |
image = np.array(image.convert("RGB"))
|
@@ -123,7 +124,7 @@ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label,
|
|
123 |
# sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
|
124 |
# Use CPU for both model and computations
|
125 |
# sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
126 |
-
predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="
|
127 |
|
128 |
# predictor = SAM2ImagePredictor(sam2_model)
|
129 |
predictor.set_image(image)
|
@@ -152,8 +153,7 @@ with gr.Blocks() as demo:
|
|
152 |
tracking_points = gr.State([])
|
153 |
trackings_input_label = gr.State([])
|
154 |
with gr.Column():
|
155 |
-
gr.Markdown("# SAM2 Image Predictor
|
156 |
-
gr.Markdown("This version runs entirely on CPU")
|
157 |
with gr.Row():
|
158 |
with gr.Column():
|
159 |
input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
|
|
|
111 |
mask = feather_mask(mask, feather_size)
|
112 |
return mask
|
113 |
|
114 |
+
@spaces.GPU()
|
115 |
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
|
116 |
image = Image.open(input_image)
|
117 |
image = np.array(image.convert("RGB"))
|
|
|
124 |
# sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
|
125 |
# Use CPU for both model and computations
|
126 |
# sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
127 |
+
predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="cuda")
|
128 |
|
129 |
# predictor = SAM2ImagePredictor(sam2_model)
|
130 |
predictor.set_image(image)
|
|
|
153 |
tracking_points = gr.State([])
|
154 |
trackings_input_label = gr.State([])
|
155 |
with gr.Column():
|
156 |
+
gr.Markdown("# SAM2 Image Predictor / Masking Assistant")
|
|
|
157 |
with gr.Row():
|
158 |
with gr.Column():
|
159 |
input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
|