|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
# GAN for Comic Faces Paired Generation |
|
|
|
## Model Overview |
|
|
|
This model implements a **Generative Adversarial Network (GAN)** with a **UNet generator** and a **PatchGAN discriminator**. The network is designed to generate paired images of comic faces based on a synthetic dataset of comic faces. The model aims to generate high-quality image pairs where the first image is transformed into a second target image (e.g., photo-to-cartoon or cartoon-to-photo transformations). |
|
|
|
- **Dataset:** [Comic Faces Paired Synthetic Dataset](https://www.kaggle.com/datasets/defileroff/comic-faces-paired-synthetic) |
|
- **Batch Size:** 32 |
|
- **Input Shape:** (3, 256, 256) (RGB Images) |
|
- **Output Shape:** (3, 256, 256) |
|
|
|
## Model Architecture |
|
|
|
### Generator: **UNet** |
|
|
|
The generator uses a **UNet architecture**, which is designed for image-to-image translation tasks. It has an encoder-decoder structure with skip connections, allowing for high-resolution output. The architecture includes the following layers: |
|
|
|
- **Encoder Path (Contracting Path):** |
|
The encoder consists of **DoubleConv** layers that progressively downsample the input image to extract features. It uses **MaxPool2d** to reduce spatial dimensions. |
|
|
|
- **Bottleneck:** |
|
The deepest layer of the network (with 1024 feature channels) processes the smallest version of the image. |
|
|
|
- **Decoder Path (Expanding Path):** |
|
The decoder uses **Upsample** layers to progressively increase the spatial dimensions and **DoubleConv** layers to refine the output. Skip connections are used to combine features from the encoder path. |
|
|
|
- **Final Convolution:** |
|
The final layer outputs the transformed image using a **1x1 convolution**. |
|
|
|
### Discriminator: **PatchGANDiscriminator** |
|
|
|
The discriminator uses a **PatchGAN** architecture, which classifies patches of the image as real or fake. The discriminator works by processing the **input image and output image pair** (3 channels for the input image + 3 channels for the generated output). It progressively reduces the spatial dimensions using **Conv2d** and **LeakyReLU** activations, while normalizing each layer with **InstanceNorm2d**. The final output is a probability score indicating whether the patch is real or fake. |
|
|
|
--- |
|
|
|
### Generator Code (UNet): |
|
|
|
```python |
|
class UNet(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, in_channels, out_channels): |
|
super(UNet, self).__init__() |
|
|
|
# Contracting Path (Encoder) |
|
self.down_conv1 = DoubleConv(in_channels, 64) |
|
self.down_conv2 = DoubleConv(64, 128) |
|
self.down_conv3 = DoubleConv(128, 256) |
|
self.down_conv4 = DoubleConv(256, 512) |
|
self.down_conv5 = DoubleConv(512, 1024) |
|
|
|
# Downsampling |
|
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
# Upsampling layers using nn.Upsample |
|
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) |
|
|
|
# Decoder (Expanding Path) |
|
self.up_conv1 = DoubleConv(1024 + 512, 512) |
|
self.up_conv2 = DoubleConv(512 + 256, 256) |
|
self.up_conv3 = DoubleConv(256 + 128, 128) |
|
self.up_conv4 = DoubleConv(128 + 64, 64) |
|
|
|
# Final 1x1 convolution to get desired number of output channels |
|
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) |
|
|
|
def forward(self, x): |
|
x1 = self.down_conv1(x) |
|
x2 = self.down_conv2(self.maxpool(x1)) |
|
x3 = self.down_conv3(self.maxpool(x2)) |
|
x4 = self.down_conv4(self.maxpool(x3)) |
|
x5 = self.down_conv5(self.maxpool(x4)) |
|
|
|
x = self.upsample(x5) |
|
x = torch.cat([x4, x], dim=1) |
|
x = self.up_conv1(x) |
|
|
|
x = self.upsample(x) |
|
x = torch.cat([x3, x], dim=1) |
|
x = self.up_conv2(x) |
|
|
|
x = self.upsample(x) |
|
x = torch.cat([x2, x], dim=1) |
|
x = self.up_conv3(x) |
|
|
|
x = self.upsample(x) |
|
x = torch.cat([x1, x], dim=1) |
|
x = self.up_conv4(x) |
|
|
|
return self.final_conv(x) |
|
|
|
|
|
class PatchGANDiscriminator(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, in_channels=6): |
|
super().__init__() |
|
|
|
self.layers = nn.Sequential( |
|
nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), |
|
nn.InstanceNorm2d(128), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), |
|
nn.InstanceNorm2d(256), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1), |
|
nn.InstanceNorm2d(512), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
``` |
|
|