Upload ComplexUNet for CIFAR-10 inpainting
Browse files- README.md +94 -0
- inpainting_transformer_weights.pth +3 -0
- model.py +53 -0
README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
---
|
| 3 |
+
license: mit
|
| 4 |
+
language: en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- image-inpainting
|
| 8 |
+
- computer-vision
|
| 9 |
+
- pytorch
|
| 10 |
+
- unet
|
| 11 |
+
- cifar-10
|
| 12 |
+
datasets:
|
| 13 |
+
- cifar10
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# U-Net for Image Inpainting on CIFAR-10
|
| 17 |
+
|
| 18 |
+
This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained to perform image inpainting on the CIFAR-10 dataset. The model takes an image with a masked (blacked-out) region and reconstructs the missing part.
|
| 19 |
+
|
| 20 |
+
## Model Description
|
| 21 |
+
|
| 22 |
+
The model is a `ComplexUNet` architecture, a variant of the standard U-Net. It features:
|
| 23 |
+
- **Deeper Architecture**: 4 downsampling and 4 upsampling stages.
|
| 24 |
+
- **Residual Blocks**: Each stage uses residual blocks instead of simple convolutional layers.
|
| 25 |
+
- **Increased Width**: The model was trained with `base_channels=96`.
|
| 26 |
+
- **Total Parameters**: 73,148,259
|
| 27 |
+
|
| 28 |
+
## How to Use
|
| 29 |
+
|
| 30 |
+
First, install the required libraries:
|
| 31 |
+
```bash
|
| 32 |
+
pip install torch torchvision numpy Pillow
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Then, you can load the model and perform inpainting on an image tensor.
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
import torch
|
| 39 |
+
from torchvision import transforms as T
|
| 40 |
+
from PIL import Image
|
| 41 |
+
from model import ComplexUNet # Import the class from model.py
|
| 42 |
+
|
| 43 |
+
# --- Setup ---
|
| 44 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
# Download the .pth file from the 'Files and versions' tab of this repo
|
| 46 |
+
MODEL_PATH = "inpainting_model_larger.pth"
|
| 47 |
+
|
| 48 |
+
# --- Load Model ---
|
| 49 |
+
model = ComplexUNet(base_channels=96)
|
| 50 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| 51 |
+
model.to(DEVICE)
|
| 52 |
+
model.eval()
|
| 53 |
+
|
| 54 |
+
# --- Load and Preprocess Image ---
|
| 55 |
+
# image = Image.open("your_image.png").convert("RGB")
|
| 56 |
+
# For demonstration, let's create a dummy tensor
|
| 57 |
+
transform = T.Compose([T.Resize((32, 32)), T.ToTensor()])
|
| 58 |
+
# image_tensor = transform(image)
|
| 59 |
+
image_tensor = torch.rand(3, 32, 32)
|
| 60 |
+
|
| 61 |
+
# --- Create a Mask ---
|
| 62 |
+
masked_tensor = image_tensor.clone()
|
| 63 |
+
masked_tensor[:, 8:24, 8:24] = 0 # Example mask in the center
|
| 64 |
+
|
| 65 |
+
# --- Perform Inpainting ---
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
input_tensor = masked_tensor.unsqueeze(0).to(DEVICE)
|
| 68 |
+
reconstructed_tensor = model(input_tensor).squeeze(0).cpu()
|
| 69 |
+
|
| 70 |
+
# 'reconstructed_tensor' now holds the inpainted image.
|
| 71 |
+
from torchvision.transforms.functional import to_pil_image
|
| 72 |
+
reconstructed_image = to_pil_image(reconstructed_tensor)
|
| 73 |
+
reconstructed_image.save("reconstructed_image.png")
|
| 74 |
+
print("Saved reconstructed_image.png")
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Training Data
|
| 78 |
+
|
| 79 |
+
The model was trained on the **CIFAR-10** dataset.
|
| 80 |
+
- **Preprocessing**: Images were used at their original **32x32 pixels** resolution.
|
| 81 |
+
- **Augmentation**: For each training image, a random rectangular mask was applied.
|
| 82 |
+
|
| 83 |
+
## Training Procedure
|
| 84 |
+
|
| 85 |
+
- **Framework**: PyTorch
|
| 86 |
+
- **Optimizer**: Adam
|
| 87 |
+
- **Learning Rate**: 0.001
|
| 88 |
+
- **Epochs**: 50
|
| 89 |
+
- **Batch Size**: 128
|
| 90 |
+
- **Loss Function**: Mean Squared Error (MSE)
|
| 91 |
+
|
| 92 |
+
## Evaluation
|
| 93 |
+
|
| 94 |
+
Evaluation metrics were not saved by the training script. To get PSNR and SSIM, please run the `evaluate_model` function from the training script.
|
inpainting_transformer_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e8f1afbf67c0f3de4261cb41d4083d238bbfd51d54f9ed9e67b7d3faa3183d2d
|
| 3 |
+
size 7928446
|
model.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class ResidualBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, out_channels):
|
| 7 |
+
super(ResidualBlock, self).__init__()
|
| 8 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
| 9 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 10 |
+
self.relu = nn.ReLU(inplace=True)
|
| 11 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
| 12 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 13 |
+
if in_channels != out_channels:
|
| 14 |
+
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels))
|
| 15 |
+
else:
|
| 16 |
+
self.shortcut = nn.Identity()
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
residual = self.shortcut(x)
|
| 19 |
+
out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
|
| 20 |
+
out = self.conv2(out); out = self.bn2(out)
|
| 21 |
+
out += residual
|
| 22 |
+
out = self.relu(out)
|
| 23 |
+
return out
|
| 24 |
+
|
| 25 |
+
class ComplexUNet(nn.Module):
|
| 26 |
+
def __init__(self, base_channels=96): # Default to the trained architecture
|
| 27 |
+
super(ComplexUNet, self).__init__()
|
| 28 |
+
c = base_channels
|
| 29 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 30 |
+
self.enc1 = ResidualBlock(3, c)
|
| 31 |
+
self.enc2 = ResidualBlock(c, c*2)
|
| 32 |
+
self.enc3 = ResidualBlock(c*2, c*4)
|
| 33 |
+
self.enc4 = ResidualBlock(c*4, c*8)
|
| 34 |
+
self.bottleneck = ResidualBlock(c*8, c*16)
|
| 35 |
+
self.upconv1 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2)
|
| 36 |
+
self.upconv2 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2)
|
| 37 |
+
self.upconv3 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2)
|
| 38 |
+
self.upconv4 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2)
|
| 39 |
+
self.dec_conv1 = ResidualBlock(c*16, c*8)
|
| 40 |
+
self.dec_conv2 = ResidualBlock(c*8, c*4)
|
| 41 |
+
self.dec_conv3 = ResidualBlock(c*4, c*2)
|
| 42 |
+
self.dec_conv4 = ResidualBlock(c*2, c)
|
| 43 |
+
self.final_conv = nn.Conv2d(c, 3, kernel_size=1)
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2)
|
| 46 |
+
e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4)
|
| 47 |
+
b = self.bottleneck(p4)
|
| 48 |
+
d1 = self.upconv1(b); d1 = torch.cat([d1, e4], dim=1); d1 = self.dec_conv1(d1)
|
| 49 |
+
d2 = self.upconv2(d1); d2 = torch.cat([d2, e3], dim=1); d2 = self.dec_conv2(d2)
|
| 50 |
+
d3 = self.upconv3(d2); d3 = torch.cat([d3, e2], dim=1); d3 = self.dec_conv3(d3)
|
| 51 |
+
d4 = self.upconv4(d3); d4 = torch.cat([d4, e1], dim=1); d4 = self.dec_conv4(d4)
|
| 52 |
+
out = self.final_conv(d4)
|
| 53 |
+
return torch.sigmoid(out)
|