Harry2687 commited on
Commit
35b51f6
·
1 Parent(s): af5e3f0

Added train.py

Browse files
Files changed (4) hide show
  1. .gitignore +4 -1
  2. modules/model.py +68 -0
  3. main.py → predict.py +3 -71
  4. train.py +146 -0
.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  model_parameters.pt
2
- .DS_Store
 
 
 
 
1
  model_parameters.pt
2
+ .DS_Store
3
+ modules/__pycache__
4
+ /celeba
5
+ /.vscode
modules/model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ def conv_block(in_channels, out_channels, pool=False):
5
+ layers = [
6
+ nn.Conv2d(
7
+ in_channels,
8
+ out_channels,
9
+ kernel_size=3,
10
+ padding=1
11
+ ),
12
+ nn.BatchNorm2d(out_channels),
13
+ nn.ReLU()
14
+ ]
15
+ if pool:
16
+ layers.append(
17
+ nn.MaxPool2d(4)
18
+ )
19
+ return nn.Sequential(*layers)
20
+
21
+ class resnetModel_128(nn.Module):
22
+ def __init__(self):
23
+ super().__init__()
24
+ self.model_name = 'resnetModel_128'
25
+
26
+ self.conv_1 = conv_block(1, 64)
27
+ self.res_1 = nn.Sequential(
28
+ conv_block(64, 64),
29
+ conv_block(64, 64)
30
+ )
31
+ self.conv_2 = conv_block(64, 256, pool=True)
32
+ self.res_2 = nn.Sequential(
33
+ conv_block(256, 256),
34
+ conv_block(256, 256)
35
+ )
36
+ self.conv_3 = conv_block(256, 512, pool=True)
37
+ self.res_3 = nn.Sequential(
38
+ conv_block(512, 512),
39
+ conv_block(512, 512)
40
+ )
41
+ self.conv_4 = conv_block(512, 1024, pool=True)
42
+ self.res_4 = nn.Sequential(
43
+ conv_block(1024, 1024),
44
+ conv_block(1024, 1024)
45
+ )
46
+ self.classifier = nn.Sequential(
47
+ nn.Flatten(),
48
+ nn.Linear(2*2*1024, 2048),
49
+ nn.Dropout(0.5),
50
+ nn.ReLU(),
51
+ nn.Linear(2048, 1024),
52
+ nn.Dropout(0.5),
53
+ nn.ReLU(),
54
+ nn.Linear(1024, 2)
55
+ )
56
+
57
+ def forward(self, x):
58
+ x = self.conv_1(x)
59
+ x = self.res_1(x) + x
60
+ x = self.conv_2(x)
61
+ x = self.res_2(x) + x
62
+ x = self.conv_3(x)
63
+ x = self.res_3(x) + x
64
+ x = self.conv_4(x)
65
+ x = self.res_4(x) + x
66
+ x = self.classifier(x)
67
+ x = F.softmax(x, dim=1)
68
+ return x
main.py → predict.py RENAMED
@@ -1,12 +1,11 @@
1
  import os
2
  import gdown
3
  import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
  import torchvision.datasets as datasets
7
  import torchvision.transforms as transforms
8
  from torch.utils.data import DataLoader
9
  import time
 
10
 
11
  # Download model if not available
12
  modelsave_name = 'model_parameters.pt'
@@ -28,76 +27,9 @@ else:
28
  torch.set_default_device(device)
29
 
30
  print(f'\nDevice: {device_name}')
