clip-image / meta_segment_anything.py
T.Masuda
clip-image
d65ec94
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import torch
import numpy as np
from PIL import Image
class SegmentAnything:
def __init__(self):
sam_checkpoint = 'checkpoint/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
if torch.cuda.is_available():
sam.to(device='cuda')
self.sam = sam
def predict(self, image, point_coords, point_labels, box=None):
predictor = SamPredictor(self.sam)
predictor.set_image(np.array(image, dtype=np.uint8))
return predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box)
def generate(self, image):
mask_generator = SamAutomaticMaskGenerator(self.sam)
return mask_generator.generate(np.array(image, dtype=np.uint8))
@staticmethod
def makeMaskImage(mask, color):
image = Image.new('RGBA', mask.shape)
width, height = image.size
for x in range(width):
for y in range(height):
if mask[x, y]:
image.putpixel((x, y), color)
return image
@staticmethod
def makeNewImage(image, maskImage):
newImage = Image.new('RGBA', image.size)
timage = maskImage.copy()
width, height = timage.size
for x in range(width):
for y in range(height):
_, _, _, a = timage.getpixel((x, y))
timage.putpixel((x, y), (0, 0, 0, 255) if a > 0 else (0, 0, 0, 0))
newImage.paste(image, (0, 0), timage)
return newImage