sebastiansarasti commited on
Commit
e3fc5bb
·
verified ·
1 Parent(s): 3959c38

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +118 -3
README.md CHANGED
@@ -1,3 +1,118 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ license: mit
5
+ ---
6
+
7
+ # GAN for Comic Faces Paired Generation
8
+
9
+ ## Model Overview
10
+
11
+ This model implements a **Generative Adversarial Network (GAN)** with a **UNet generator** and a **PatchGAN discriminator**. The network is designed to generate paired images of comic faces based on a synthetic dataset of comic faces. The model aims to generate high-quality image pairs where the first image is transformed into a second target image (e.g., photo-to-cartoon or cartoon-to-photo transformations).
12
+
13
+ - **Dataset:** [Comic Faces Paired Synthetic Dataset](https://www.kaggle.com/datasets/defileroff/comic-faces-paired-synthetic)
14
+ - **Batch Size:** 32
15
+ - **Input Shape:** (3, 256, 256) (RGB Images)
16
+ - **Output Shape:** (3, 256, 256)
17
+
18
+ ## Model Architecture
19
+
20
+ ### Generator: **UNet**
21
+
22
+ The generator uses a **UNet architecture**, which is designed for image-to-image translation tasks. It has an encoder-decoder structure with skip connections, allowing for high-resolution output. The architecture includes the following layers:
23
+
24
+ - **Encoder Path (Contracting Path):**
25
+ The encoder consists of **DoubleConv** layers that progressively downsample the input image to extract features. It uses **MaxPool2d** to reduce spatial dimensions.
26
+
27
+ - **Bottleneck:**
28
+ The deepest layer of the network (with 1024 feature channels) processes the smallest version of the image.
29
+
30
+ - **Decoder Path (Expanding Path):**
31
+ The decoder uses **Upsample** layers to progressively increase the spatial dimensions and **DoubleConv** layers to refine the output. Skip connections are used to combine features from the encoder path.
32
+
33
+ - **Final Convolution:**
34
+ The final layer outputs the transformed image using a **1x1 convolution**.
35
+
36
+ ### Discriminator: **PatchGANDiscriminator**
37
+
38
+ The discriminator uses a **PatchGAN** architecture, which classifies patches of the image as real or fake. The discriminator works by processing the **input image and output image pair** (3 channels for the input image + 3 channels for the generated output). It progressively reduces the spatial dimensions using **Conv2d** and **LeakyReLU** activations, while normalizing each layer with **InstanceNorm2d**. The final output is a probability score indicating whether the patch is real or fake.
39
+
40
+ ---
41
+
42
+ ### Generator Code (UNet):
43
+
44
+ ```python
45
+ class UNet(nn.Module, PyTorchModelHubMixin):
46
+ def __init__(self, in_channels, out_channels):
47
+ super(UNet, self).__init__()
48
+
49
+ # Contracting Path (Encoder)
50
+ self.down_conv1 = DoubleConv(in_channels, 64)
51
+ self.down_conv2 = DoubleConv(64, 128)
52
+ self.down_conv3 = DoubleConv(128, 256)
53
+ self.down_conv4 = DoubleConv(256, 512)
54
+ self.down_conv5 = DoubleConv(512, 1024)
55
+
56
+ # Downsampling
57
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
58
+
59
+ # Upsampling layers using nn.Upsample
60
+ self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
61
+
62
+ # Decoder (Expanding Path)
63
+ self.up_conv1 = DoubleConv(1024 + 512, 512)
64
+ self.up_conv2 = DoubleConv(512 + 256, 256)
65
+ self.up_conv3 = DoubleConv(256 + 128, 128)
66
+ self.up_conv4 = DoubleConv(128 + 64, 64)
67
+
68
+ # Final 1x1 convolution to get desired number of output channels
69
+ self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
70
+
71
+ def forward(self, x):
72
+ x1 = self.down_conv1(x)
73
+ x2 = self.down_conv2(self.maxpool(x1))
74
+ x3 = self.down_conv3(self.maxpool(x2))
75
+ x4 = self.down_conv4(self.maxpool(x3))
76
+ x5 = self.down_conv5(self.maxpool(x4))
77
+
78
+ x = self.upsample(x5)
79
+ x = torch.cat([x4, x], dim=1)
80
+ x = self.up_conv1(x)
81
+
82
+ x = self.upsample(x)
83
+ x = torch.cat([x3, x], dim=1)
84
+ x = self.up_conv2(x)
85
+
86
+ x = self.upsample(x)
87
+ x = torch.cat([x2, x], dim=1)
88
+ x = self.up_conv3(x)
89
+
90
+ x = self.upsample(x)
91
+ x = torch.cat([x1, x], dim=1)
92
+ x = self.up_conv4(x)
93
+
94
+ return self.final_conv(x)
95
+
96
+
97
+ class PatchGANDiscriminator(nn.Module, PyTorchModelHubMixin):
98
+ def __init__(self, in_channels=6):
99
+ super().__init__()
100
+
101
+ self.layers = nn.Sequential(
102
+ nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
103
+ nn.LeakyReLU(0.2, inplace=True),
104
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
105
+ nn.InstanceNorm2d(128),
106
+ nn.LeakyReLU(0.2, inplace=True),
107
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
108
+ nn.InstanceNorm2d(256),
109
+ nn.LeakyReLU(0.2, inplace=True),
110
+ nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
111
+ nn.InstanceNorm2d(512),
112
+ nn.LeakyReLU(0.2, inplace=True),
113
+ nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
114
+ )
115
+
116
+ def forward(self, x):
117
+ return self.layers(x)
118
+ ```