|
import os |
|
from skimage import io, transform |
|
from skimage.filters import gaussian |
|
import torch |
|
import torchvision |
|
from torch.autograd import Variable |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
|
|
|
|
import numpy as np |
|
from PIL import Image |
|
import glob |
|
|
|
from data_loader import RescaleT |
|
from data_loader import ToTensor |
|
from data_loader import ToTensorLab |
|
from data_loader import SalObjDataset |
|
|
|
from model import U2NET |
|
from model import U2NETP |
|
|
|
import argparse |
|
|
|
|
|
def normPRED(d): |
|
ma = torch.max(d) |
|
mi = torch.min(d) |
|
|
|
dn = (d-mi)/(ma-mi) |
|
|
|
return dn |
|
|
|
def save_output(image_name,pred,d_dir,sigma=2,alpha=0.5): |
|
|
|
predict = pred |
|
predict = predict.squeeze() |
|
predict_np = predict.cpu().data.numpy() |
|
|
|
image = io.imread(image_name) |
|
pd = transform.resize(predict_np,image.shape[0:2],order=2) |
|
pd = pd/(np.amax(pd)+1e-8)*255 |
|
pd = pd[:,:,np.newaxis] |
|
|
|
print(image.shape) |
|
print(pd.shape) |
|
|
|
|
|
|
|
sigma=sigma |
|
image = gaussian(image, sigma=sigma, preserve_range=True) |
|
|
|
|
|
alpha = alpha |
|
im_comp = image*alpha+pd*(1-alpha) |
|
|
|
print(im_comp.shape) |
|
|
|
|
|
img_name = image_name.split(os.sep)[-1] |
|
aaa = img_name.split(".") |
|
bbb = aaa[0:-1] |
|
imidx = bbb[0] |
|
for i in range(1,len(bbb)): |
|
imidx = imidx + "." + bbb[i] |
|
io.imsave(d_dir+'/'+imidx+'_sigma_' + str(sigma) + '_alpha_' + str(alpha) + '_composite.png',im_comp) |
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser(description="image and portrait composite") |
|
parser.add_argument('-s',action='store',dest='sigma') |
|
parser.add_argument('-a',action='store',dest='alpha') |
|
args = parser.parse_args() |
|
print(args.sigma) |
|
print(args.alpha) |
|
print("--------------------") |
|
|
|
|
|
model_name='u2net_portrait' |
|
|
|
|
|
image_dir = './test_data/test_portrait_images/your_portrait_im' |
|
prediction_dir = './test_data/test_portrait_images/your_portrait_results' |
|
if(not os.path.exists(prediction_dir)): |
|
os.mkdir(prediction_dir) |
|
|
|
model_dir = './saved_models/u2net_portrait/u2net_portrait.pth' |
|
|
|
img_name_list = glob.glob(image_dir+'/*') |
|
print("Number of images: ", len(img_name_list)) |
|
|
|
|
|
|
|
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, |
|
lbl_name_list = [], |
|
transform=transforms.Compose([RescaleT(512), |
|
ToTensorLab(flag=0)]) |
|
) |
|
test_salobj_dataloader = DataLoader(test_salobj_dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=1) |
|
|
|
|
|
|
|
print("...load U2NET---173.6 MB") |
|
net = U2NET(3,1) |
|
|
|
net.load_state_dict(torch.load(model_dir)) |
|
if torch.cuda.is_available(): |
|
net.cuda() |
|
net.eval() |
|
|
|
|
|
for i_test, data_test in enumerate(test_salobj_dataloader): |
|
|
|
print("inferencing:",img_name_list[i_test].split(os.sep)[-1]) |
|
|
|
inputs_test = data_test['image'] |
|
inputs_test = inputs_test.type(torch.FloatTensor) |
|
|
|
if torch.cuda.is_available(): |
|
inputs_test = Variable(inputs_test.cuda()) |
|
else: |
|
inputs_test = Variable(inputs_test) |
|
|
|
d1,d2,d3,d4,d5,d6,d7= net(inputs_test) |
|
|
|
|
|
pred = 1.0 - d1[:,0,:,:] |
|
pred = normPRED(pred) |
|
|
|
|
|
save_output(img_name_list[i_test],pred,prediction_dir,sigma=float(args.sigma),alpha=float(args.alpha)) |
|
|
|
del d1,d2,d3,d4,d5,d6,d7 |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|