LPX55 commited on
Commit
112487d
·
1 Parent(s): a745151

mem split fix

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. sam2_mask.py +4 -2
requirements.txt CHANGED
@@ -12,5 +12,5 @@ safetensors
12
  matplotlib
13
  torchvision
14
  pydantic==2.10.6
15
- git+https://github.com/facebookresearch/sam2.git
16
  gradio_image_prompter
 
12
  matplotlib
13
  torchvision
14
  pydantic==2.10.6
15
+ sam2
16
  gradio_image_prompter
sam2_mask.py CHANGED
@@ -1,5 +1,5 @@
1
  # K-I-S-S
2
-
3
  import gradio as gr
4
  from gradio_image_prompter import ImagePrompter
5
  from sam2.sam2_image_predictor import SAM2ImagePredictor
@@ -10,10 +10,12 @@ from PIL import Image as PILImage
10
  # Initialize SAM2 predictor
11
  MODEL = "facebook/sam2.1-hiera-large"
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
14
 
 
15
  def predict_masks(image, points):
16
  """Predict a single mask from the image based on selected points."""
 
 
17
  image_np = np.array(image)
18
  points_list = [[point["x"], point["y"]] for point in points]
19
  input_labels = [1] * len(points_list)
 
1
  # K-I-S-S
2
+ import spaces
3
  import gradio as gr
4
  from gradio_image_prompter import ImagePrompter
5
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
10
  # Initialize SAM2 predictor
11
  MODEL = "facebook/sam2.1-hiera-large"
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
13
 
14
+ @spaces.GPU()
15
  def predict_masks(image, points):
16
  """Predict a single mask from the image based on selected points."""
17
+ global PREDICTOR
18
+ PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
19
  image_np = np.array(image)
20
  points_list = [[point["x"], point["y"]] for point in points]
21
  input_labels = [1] * len(points_list)