|
import os |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import pickle |
|
from google_drive_downloader import GoogleDriveDownloader as gdd |
|
|
|
import torch |
|
import torchvision |
|
from torchvision import transforms |
|
from torchvision.models.segmentation import deeplabv3_resnet101 |
|
from torchvision.models.segmentation.deeplabv3 import DeepLabHead |
|
|
|
|
|
def deeplabv3_remove_bg(img): |
|
img = np.array(img, dtype=np.uint8) |
|
|
|
|
|
h,w,_ = img.shape |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
model = deeplabv3_resnet101(pretrained=True) |
|
model.classifier = DeepLabHead(2048, num_classes=1) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
preprocess = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
input_tensor = preprocess(img) |
|
input_batch = input_tensor.unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(input_batch)['out'][0] |
|
del model |
|
output = output.argmax(0) |
|
mask = output.byte().cpu().numpy() |
|
|
|
|
|
mask[mask>0] = 1.0 |
|
mask = np.dstack([mask, mask, mask]) |
|
img_masked = Image.fromarray(cv2.multiply(img, mask)) |
|
index_masked = np.where(np.array(mask)[:,:,2]==0) |
|
return img_masked, index_masked |
|
|
|
|