|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
import opencv_transforms.functional as FF |
|
from torchvision import datasets |
|
from PIL import Image |
|
|
|
def color_cluster(img, nclusters=9): |
|
""" |
|
Apply K-means clustering to the input image |
|
|
|
Args: |
|
img: Numpy array which has shape of (H, W, C) |
|
nclusters: # of clusters (default = 9) |
|
|
|
Returns: |
|
color_palette: list of 3D numpy arrays which have same shape of that of input image |
|
e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4] |
|
and each component is (256, 256, 3) numpy array. |
|
|
|
Note: |
|
K-means clustering algorithm is quite computaionally intensive. |
|
Thus, before extracting dominant colors, the input images are resized to x0.25 size. |
|
""" |
|
img_size = img.shape |
|
small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA) |
|
sample = small_img.reshape((-1, 3)) |
|
sample = np.float32(sample) |
|
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) |
|
flags = cv2.KMEANS_PP_CENTERS |
|
|
|
_, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags) |
|
centers = np.uint8(centers) |
|
color_palette = [] |
|
|
|
for i in range(0, nclusters): |
|
dominant_color = np.zeros(img_size, dtype='uint8') |
|
dominant_color[:,:,:] = centers[i] |
|
color_palette.append(dominant_color) |
|
|
|
return color_palette |
|
|
|
class PairImageFolder(datasets.ImageFolder): |
|
""" |
|
A generic data loader where the images are arranged in this way: :: |
|
|
|
root/dog/xxx.png |
|
root/dog/xxy.png |
|
root/dog/xxz.png |
|
|
|
root/cat/123.png |
|
root/cat/nsdf3.png |
|
root/cat/asd932_.png |
|
|
|
This class works properly for paired image in form of [sketch, color_image] |
|
|
|
Args: |
|
root (string): Root directory path. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
loader (callable, optional): A function to load an image given its path. |
|
is_valid_file (callable, optional): A function that takes path of an Image file |
|
and check if the file is a valid file (used to check of corrupt files) |
|
sketch_net: The network to convert color image to sketch image |
|
ncluster: Number of clusters when extracting color palette. |
|
|
|
Attributes: |
|
classes (list): List of the class names. |
|
class_to_idx (dict): Dict with items (class_name, class_index). |
|
imgs (list): List of (image path, class_index) tuples |
|
|
|
Getitem: |
|
img_edge: Edge image |
|
img: Color Image |
|
color_palette: Extracted color paltette |
|
""" |
|
def __init__(self, root, transform, sketch_net, ncluster): |
|
super(PairImageFolder, self).__init__(root, transform) |
|
self.ncluster = ncluster |
|
self.sketch_net = sketch_net |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
def __getitem__(self, index): |
|
path, label = self.imgs[index] |
|
img = self.loader(path) |
|
img = np.asarray(img) |
|
img = img[:, 0:512, :] |
|
img = self.transform(img) |
|
color_palette = color_cluster(img, nclusters=self.ncluster) |
|
img = self.make_tensor(img) |
|
|
|
with torch.no_grad(): |
|
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy() |
|
img_edge = FF.to_grayscale(img_edge, num_output_channels=3) |
|
img_edge = FF.to_tensor(img_edge) |
|
|
|
for i in range(0, len(color_palette)): |
|
color = color_palette[i] |
|
color_palette[i] = self.make_tensor(color) |
|
|
|
return img_edge, img, color_palette |
|
|
|
def make_tensor(self, img): |
|
img = FF.to_tensor(img) |
|
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
return img |
|
|
|
class GetImageFolder(datasets.ImageFolder): |
|
""" |
|
A generic data loader where the images are arranged in this way: :: |
|
|
|
root/dog/xxx.png |
|
root/dog/xxy.png |
|
root/dog/xxz.png |
|
|
|
root/cat/123.png |
|
root/cat/nsdf3.png |
|
root/cat/asd932_.png |
|
|
|
Args: |
|
root (string): Root directory path. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
loader (callable, optional): A function to load an image given its path. |
|
is_valid_file (callable, optional): A function that takes path of an Image file |
|
and check if the file is a valid file (used to check of corrupt files) |
|
sketch_net: The network to convert color image to sketch image |
|
ncluster: Number of clusters when extracting color palette. |
|
|
|
Attributes: |
|
classes (list): List of the class names. |
|
class_to_idx (dict): Dict with items (class_name, class_index). |
|
imgs (list): List of (image path, class_index) tuples |
|
|
|
Getitem: |
|
img_edge: Edge image |
|
img: Color Image |
|
color_palette: Extracted color paltette |
|
""" |
|
def __init__(self, root, transform, sketch_net, ncluster): |
|
super(GetImageFolder, self).__init__(root, transform) |
|
self.ncluster = ncluster |
|
self.sketch_net = sketch_net |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
def __getitem__(self, index): |
|
path, label = self.imgs[index] |
|
img = self.loader(path) |
|
img = np.asarray(img) |
|
img = self.transform(img) |
|
color_palette = color_cluster(img, nclusters=self.ncluster) |
|
img = self.make_tensor(img) |
|
|
|
with torch.no_grad(): |
|
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1, 2, 0).cpu().numpy() |
|
img_edge = FF.to_grayscale(img_edge, num_output_channels=3) |
|
img_edge = FF.to_tensor(img_edge) |
|
|
|
for i in range(0, len(color_palette)): |
|
color = color_palette[i] |
|
color_palette[i] = self.make_tensor(color) |
|
|
|
return img_edge, img, color_palette |
|
|
|
def make_tensor(self, img): |
|
img = FF.to_tensor(img) |
|
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
return img |