Aaron Vattay commited on
Commit
a6a4c31
·
1 Parent(s): c3ce88e

Model relase

Browse files

git push origin main

upscaling

Files changed (4) hide show
  1. .gitattributes +1 -0
  2. AIupscale_run.py +58 -0
  3. AIupscale_train.py +113 -0
  4. upscaling.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ upscaling.pth filter=lfs diff=lfs merge=lfs -text
AIupscale_run.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import coremltools as ct
5
+
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms
8
+ import os
9
+ from PIL import Image
10
+ import torchvision.transforms.functional as TF
11
+
12
+ device = torch.device("mps")
13
+ class UPSC(nn.Module):
14
+ def __init__(self):
15
+ super(UPSC,self).__init__()
16
+ self.model = nn.Sequential(
17
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
18
+ nn.ReLU(),
19
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1),
20
+ nn.ReLU(),
21
+ # This convolution outputs channels that are scale_factor^2 * number_of_channels.
22
+ nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1),
23
+ # PixelShuffle rearranges channels into spatial dimensions.
24
+ nn.PixelShuffle(3)
25
+ )
26
+ def forward(self, x):
27
+ return self.model(x)
28
+
29
+ model = UPSC().to(device)
30
+ model.load_state_dict(torch.load("upscaling.pth", weights_only=True))
31
+ model.eval()
32
+
33
+ img = Image.open("test.png").convert("RGB")
34
+
35
+ # Resize it to match what the model expects (e.g. 256x256)
36
+ transform = transforms.Compose([
37
+ transforms.Resize((256, 256)), # match training input size
38
+ transforms.ToTensor()
39
+ ])
40
+
41
+ lr_tensor = transform(img).unsqueeze(0).to(device)
42
+
43
+ with torch.no_grad():
44
+ sr_tensor = model(lr_tensor)
45
+ traced_model = torch.jit.trace(model, lr_tensor)
46
+
47
+
48
+ # Remove batch dimension and convert to PIL
49
+ sr_image = TF.to_pil_image(sr_tensor.squeeze(0).clamp(0, 1))
50
+ sr_image.save("upscaled_output_5.jpg")
51
+
52
+ mlmodel = ct.convert(
53
+ traced_model,
54
+ inputs=[ct.ImageType(name="input", shape=lr_tensor.shape)],
55
+ compute_units=ct.ComputeUnit.ALL # Use ANE if available
56
+ )
57
+
58
+ mlmodel.save("upscaling.mlmodel")
AIupscale_train.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ import os
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset,dataloader
9
+ from torchvision import transforms
10
+ from torch.utils.data import DataLoader
11
+
12
+ class UPSC(nn.Module):
13
+ def __init__(self):
14
+ super(UPSC,self).__init__()
15
+ self.model = nn.Sequential(
16
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
17
+ nn.ReLU(),
18
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1),
19
+ nn.ReLU(),
20
+ # This convolution outputs channels that are scale_factor^2 * number_of_channels.
21
+ nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1),
22
+ # PixelShuffle rearranges channels into spatial dimensions.
23
+ nn.PixelShuffle(3)
24
+ )
25
+ def forward(self, x):
26
+ return self.model(x)
27
+
28
+
29
+
30
+ class PairedSuperResolutionDataset(Dataset):
31
+ def __init__(self, lr_dir, hr_dir, lr_size=(64, 64), hr_size=(256, 256)):
32
+ self.lr_dir = lr_dir
33
+ self.hr_dir = hr_dir
34
+ self.lr_files = sorted(os.listdir(lr_dir))
35
+ self.hr_files = sorted(os.listdir(hr_dir))
36
+
37
+ self.transform_lr = transforms.Compose([
38
+ transforms.Resize(lr_size),
39
+ transforms.ToTensor()
40
+ ])
41
+
42
+ self.transform_hr = transforms.Compose([
43
+ transforms.Resize(hr_size),
44
+ transforms.ToTensor()
45
+ ])
46
+
47
+ def __len__(self):
48
+ return len(self.lr_files)
49
+
50
+ def __getitem__(self, idx):
51
+ lr_path = os.path.join(self.lr_dir, self.lr_files[idx])
52
+ hr_path = os.path.join(self.hr_dir, self.hr_files[idx])
53
+
54
+ lr_img = Image.open(lr_path).convert("RGB")
55
+ hr_img = Image.open(hr_path).convert("RGB")
56
+
57
+ lr_tensor = self.transform_lr(lr_img)
58
+ hr_tensor = self.transform_hr(hr_img)
59
+
60
+ return lr_tensor, hr_tensor
61
+
62
+
63
+ lr_dir = '/Users/aaronvattay/Documents/DF2K_train_LR_bicubic/X3'
64
+ hr_dir = '/Users/aaronvattay/Documents/DF2K_train_HR'
65
+ batch_size = 16
66
+ num_epochs = 10
67
+ learning_rate = 1e-4
68
+
69
+
70
+ # Create dataset and dataloader
71
+ dataset = PairedSuperResolutionDataset(
72
+ lr_dir=lr_dir,
73
+ hr_dir=hr_dir,
74
+ lr_size=(256,256),
75
+ hr_size=(768,768)
76
+ )
77
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
78
+
79
+ # Device configuration
80
+ device = torch.device("mps")
81
+
82
+ # Initialize model, loss, and optimizer
83
+ model = UPSC().to(device)
84
+ criterion = nn.MSELoss()
85
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
86
+ # Load the model state if available
87
+ if os.path.exists("upscaling.pth"):
88
+ model.load_state_dict(torch.load("upscaling.pth",map_location=device,weights_only=True))
89
+ # Set the model to training mode
90
+ model.train()
91
+ if __name__ == "__main__":
92
+ for epoch in range(num_epochs):
93
+ epoch_loss = 0.0
94
+ for lr_imgs, hr_imgs in dataloader:
95
+ # Move images to device
96
+ lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
97
+
98
+ # Forward pass: Model produces the upscaled images
99
+ outputs = model(lr_imgs)
100
+ loss = criterion(outputs, hr_imgs)
101
+
102
+ # Backpropagation and optimization
103
+ optimizer.zero_grad() # Clear gradients for this iteration
104
+ loss.backward() # Backpropagate the loss
105
+ optimizer.step() # Update weights
106
+
107
+ epoch_loss += loss.item()
108
+
109
+ avg_loss = epoch_loss / len(dataloader)
110
+ print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")
111
+
112
+ # Optionally, save your trained model for later inference
113
+ torch.save(model.state_dict(), "upscaling.pth")
upscaling.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8b27d159a451b1fac7efc1d1e3b2828dfafeea2695d344249df6a4cbf312f1b
3
+ size 127260