31
-
32
- # Define model
33
- def conv_block(in_channels, out_channels, pool=False):
34
- layers = [
35
- nn.Conv2d(
36
- in_channels,
37
- out_channels,
38
- kernel_size=3,
39
- padding=1
40
- ),
41
- nn.BatchNorm2d(out_channels),
42
- nn.ReLU()
43
- ]
44
- if pool:
45
- layers.append(
46
- nn.MaxPool2d(4)
47
- )
48
- return nn.Sequential(*layers)
49
-
50
- class resnetModel_128(nn.Module):
51
- def __init__(self):
52
- super().__init__()
53
- self.model_name = 'resnetModel_128'
54
-
55
- self.conv_1 = conv_block(1, 64)
56
- self.res_1 = nn.Sequential(
57
- conv_block(64, 64),
58
- conv_block(64, 64)
59
- )
60
- self.conv_2 = conv_block(64, 256, pool=True)
61
- self.res_2 = nn.Sequential(
62
- conv_block(256, 256),
63
- conv_block(256, 256)
64
- )
65
- self.conv_3 = conv_block(256, 512, pool=True)
66
- self.res_3 = nn.Sequential(
67
- conv_block(512, 512),
68
- conv_block(512, 512)
69
- )
70
- self.conv_4 = conv_block(512, 1024, pool=True)
71
- self.res_4 = nn.Sequential(
72
- conv_block(1024, 1024),
73
- conv_block(1024, 1024)
74
- )
75
- self.classifier = nn.Sequential(
76
- nn.Flatten(),
77
- nn.Linear(2*2*1024, 2048),
78
- nn.Dropout(0.5),
79
- nn.ReLU(),
80
- nn.Linear(2048, 1024),
81
- nn.Dropout(0.5),
82
- nn.ReLU(),
83
- nn.Linear(1024, 2)
84
- )
85
-
86
- def forward(self, x):
87
- x = self.conv_1(x)
88
- x = self.res_1(x) + x
89
- x = self.conv_2(x)
90
- x = self.res_2(x) + x
91
- x = self.conv_3(x)
92
- x = self.res_3(x) + x
93
- x = self.conv_4(x)
94
- x = self.res_4(x) + x
95
- x = self.classifier(x)
96
- x = F.softmax(x, dim=1)
97
- return x
98
 
99
  # Make model and load parameters
100
- resnet = resnetModel_128()
101
  resnet.load_state_dict(torch.load(modelsave_name, map_location=device))
102
  resnet.eval()
103
 
