import cv2 |
import numpy as np |
import torch |
import torchvision |
import opencv_transforms.functional as FF |
from torchvision import datasets |
from PIL import Image |
def color_cluster(img, nclusters=9): |
""" |
Apply K-means clustering to the input image |
Args: |
img: Numpy array which has shape of (H, W, C) |
nclusters: # of clusters (default = 9) |
Returns: |
color_palette: list of 3D numpy arrays which have same shape of that of input image |
e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4] |
and each component is (256, 256, 3) numpy array. |
Note: |
K-means clustering algorithm is quite computaionally intensive. |
Thus, before extracting dominant colors, the input images are resized to x0.25 size. |
""" |
img_size = img.shape |
small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA) |
sample = small_img.reshape((-1, 3)) |
sample = np.float32(sample) |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) |
flags = cv2.KMEANS_PP_CENTERS |
_, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags) |
centers = np.uint8(centers) |
color_palette = [] |
for i in range(0, nclusters): |
dominant_color = np.zeros(img_size, dtype='uint8') |
dominant_color[:,:,:] = centers[i] |
color_palette.append(dominant_color) |
return color_palette |
class PairImageFolder(datasets.ImageFolder): |
""" |
A generic data loader where the images are arranged in this way: :: |
root/dog/xxx.png |
root/dog/xxy.png |
root/dog/xxz.png |
root/cat/123.png |
root/cat/nsdf3.png |
root/cat/asd932_.png |
This class works properly for paired image in form of [sketch, color_image] |
Args: |
root (string): Root directory path. |
transform (callable, optional): A function/transform that takes in an PIL image |
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
target_transform (callable, optional): A function/transform that takes in the |
target and transforms it. |
loader (callable, optional): A function to load an image given its path. |
is_valid_file (callable, optional): A function that takes path of an Image file |
and check if the file is a valid file (used to check of corrupt files) |
sketch_net: The network to convert color image to sketch image |
ncluster: Number of clusters when extracting color palette. |
Attributes: |
classes (list): List of the class names. |
class_to_idx (dict): Dict with items (class_name, class_index). |
imgs (list): List of (image path, class_index) tuples |
Getitem: |
img_edge: Edge image |
img: Color Image |
color_palette: Extracted color paltette |
""" |
def __init__(self, root, transform, sketch_net, ncluster): |
super(PairImageFolder, self).__init__(root, transform) |
self.ncluster = ncluster |
self.sketch_net = sketch_net |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
def __getitem__(self, index): |
path, label = self.imgs[index] |
img = self.loader(path) |
img = np.asarray(img) |
img = img[:, 0:512, :] |
img = self.transform(img) |
color_palette = color_cluster(img, nclusters=self.ncluster) |
img = self.make_tensor(img) |
with torch.no_grad(): |
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy() |
img_edge = FF.to_grayscale(img_edge, num_output_channels=3) |
img_edge = FF.to_tensor(img_edge) |
for i in range(0, len(color_palette)): |
color = color_palette[i] |
color_palette[i] = self.make_tensor(color) |
return img_edge, img, color_palette |
def make_tensor(self, img): |
img = FF.to_tensor(img) |
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
return img |
class GetImageFolder(datasets.ImageFolder): |
""" |
A generic data loader where the images are arranged in this way: :: |
root/dog/xxx.png |
root/dog/xxy.png |
root/dog/xxz.png |
root/cat/123.png |
root/cat/nsdf3.png |
root/cat/asd932_.png |
Args: |
root (string): Root directory path. |
transform (callable, optional): A function/transform that takes in an PIL image |
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
target_transform (callable, optional): A function/transform that takes in the |
target and transforms it. |
loader (callable, optional): A function to load an image given its path. |
is_valid_file (callable, optional): A function that takes path of an Image file |
and check if the file is a valid file (used to check of corrupt files) |
sketch_net: The network to convert color image to sketch image |
ncluster: Number of clusters when extracting color palette. |
Attributes: |
classes (list): List of the class names. |
class_to_idx (dict): Dict with items (class_name, class_index). |
imgs (list): List of (image path, class_index) tuples |
Getitem: |
img_edge: Edge image |
img: Color Image |
color_palette: Extracted color paltette |
""" |
def __init__(self, root, transform, sketch_net, ncluster): |
super(GetImageFolder, self).__init__(root, transform) |
self.ncluster = ncluster |
self.sketch_net = sketch_net |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
def __getitem__(self, index): |
path, label = self.imgs[index] |
img = self.loader(path) |
img = np.asarray(img) |
img = self.transform(img) |
color_palette = color_cluster(img, nclusters=self.ncluster) |
img = self.make_tensor(img) |
with torch.no_grad(): |
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1, 2, 0).cpu().numpy() |
img_edge = FF.to_grayscale(img_edge, num_output_channels=3) |
img_edge = FF.to_tensor(img_edge) |
for i in range(0, len(color_palette)): |
color = color_palette[i] |
color_palette[i] = self.make_tensor(color) |
return img_edge, img, color_palette |
def make_tensor(self, img): |
img = FF.to_tensor(img) |
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
return img |