rajatsingh0702's picture
files added
7234ee2
raw
history blame
6.67 kB
import cv2
import numpy as np
import torch
import torchvision
import opencv_transforms.functional as FF
from torchvision import datasets
from PIL import Image
def color_cluster(img, nclusters=9):
"""
Apply K-means clustering to the input image
Args:
img: Numpy array which has shape of (H, W, C)
nclusters: # of clusters (default = 9)
Returns:
color_palette: list of 3D numpy arrays which have same shape of that of input image
e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4]
and each component is (256, 256, 3) numpy array.
Note:
K-means clustering algorithm is quite computaionally intensive.
Thus, before extracting dominant colors, the input images are resized to x0.25 size.
"""
img_size = img.shape
small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
sample = small_img.reshape((-1, 3))
sample = np.float32(sample)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
flags = cv2.KMEANS_PP_CENTERS
_, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags)
centers = np.uint8(centers)
color_palette = []
for i in range(0, nclusters):
dominant_color = np.zeros(img_size, dtype='uint8')
dominant_color[:,:,:] = centers[i]
color_palette.append(dominant_color)
return color_palette
class PairImageFolder(datasets.ImageFolder):
"""
A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
This class works properly for paired image in form of [sketch, color_image]
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
sketch_net: The network to convert color image to sketch image
ncluster: Number of clusters when extracting color palette.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
Getitem:
img_edge: Edge image
img: Color Image
color_palette: Extracted color paltette
"""
def __init__(self, root, transform, sketch_net, ncluster):
super(PairImageFolder, self).__init__(root, transform)
self.ncluster = ncluster
self.sketch_net = sketch_net
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def __getitem__(self, index):
path, label = self.imgs[index]
img = self.loader(path)
img = np.asarray(img)
img = img[:, 0:512, :]
img = self.transform(img)
color_palette = color_cluster(img, nclusters=self.ncluster)
img = self.make_tensor(img)
with torch.no_grad():
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy()
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
img_edge = FF.to_tensor(img_edge)
for i in range(0, len(color_palette)):
color = color_palette[i]
color_palette[i] = self.make_tensor(color)
return img_edge, img, color_palette
def make_tensor(self, img):
img = FF.to_tensor(img)
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
return img
class GetImageFolder(datasets.ImageFolder):
"""
A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
sketch_net: The network to convert color image to sketch image
ncluster: Number of clusters when extracting color palette.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
Getitem:
img_edge: Edge image
img: Color Image
color_palette: Extracted color paltette
"""
def __init__(self, root, transform, sketch_net, ncluster):
super(GetImageFolder, self).__init__(root, transform)
self.ncluster = ncluster
self.sketch_net = sketch_net
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def __getitem__(self, index):
path, label = self.imgs[index]
img = self.loader(path)
img = np.asarray(img)
img = self.transform(img)
color_palette = color_cluster(img, nclusters=self.ncluster)
img = self.make_tensor(img)
with torch.no_grad():
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1, 2, 0).cpu().numpy()
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
img_edge = FF.to_tensor(img_edge)
for i in range(0, len(color_palette)):
color = color_palette[i]
color_palette[i] = self.make_tensor(color)
return img_edge, img, color_palette
def make_tensor(self, img):
img = FF.to_tensor(img)
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
return img