@@ -118,7 +50,7 @@ my_dataset = datasets.ImageFolder(
118
 
119
  my_dataset_loader = DataLoader(
120
  my_dataset,
121
- batch_size=len(my_dataset),
122
  generator=torch.Generator(device=device)
123
  )
124
 
 
1
  import os
2
  import gdown
3
  import torch
 
 
4
  import torchvision.datasets as datasets
5
  import torchvision.transforms as transforms
6
  from torch.utils.data import DataLoader
7
  import time
8
+ import modules.model as model
9
 
10
  # Download model if not available
11
  modelsave_name = 'model_parameters.pt'
 
27
  torch.set_default_device(device)
28
 
29
  print(f'\nDevice: {device_name}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Make model and load parameters
32
+ resnet = model.resnetModel_128()
33
  resnet.load_state_dict(torch.load(modelsave_name, map_location=device))
34
  resnet.eval()
35
 
 
50
 
51
  my_dataset_loader = DataLoader(
52
  my_dataset,
53
+ batch_size=min(len(my_dataset), 10),
54
  generator=torch.Generator(device=device)
55
  )
56
 
train.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+ import zipfile
4
+ import shutil
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.datasets as datasets
8
+ import torchvision.transforms as transforms
9
+ from torch.utils.data import DataLoader
10
+ import time
11
+ import modules.model as model
12
+
13
+ # Download model if not available
14
+ # if os.path.exists('celeba/') == False:
15
+ # url = 'https://drive.google.com/file/d/1_oL160xwrOiF5x56GddAUtOuXe6bIwpL/view?usp=sharing'
16
+ # output = 'download.zip'
17
+ # gdown.download(url, output, fuzzy=True)
18
+
19
+ # with zipfile.ZipFile(output, 'r') as zip_ref:
20
+ # zip_ref.extractall()
21
+
22
+ # os.remove(output)
23
+ # shutil.rmtree('__MACOSX')
24
+
25
+ # Set device
26
+ if torch.backends.mps.is_available():
27
+ device = torch.device('mps')
28
+ device_name = 'Apple Silicon GPU'
29
+ elif torch.cuda.is_available():
30
+ device = torch.device('cuda')
31
+ device_name = 'CUDA'
32
+ else:
33
+ device = torch.device('cpu')
34
+ device_name = 'CPU'
35
+
36
+ torch.set_default_device(device)
37
+
38
+ print(f'\nDevice: {device_name}')
39
+
40
+ # Define dataset, dataloader and transform
41
+ imsize = int(128/0.8)
42
+ batch_size = 10
43
+
44
+ fivecrop_transform = transforms.Compose([
45
+ transforms.Resize([imsize, imsize]),
46
+ transforms.Grayscale(1),
47
+ transforms.FiveCrop(int(imsize*0.8)),
48
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
49
+ transforms.Normalize(0, 1)
50
+ ])
51
+
52
+ train_dataset = datasets.CelebA(
53
+ root='',
54
+ split='all',
55
+ target_type='attr',
56
+ transform=fivecrop_transform,
57
+ download=True,
58
+ )
59
+
60
+ train_loader = DataLoader(
61
+ train_dataset,
62
+ batch_size=batch_size,
63
+ shuffle=True,
64
+ generator=torch.Generator(device=device)
65
+ )
66
+
67
+ # Male index
68
+ factor = 20
69
+
70
+ # Define model, optimiser and scheduler
71
+ torch.manual_seed(2687)
72
+ resnet = model.resnetModel_128()
73
+ criterion = nn.CrossEntropyLoss()
74
+ optimizer = torch.optim.SGD(
75
+ resnet.parameters(),
76
+ lr=0.01,
77
+ momentum=0.9,
78
+ weight_decay=0.001
79
+ )
80
+ scheduler = torch.optim.lr_scheduler.StepLR(
81
+ optimizer=optimizer,
82
+ step_size=1,
83
+ gamma=0.1
84
+ )
85
+
86
+ def mins_to_hours(mins):
87
+ hours = int(mins/60)
88
+ rem_mins = mins % 60
89
+ return hours, rem_mins
90
+
91
+ epochs = 2
92
+ train_losses = []
93
+ train_accuracy = []
94
+ for i in range(epochs):
95
+ epoch_time = 0
96
+
97
+ for j, (X_train, y_train) in enumerate(train_loader):
98
+ batch_start = time.time()
99
+
100
+ X_train = X_train.to(device)
101
+ y_train = y_train[:, factor]
102
+
103
+ bs, ncrops, c, h, w = X_train.size()
104
+ y_pred_crops = resnet.forward(X_train.view(-1, c, h, w))
105
+ y_pred = y_pred_crops.view(bs, ncrops, -1).mean(1)
106
+
107
+ loss = criterion(y_pred, y_train)
108
+
109
+ predicted = torch.max(y_pred.data, 1)[1]
110
+ train_batch_accuracy = (predicted == y_train).sum()/len(X_train)
111
+
112
+ optimizer.zero_grad()
113
+ loss.backward()
114
+ optimizer.step()
115
+
116
+ train_losses.append(loss.item())
117
+ train_accuracy.append(train_batch_accuracy.item())
118
+
119
+ batch_end = time.time()
120
+
121
+ batch_time = batch_end - batch_start
122
+ epoch_time += batch_time
123
+ avg_batch_time = epoch_time/(j+1)
124
+ batches_remaining = len(train_loader)-(j+1)
125
+ epoch_mins_remaining = round(batches_remaining*avg_batch_time/60)
126
+ epoch_time_remaining = mins_to_hours(epoch_mins_remaining)
127
+
128
+ full_epoch = avg_batch_time*len(train_loader)
129
+ epochs_remaining = epochs-(i+1)
130
+ rem_epoch_mins_remaining = epoch_mins_remaining+round(full_epoch*epochs_remaining/60)
131
+ rem_epoch_time_remaining = mins_to_hours(rem_epoch_mins_remaining)
132
+
133
+ if (j+1) % 10 == 0:
134
+ print(f'\nEpoch: {i+1}/{epochs} | Train Batch: {j+1}/{len(train_loader)}')
135
+ print(f'Current epoch: {epoch_time_remaining[0]} hours {epoch_time_remaining[1]} minutes')
136
+ print(f'Remaining epochs: {rem_epoch_time_remaining[0]} hours {rem_epoch_time_remaining[1]} minutes')
137
+ print(f'Train Loss: {loss}')
138
+ print(f'Train Accuracy: {train_batch_accuracy}')
139
+
140
+ scheduler.step()
141
+
142
+ trained_model_name = resnet.model_name + '_epoch_' + str(i+1) + '.pt'
143
+ torch.save(
144
+ resnet.state_dict(),
145
+ trained_model_name
146
+ )