import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin class ModelColorization(nn.Module, PyTorchModelHubMixin): def __init__(self): super(ModelColorization, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 256, kernel_size=3, stride=1, padding=1), nn.MaxPool2d(kernel_size=2, stride=2), nn.ReLU(), nn.BatchNorm2d(256), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nn.MaxPool2d(kernel_size=2, stride=2), nn.ReLU(), nn.BatchNorm2d(128), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nn.MaxPool2d(kernel_size=2, stride=2), nn.ReLU(), nn.BatchNorm2d(64), nn.Flatten(), nn.Linear(64 * 16 * 16, 3000), ) self.decoder = nn.Sequential( nn.Linear(3000, 64 * 16 * 16), nn.ReLU(), nn.Unflatten(1, (64, 16, 16)), nn.ConvTranspose2d(64, 128, kernel_size=2, stride=2), nn.ReLU(), nn.BatchNorm2d(128), nn.ConvTranspose2d(128, 256, kernel_size=2, stride=2), nn.ReLU(), nn.BatchNorm2d(256), nn.ConvTranspose2d(256, 3, kernel_size=2, stride=2), nn.Sigmoid(), ) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x