Spaces:
Sleeping
Sleeping
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator | |
import torch | |
import numpy as np | |
from PIL import Image | |
class SegmentAnything: | |
def __init__(self): | |
sam_checkpoint = 'checkpoint/sam_vit_h_4b8939.pth' | |
model_type = 'vit_h' | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
if torch.cuda.is_available(): | |
sam.to(device='cuda') | |
self.sam = sam | |
def predict(self, image, point_coords, point_labels, box=None): | |
predictor = SamPredictor(self.sam) | |
predictor.set_image(np.array(image, dtype=np.uint8)) | |
return predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) | |
def generate(self, image): | |
mask_generator = SamAutomaticMaskGenerator(self.sam) | |
return mask_generator.generate(np.array(image, dtype=np.uint8)) | |
def makeMaskImage(mask, color): | |
image = Image.new('RGBA', mask.shape) | |
width, height = image.size | |
for x in range(width): | |
for y in range(height): | |
if mask[x, y]: | |
image.putpixel((x, y), color) | |
return image | |
def makeNewImage(image, maskImage): | |
newImage = Image.new('RGBA', image.size) | |
timage = maskImage.copy() | |
width, height = timage.size | |
for x in range(width): | |
for y in range(height): | |
_, _, _, a = timage.getpixel((x, y)) | |
timage.putpixel((x, y), (0, 0, 0, 255) if a > 0 else (0, 0, 0, 0)) | |
newImage.paste(image, (0, 0), timage) | |
return newImage | |