Spaces:
Running
Running
Added train.py
Browse files- .gitignore +4 -1
- modules/model.py +68 -0
- main.py → predict.py +3 -71
- train.py +146 -0
.gitignore
CHANGED
@@ -1,2 +1,5 @@
|
|
1 |
model_parameters.pt
|
2 |
-
.DS_Store
|
|
|
|
|
|
|
|
1 |
model_parameters.pt
|
2 |
+
.DS_Store
|
3 |
+
modules/__pycache__
|
4 |
+
/celeba
|
5 |
+
/.vscode
|
modules/model.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def conv_block(in_channels, out_channels, pool=False):
|
5 |
+
layers = [
|
6 |
+
nn.Conv2d(
|
7 |
+
in_channels,
|
8 |
+
out_channels,
|
9 |
+
kernel_size=3,
|
10 |
+
padding=1
|
11 |
+
),
|
12 |
+
nn.BatchNorm2d(out_channels),
|
13 |
+
nn.ReLU()
|
14 |
+
]
|
15 |
+
if pool:
|
16 |
+
layers.append(
|
17 |
+
nn.MaxPool2d(4)
|
18 |
+
)
|
19 |
+
return nn.Sequential(*layers)
|
20 |
+
|
21 |
+
class resnetModel_128(nn.Module):
|
22 |
+
def __init__(self):
|
23 |
+
super().__init__()
|
24 |
+
self.model_name = 'resnetModel_128'
|
25 |
+
|
26 |
+
self.conv_1 = conv_block(1, 64)
|
27 |
+
self.res_1 = nn.Sequential(
|
28 |
+
conv_block(64, 64),
|
29 |
+
conv_block(64, 64)
|
30 |
+
)
|
31 |
+
self.conv_2 = conv_block(64, 256, pool=True)
|
32 |
+
self.res_2 = nn.Sequential(
|
33 |
+
conv_block(256, 256),
|
34 |
+
conv_block(256, 256)
|
35 |
+
)
|
36 |
+
self.conv_3 = conv_block(256, 512, pool=True)
|
37 |
+
self.res_3 = nn.Sequential(
|
38 |
+
conv_block(512, 512),
|
39 |
+
conv_block(512, 512)
|
40 |
+
)
|
41 |
+
self.conv_4 = conv_block(512, 1024, pool=True)
|
42 |
+
self.res_4 = nn.Sequential(
|
43 |
+
conv_block(1024, 1024),
|
44 |
+
conv_block(1024, 1024)
|
45 |
+
)
|
46 |
+
self.classifier = nn.Sequential(
|
47 |
+
nn.Flatten(),
|
48 |
+
nn.Linear(2*2*1024, 2048),
|
49 |
+
nn.Dropout(0.5),
|
50 |
+
nn.ReLU(),
|
51 |
+
nn.Linear(2048, 1024),
|
52 |
+
nn.Dropout(0.5),
|
53 |
+
nn.ReLU(),
|
54 |
+
nn.Linear(1024, 2)
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.conv_1(x)
|
59 |
+
x = self.res_1(x) + x
|
60 |
+
x = self.conv_2(x)
|
61 |
+
x = self.res_2(x) + x
|
62 |
+
x = self.conv_3(x)
|
63 |
+
x = self.res_3(x) + x
|
64 |
+
x = self.conv_4(x)
|
65 |
+
x = self.res_4(x) + x
|
66 |
+
x = self.classifier(x)
|
67 |
+
x = F.softmax(x, dim=1)
|
68 |
+
return x
|
main.py → predict.py
RENAMED
@@ -1,12 +1,11 @@
|
|
1 |
import os
|
2 |
import gdown
|
3 |
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
import torchvision.datasets as datasets
|
7 |
import torchvision.transforms as transforms
|
8 |
from torch.utils.data import DataLoader
|
9 |
import time
|
|
|
10 |
|
11 |
# Download model if not available
|
12 |
modelsave_name = 'model_parameters.pt'
|
@@ -28,76 +27,9 @@ else:
|
|
28 |
torch.set_default_device(device)
|
29 |
|
30 |
print(f'\nDevice: {device_name}')
|
31 |
-
|
32 |
-
# Define model
|
33 |
-
def conv_block(in_channels, out_channels, pool=False):
|
34 |
-
layers = [
|
35 |
-
nn.Conv2d(
|
36 |
-
in_channels,
|
37 |
-
out_channels,
|
38 |
-
kernel_size=3,
|
39 |
-
padding=1
|
40 |
-
),
|
41 |
-
nn.BatchNorm2d(out_channels),
|
42 |
-
nn.ReLU()
|
43 |
-
]
|
44 |
-
if pool:
|
45 |
-
layers.append(
|
46 |
-
nn.MaxPool2d(4)
|
47 |
-
)
|
48 |
-
return nn.Sequential(*layers)
|
49 |
-
|
50 |
-
class resnetModel_128(nn.Module):
|
51 |
-
def __init__(self):
|
52 |
-
super().__init__()
|
53 |
-
self.model_name = 'resnetModel_128'
|
54 |
-
|
55 |
-
self.conv_1 = conv_block(1, 64)
|
56 |
-
self.res_1 = nn.Sequential(
|
57 |
-
conv_block(64, 64),
|
58 |
-
conv_block(64, 64)
|
59 |
-
)
|
60 |
-
self.conv_2 = conv_block(64, 256, pool=True)
|
61 |
-
self.res_2 = nn.Sequential(
|
62 |
-
conv_block(256, 256),
|
63 |
-
conv_block(256, 256)
|
64 |
-
)
|
65 |
-
self.conv_3 = conv_block(256, 512, pool=True)
|
66 |
-
self.res_3 = nn.Sequential(
|
67 |
-
conv_block(512, 512),
|
68 |
-
conv_block(512, 512)
|
69 |
-
)
|
70 |
-
self.conv_4 = conv_block(512, 1024, pool=True)
|
71 |
-
self.res_4 = nn.Sequential(
|
72 |
-
conv_block(1024, 1024),
|
73 |
-
conv_block(1024, 1024)
|
74 |
-
)
|
75 |
-
self.classifier = nn.Sequential(
|
76 |
-
nn.Flatten(),
|
77 |
-
nn.Linear(2*2*1024, 2048),
|
78 |
-
nn.Dropout(0.5),
|
79 |
-
nn.ReLU(),
|
80 |
-
nn.Linear(2048, 1024),
|
81 |
-
nn.Dropout(0.5),
|
82 |
-
nn.ReLU(),
|
83 |
-
nn.Linear(1024, 2)
|
84 |
-
)
|
85 |
-
|
86 |
-
def forward(self, x):
|
87 |
-
x = self.conv_1(x)
|
88 |
-
x = self.res_1(x) + x
|
89 |
-
x = self.conv_2(x)
|
90 |
-
x = self.res_2(x) + x
|
91 |
-
x = self.conv_3(x)
|
92 |
-
x = self.res_3(x) + x
|
93 |
-
x = self.conv_4(x)
|
94 |
-
x = self.res_4(x) + x
|
95 |
-
x = self.classifier(x)
|
96 |
-
x = F.softmax(x, dim=1)
|
97 |
-
return x
|
98 |
|
99 |
# Make model and load parameters
|
100 |
-
resnet = resnetModel_128()
|
101 |
resnet.load_state_dict(torch.load(modelsave_name, map_location=device))
|
102 |
resnet.eval()
|
103 |
|
@@ -118,7 +50,7 @@ my_dataset = datasets.ImageFolder(
|
|
118 |
|
119 |
my_dataset_loader = DataLoader(
|
120 |
my_dataset,
|
121 |
-
batch_size=len(my_dataset),
|
122 |
generator=torch.Generator(device=device)
|
123 |
)
|
124 |
|
|
|
1 |
import os
|
2 |
import gdown
|
3 |
import torch
|
|
|
|
|
4 |
import torchvision.datasets as datasets
|
5 |
import torchvision.transforms as transforms
|
6 |
from torch.utils.data import DataLoader
|
7 |
import time
|
8 |
+
import modules.model as model
|
9 |
|
10 |
# Download model if not available
|
11 |
modelsave_name = 'model_parameters.pt'
|
|
|
27 |
torch.set_default_device(device)
|
28 |
|
29 |
print(f'\nDevice: {device_name}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Make model and load parameters
|
32 |
+
resnet = model.resnetModel_128()
|
33 |
resnet.load_state_dict(torch.load(modelsave_name, map_location=device))
|
34 |
resnet.eval()
|
35 |
|
|
|
50 |
|
51 |
my_dataset_loader = DataLoader(
|
52 |
my_dataset,
|
53 |
+
batch_size=min(len(my_dataset), 10),
|
54 |
generator=torch.Generator(device=device)
|
55 |
)
|
56 |
|
train.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gdown
|
3 |
+
import zipfile
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.datasets as datasets
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
import time
|
11 |
+
import modules.model as model
|
12 |
+
|
13 |
+
# Download model if not available
|
14 |
+
# if os.path.exists('celeba/') == False:
|
15 |
+
# url = 'https://drive.google.com/file/d/1_oL160xwrOiF5x56GddAUtOuXe6bIwpL/view?usp=sharing'
|
16 |
+
# output = 'download.zip'
|
17 |
+
# gdown.download(url, output, fuzzy=True)
|
18 |
+
|
19 |
+
# with zipfile.ZipFile(output, 'r') as zip_ref:
|
20 |
+
# zip_ref.extractall()
|
21 |
+
|
22 |
+
# os.remove(output)
|
23 |
+
# shutil.rmtree('__MACOSX')
|
24 |
+
|
25 |
+
# Set device
|
26 |
+
if torch.backends.mps.is_available():
|
27 |
+
device = torch.device('mps')
|
28 |
+
device_name = 'Apple Silicon GPU'
|
29 |
+
elif torch.cuda.is_available():
|
30 |
+
device = torch.device('cuda')
|
31 |
+
device_name = 'CUDA'
|
32 |
+
else:
|
33 |
+
device = torch.device('cpu')
|
34 |
+
device_name = 'CPU'
|
35 |
+
|
36 |
+
torch.set_default_device(device)
|
37 |
+
|
38 |
+
print(f'\nDevice: {device_name}')
|
39 |
+
|
40 |
+
# Define dataset, dataloader and transform
|
41 |
+
imsize = int(128/0.8)
|
42 |
+
batch_size = 10
|
43 |
+
|
44 |
+
fivecrop_transform = transforms.Compose([
|
45 |
+
transforms.Resize([imsize, imsize]),
|
46 |
+
transforms.Grayscale(1),
|
47 |
+
transforms.FiveCrop(int(imsize*0.8)),
|
48 |
+
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
|
49 |
+
transforms.Normalize(0, 1)
|
50 |
+
])
|
51 |
+
|
52 |
+
train_dataset = datasets.CelebA(
|
53 |
+
root='',
|
54 |
+
split='all',
|
55 |
+
target_type='attr',
|
56 |
+
transform=fivecrop_transform,
|
57 |
+
download=True,
|
58 |
+
)
|
59 |
+
|
60 |
+
train_loader = DataLoader(
|
61 |
+
train_dataset,
|
62 |
+
batch_size=batch_size,
|
63 |
+
shuffle=True,
|
64 |
+
generator=torch.Generator(device=device)
|
65 |
+
)
|
66 |
+
|
67 |
+
# Male index
|
68 |
+
factor = 20
|
69 |
+
|
70 |
+
# Define model, optimiser and scheduler
|
71 |
+
torch.manual_seed(2687)
|
72 |
+
resnet = model.resnetModel_128()
|
73 |
+
criterion = nn.CrossEntropyLoss()
|
74 |
+
optimizer = torch.optim.SGD(
|
75 |
+
resnet.parameters(),
|
76 |
+
lr=0.01,
|
77 |
+
momentum=0.9,
|
78 |
+
weight_decay=0.001
|
79 |
+
)
|
80 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
81 |
+
optimizer=optimizer,
|
82 |
+
step_size=1,
|
83 |
+
gamma=0.1
|
84 |
+
)
|
85 |
+
|
86 |
+
def mins_to_hours(mins):
|
87 |
+
hours = int(mins/60)
|
88 |
+
rem_mins = mins % 60
|
89 |
+
return hours, rem_mins
|
90 |
+
|
91 |
+
epochs = 2
|
92 |
+
train_losses = []
|
93 |
+
train_accuracy = []
|
94 |
+
for i in range(epochs):
|
95 |
+
epoch_time = 0
|
96 |
+
|
97 |
+
for j, (X_train, y_train) in enumerate(train_loader):
|
98 |
+
batch_start = time.time()
|
99 |
+
|
100 |
+
X_train = X_train.to(device)
|
101 |
+
y_train = y_train[:, factor]
|
102 |
+
|
103 |
+
bs, ncrops, c, h, w = X_train.size()
|
104 |
+
y_pred_crops = resnet.forward(X_train.view(-1, c, h, w))
|
105 |
+
y_pred = y_pred_crops.view(bs, ncrops, -1).mean(1)
|
106 |
+
|
107 |
+
loss = criterion(y_pred, y_train)
|
108 |
+
|
109 |
+
predicted = torch.max(y_pred.data, 1)[1]
|
110 |
+
train_batch_accuracy = (predicted == y_train).sum()/len(X_train)
|
111 |
+
|
112 |
+
optimizer.zero_grad()
|
113 |
+
loss.backward()
|
114 |
+
optimizer.step()
|
115 |
+
|
116 |
+
train_losses.append(loss.item())
|
117 |
+
train_accuracy.append(train_batch_accuracy.item())
|
118 |
+
|
119 |
+
batch_end = time.time()
|
120 |
+
|
121 |
+
batch_time = batch_end - batch_start
|
122 |
+
epoch_time += batch_time
|
123 |
+
avg_batch_time = epoch_time/(j+1)
|
124 |
+
batches_remaining = len(train_loader)-(j+1)
|
125 |
+
epoch_mins_remaining = round(batches_remaining*avg_batch_time/60)
|
126 |
+
epoch_time_remaining = mins_to_hours(epoch_mins_remaining)
|
127 |
+
|
128 |
+
full_epoch = avg_batch_time*len(train_loader)
|
129 |
+
epochs_remaining = epochs-(i+1)
|
130 |
+
rem_epoch_mins_remaining = epoch_mins_remaining+round(full_epoch*epochs_remaining/60)
|
131 |
+
rem_epoch_time_remaining = mins_to_hours(rem_epoch_mins_remaining)
|
132 |
+
|
133 |
+
if (j+1) % 10 == 0:
|
134 |
+
print(f'\nEpoch: {i+1}/{epochs} | Train Batch: {j+1}/{len(train_loader)}')
|
135 |
+
print(f'Current epoch: {epoch_time_remaining[0]} hours {epoch_time_remaining[1]} minutes')
|
136 |
+
print(f'Remaining epochs: {rem_epoch_time_remaining[0]} hours {rem_epoch_time_remaining[1]} minutes')
|
137 |
+
print(f'Train Loss: {loss}')
|
138 |
+
print(f'Train Accuracy: {train_batch_accuracy}')
|
139 |
+
|
140 |
+
scheduler.step()
|
141 |
+
|
142 |
+
trained_model_name = resnet.model_name + '_epoch_' + str(i+1) + '.pt'
|
143 |
+
torch.save(
|
144 |
+
resnet.state_dict(),
|
145 |
+
trained_model_name
|
146 |
+
)
|