GAN for Comic Faces Paired Generation

Model Overview

This model implements a Generative Adversarial Network (GAN) with a UNet generator and a PatchGAN discriminator. The network is designed to generate paired images of comic faces based on a synthetic dataset of comic faces. The model aims to generate high-quality image pairs where the first image is transformed into a second target image (e.g., photo-to-cartoon or cartoon-to-photo transformations).

Model Architecture

Generator: UNet

The generator uses a UNet architecture, which is designed for image-to-image translation tasks. It has an encoder-decoder structure with skip connections, allowing for high-resolution output. The architecture includes the following layers:

  • Encoder Path (Contracting Path):
    The encoder consists of DoubleConv layers that progressively downsample the input image to extract features. It uses MaxPool2d to reduce spatial dimensions.

  • Bottleneck:
    The deepest layer of the network (with 1024 feature channels) processes the smallest version of the image.

  • Decoder Path (Expanding Path):
    The decoder uses Upsample layers to progressively increase the spatial dimensions and DoubleConv layers to refine the output. Skip connections are used to combine features from the encoder path.

  • Final Convolution:
    The final layer outputs the transformed image using a 1x1 convolution.

Discriminator: PatchGANDiscriminator

The discriminator uses a PatchGAN architecture, which classifies patches of the image as real or fake. The discriminator works by processing the input image and output image pair (3 channels for the input image + 3 channels for the generated output). It progressively reduces the spatial dimensions using Conv2d and LeakyReLU activations, while normalizing each layer with InstanceNorm2d. The final output is a probability score indicating whether the patch is real or fake.


Generator Code (UNet):

class UNet(nn.Module, PyTorchModelHubMixin):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Contracting Path (Encoder)
        self.down_conv1 = DoubleConv(in_channels, 64)
        self.down_conv2 = DoubleConv(64, 128)
        self.down_conv3 = DoubleConv(128, 256)
        self.down_conv4 = DoubleConv(256, 512)
        self.down_conv5 = DoubleConv(512, 1024)

        # Downsampling
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Upsampling layers using nn.Upsample
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        # Decoder (Expanding Path)
        self.up_conv1 = DoubleConv(1024 + 512, 512)
        self.up_conv2 = DoubleConv(512 + 256, 256)
        self.up_conv3 = DoubleConv(256 + 128, 128)
        self.up_conv4 = DoubleConv(128 + 64, 64)

        # Final 1x1 convolution to get desired number of output channels
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.down_conv1(x)
        x2 = self.down_conv2(self.maxpool(x1))
        x3 = self.down_conv3(self.maxpool(x2))
        x4 = self.down_conv4(self.maxpool(x3))
        x5 = self.down_conv5(self.maxpool(x4))

        x = self.upsample(x5)
        x = torch.cat([x4, x], dim=1)
        x = self.up_conv1(x)

        x = self.upsample(x)
        x = torch.cat([x3, x], dim=1)
        x = self.up_conv2(x)

        x = self.upsample(x)
        x = torch.cat([x2, x], dim=1)
        x = self.up_conv3(x)

        x = self.upsample(x)
        x = torch.cat([x1, x], dim=1)
        x = self.up_conv4(x)

        return self.final_conv(x)


class PatchGANDiscriminator(nn.Module, PyTorchModelHubMixin):
    def __init__(self, in_channels=6):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
        )

    def forward(self, x):
        return self.layers(x)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.