soumickmj commited on
Commit
0a4a687
·
verified ·
1 Parent(s): 7cdc1a3

Upload GPReconResNet

Browse files
Files changed (7) hide show
  1. GPModelConfigs.py +86 -0
  2. GPModels.py +64 -0
  3. GP_ReconResNet.py +270 -0
  4. GP_ShuffleUNet.py +187 -0
  5. GP_UNet.py +189 -0
  6. config.json +25 -0
  7. model.safetensors +3 -0
GPModelConfigs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class GPUNetConfig(PretrainedConfig):
4
+ model_type = "GPUNet"
5
+ def __init__(
6
+ self,
7
+ in_channels=1,
8
+ n_classes=3,
9
+ depth=3,
10
+ wf=6,
11
+ padding=True,
12
+ batch_norm=False,
13
+ up_mode="sinc",
14
+ dropout=True,
15
+ Relu="Relu",
16
+ out_act="None",
17
+ **kwargs):
18
+ self.in_channels = in_channels
19
+ self.n_classes = n_classes
20
+ self.depth = depth
21
+ self.wf = wf
22
+ self.padding = padding
23
+ self.batch_norm = batch_norm
24
+ self.up_mode = up_mode
25
+ self.dropout = dropout
26
+ self.Relu = Relu
27
+ self.out_act = out_act
28
+ super().__init__(**kwargs)
29
+
30
+ class GPReconResNetConfig(PretrainedConfig):
31
+ model_type = "GPReconResNet"
32
+ def __init__(
33
+ self,
34
+ in_channels=1,
35
+ n_classes=3,
36
+ res_blocks=14,
37
+ starting_nfeatures=64,
38
+ updown_blocks=2,
39
+ is_relu_leaky=True,
40
+ do_batchnorm=False,
41
+ res_drop_prob=0.5,
42
+ out_act="None",
43
+ forwardV=0,
44
+ upinterp_algo='sinc',
45
+ post_interp_convtrans=False,
46
+ is3D=False,
47
+ **kwargs):
48
+ self.in_channels = in_channels
49
+ self.n_classes = n_classes
50
+ self.res_blocks = res_blocks
51
+ self.starting_nfeatures = starting_nfeatures
52
+ self.updown_blocks = updown_blocks
53
+ self.is_relu_leaky = is_relu_leaky
54
+ self.do_batchnorm = do_batchnorm
55
+ self.res_drop_prob = res_drop_prob
56
+ self.out_act = out_act
57
+ self.forwardV = forwardV
58
+ self.upinterp_algo = upinterp_algo
59
+ self.post_interp_convtrans = post_interp_convtrans
60
+ self.is3D = is3D
61
+ super().__init__(**kwargs)
62
+
63
+ class GPShuffleUNetConfig(PretrainedConfig):
64
+ model_type = "GPShuffleUNet"
65
+ def __init__(
66
+ self,
67
+ d=2,
68
+ in_ch=1,
69
+ num_features=64,
70
+ n_levels=3,
71
+ out_ch=3,
72
+ kernel_size=3,
73
+ stride=1,
74
+ dropout=True,
75
+ out_act="None",
76
+ **kwargs):
77
+ self.d = d
78
+ self.in_ch = in_ch
79
+ self.num_features = num_features
80
+ self.n_levels = n_levels
81
+ self.out_ch = out_ch
82
+ self.kernel_size = kernel_size
83
+ self.stride = stride
84
+ self.dropout = dropout
85
+ self.out_act = out_act
86
+ super().__init__(**kwargs)
GPModels.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+
3
+ from .GP_UNet import GP_UNet
4
+ from .GP_ReconResNet import GP_ReconResNet
5
+ from .GP_ShuffleUNet import GP_ShuffleUNet
6
+
7
+ from .GPModelConfigs import GPUNetConfig, GPReconResNetConfig, GPShuffleUNetConfig
8
+
9
+ class GPUNet(PreTrainedModel):
10
+ config_class = GPUNetConfig
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.model = GP_UNet(
14
+ in_channels=config.in_channels,
15
+ n_classes=config.n_classes,
16
+ depth=config.depth,
17
+ wf=config.wf,
18
+ padding=config.padding,
19
+ batch_norm=config.batch_norm,
20
+ up_mode=config.up_mode,
21
+ dropout=config.dropout,
22
+ Relu=config.Relu,
23
+ out_act=config.out_act)
24
+ def forward(self, x):
25
+ return self.model(x)
26
+
27
+
28
+ class GPReconResNet(PreTrainedModel):
29
+ config_class = GPReconResNetConfig
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.model = GP_ReconResNet(
33
+ in_channels=config.in_channels,
34
+ n_classes=config.n_classes,
35
+ res_blocks=config.res_blocks,
36
+ starting_nfeatures=config.starting_nfeatures,
37
+ updown_blocks=config.updown_blocks,
38
+ is_relu_leaky=config.is_relu_leaky,
39
+ do_batchnorm=config.do_batchnorm,
40
+ res_drop_prob=config.res_drop_prob,
41
+ out_act=config.out_act,
42
+ forwardV=config.forwardV,
43
+ upinterp_algo=config.upinterp_algo,
44
+ post_interp_convtrans=config.post_interp_convtrans,
45
+ is3D=config.is3D)
46
+ def forward(self, x):
47
+ return self.model(x)
48
+
49
+ class GPShuffleUNet(PreTrainedModel):
50
+ config_class = GPShuffleUNetConfig
51
+ def __init__(self, config):
52
+ super().__init__(config)
53
+ self.model = GP_ShuffleUNet(
54
+ d=config.d,
55
+ in_ch=config.in_ch,
56
+ num_features=config.num_features,
57
+ n_levels=config.n_levels,
58
+ out_ch=config.out_ch,
59
+ kernel_size=config.kernel_size,
60
+ stride=config.stride,
61
+ dropout=config.dropout,
62
+ out_act=config.out_act)
63
+ def forward(self, x):
64
+ return self.model(x)
GP_ReconResNet.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://raw.githubusercontent.com/soumickmj/NCC1701/main/Bridge/models/ResNet/MickResNet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import sys
7
+ import torch.nn.functional as F
8
+ from tricorder.torch.transforms import Interpolator
9
+
10
+ __author__ = "Soumick Chatterjee"
11
+ __copyright__ = "Copyright 2019, Soumick Chatterjee & OvGU:ESF:MEMoRIAL"
12
+ __credits__ = ["Soumick Chatterjee"]
13
+
14
+ __license__ = "GPL"
15
+ __version__ = "1.0.0"
16
+ __email__ = "[email protected]"
17
+ __status__ = "Published"
18
+
19
+ class ResidualBlock(nn.Module):
20
+ def __init__(self, in_features, drop_prob=0.2): #drop_prob=0.2
21
+ super(ResidualBlock, self).__init__()
22
+
23
+ conv_block = [ layer_pad(1),
24
+ layer_conv(in_features, in_features, 3),
25
+ layer_norm(in_features),
26
+ act_relu(),
27
+ layer_drop(p=drop_prob, inplace=True),
28
+ layer_pad(1),
29
+ layer_conv(in_features, in_features, 3) ,
30
+ layer_norm(in_features) ]
31
+
32
+ self.conv_block = nn.Sequential(*conv_block)
33
+
34
+ def forward(self, x):
35
+ return x + self.conv_block(x)
36
+
37
+ class DownsamplingBlock(nn.Module):
38
+ def __init__(self, in_features, out_features):
39
+ super(DownsamplingBlock, self).__init__()
40
+
41
+ conv_block = [ layer_conv(in_features, out_features, 3, stride=2, padding=1),
42
+ layer_norm(out_features),
43
+ act_relu() ]
44
+ self.conv_block = nn.Sequential(*conv_block)
45
+
46
+ def forward(self, x):
47
+ return self.conv_block(x)
48
+
49
+ class UpsamplingBlock(nn.Module):
50
+ def __init__(self, in_features, out_features, mode="upconv", interpolator=None, post_interp_convtrans=False):
51
+ super(UpsamplingBlock, self).__init__()
52
+
53
+ self.interpolator = interpolator
54
+ self.mode = mode
55
+ self.post_interp_convtrans = post_interp_convtrans
56
+ if self.post_interp_convtrans:
57
+ self.post_conv = layer_conv(out_features, out_features, 1)
58
+
59
+ if mode == "upconv":
60
+ conv_block = [ layer_convtrans(in_features, out_features, 3, stride=2, padding=1, output_padding=1), ]
61
+ else:
62
+ conv_block = [ layer_pad(1),
63
+ layer_conv(in_features, out_features, 3), ]
64
+ conv_block += [ layer_norm(out_features),
65
+ act_relu() ]
66
+ self.conv_block = nn.Sequential(*conv_block)
67
+
68
+ def forward(self, x, out_shape=None):
69
+ if self.mode == "upconv":
70
+ if self.post_interp_convtrans:
71
+ x = self.conv_block(x)
72
+ if x.shape[2:] != out_shape:
73
+ return self.post_conv(self.interpolator(x, out_shape))
74
+ else:
75
+ return x
76
+ else:
77
+ return self.conv_block(x)
78
+ else:
79
+ return self.conv_block(self.interpolator(x, out_shape))
80
+
81
+ class GP_ReconResNet(nn.Module):
82
+ def __init__(self, in_channels=1, n_classes=1, res_blocks=14, starting_nfeatures=64, updown_blocks=2, is_relu_leaky=True, do_batchnorm=False, res_drop_prob=0.2, #res_drop_prob=0.2
83
+ out_act="softmax", forwardV=0, upinterp_algo='upconv', post_interp_convtrans=False, is3D=False): #should use 14 as that gives number of trainable parameters close to number of possible pixel values in a image 256x256
84
+ super(GP_ReconResNet, self).__init__()
85
+
86
+ layers = {}
87
+ if is3D:
88
+ sys.exit("ResNet: for implemented for 3D, ReflectionPad3d code is required")
89
+ layers["layer_conv"] = nn.Conv3d
90
+ layers["layer_convtrans"] = nn.ConvTranspose3d
91
+ if do_batchnorm:
92
+ layers["layer_norm"] = nn.BatchNorm3d
93
+ else:
94
+ layers["layer_norm"] = nn.InstanceNorm3d
95
+ layers["layer_drop"] = nn.Dropout3d
96
+ layers["layer_pad"] = ReflectionPad3d
97
+ layers["interp_mode"] = 'trilinear'
98
+ else:
99
+ layers["layer_conv"] = nn.Conv2d
100
+ layers["layer_convtrans"] = nn.ConvTranspose2d
101
+ if do_batchnorm:
102
+ layers["layer_norm"] = nn.BatchNorm2d
103
+ else:
104
+ layers["layer_norm"] = nn.InstanceNorm2d
105
+ layers["layer_drop"] = nn.Dropout2d
106
+ layers["layer_pad"] = nn.ReflectionPad2d
107
+ layers["interp_mode"] = 'bilinear'
108
+ if is_relu_leaky:
109
+ layers["act_relu"] = nn.PReLU
110
+ else:
111
+ layers["act_relu"] = nn.ReLU
112
+ globals().update(layers)
113
+
114
+ self.forwardV = forwardV
115
+ self.upinterp_algo = upinterp_algo
116
+
117
+ interpolator = Interpolator(mode=layers["interp_mode"] if self.upinterp_algo == "upconv" else self.upinterp_algo)
118
+
119
+ in_channels = in_channels
120
+ out_channels = n_classes
121
+ # Initial convolution block
122
+ intialConv = [ layer_pad(3),
123
+ layer_conv(in_channels, starting_nfeatures, 7),
124
+ layer_norm(starting_nfeatures),
125
+ act_relu() ]
126
+
127
+ # Downsampling [need to save the shape for upsample]
128
+ downsam = []
129
+ in_features = starting_nfeatures
130
+ out_features = in_features*2
131
+ for _ in range(updown_blocks):
132
+ downsam.append(DownsamplingBlock(in_features, out_features))
133
+ in_features = out_features
134
+ out_features = in_features*2
135
+
136
+ # Residual blocks
137
+ resblocks = []
138
+ for _ in range(res_blocks):
139
+ resblocks += [ResidualBlock(in_features, res_drop_prob)]
140
+
141
+ # Upsampling
142
+ upsam = []
143
+ out_features = in_features//2
144
+ for _ in range(updown_blocks):
145
+ upsam.append(UpsamplingBlock(in_features, out_features, self.upinterp_algo, interpolator, post_interp_convtrans))
146
+ in_features = out_features
147
+ out_features = in_features//2
148
+
149
+ # Output layer
150
+ finalconv = [ layer_conv(starting_nfeatures, out_channels, 1), ] #kernel size changed from 7 to 1 to make GMP work
151
+
152
+ if out_act == "sigmoid":
153
+ finalconv += [ nn.Sigmoid(), ]
154
+ elif out_act == "relu":
155
+ finalconv += [ act_relu(), ]
156
+ elif out_act == "tanh":
157
+ finalconv += [ nn.Tanh(), ]
158
+ elif out_act == "softmax":
159
+ finalconv += [ nn.Softmax2d(), ]
160
+
161
+
162
+ self.intialConv = nn.Sequential(*intialConv)
163
+ self.downsam = nn.ModuleList(downsam)
164
+ self.resblocks = nn.Sequential(*resblocks)
165
+ self.upsam = nn.ModuleList(upsam)
166
+ self.finalconv = nn.Sequential(*finalconv)
167
+
168
+ ### For Classification, following Florian's GP-UNet
169
+ self.GMP = nn.AdaptiveMaxPool2d((1, 1))
170
+
171
+ if self.forwardV == 0:
172
+ self.forward = self.forwardV0
173
+ elif self.forwardV == 1:
174
+ sys.exit("ResNet: its identical to V0 in case of GP_ResNet")
175
+ elif self.forwardV == 2:
176
+ self.forward = self.forwardV2
177
+ elif self.forwardV == 3:
178
+ self.forward = self.forwardV3
179
+ elif self.forwardV == 4:
180
+ self.forward = self.forwardV4
181
+ elif self.forwardV == 5:
182
+ self.forward = self.forwardV5
183
+
184
+ def final_step(self, x):
185
+ if self.training:
186
+ x = self.GMP(x)
187
+ return self.finalconv(x).view(x.shape[0],-1)
188
+ else:
189
+ mask = self.finalconv(x)
190
+ x = self.GMP(x)
191
+ pred = self.finalconv(x).view(x.shape[0],-1)
192
+ return pred, mask
193
+
194
+ def forwardV0(self, x):
195
+ #v0: Original Version
196
+ x = self.intialConv(x)
197
+ shapes = []
198
+ for downblock in self.downsam:
199
+ shapes.append(x.shape[2:])
200
+ x = downblock(x)
201
+ x = self.resblocks(x)
202
+ for i, upblock in enumerate(self.upsam):
203
+ x = upblock(x, shapes[-1-i])
204
+ return self.final_step(x)
205
+
206
+ def forwardV2(self, x):
207
+ #v2: residual of v1 + input to the residual blocks added back with the output
208
+ out = self.intialConv(x)
209
+ shapes = []
210
+ for downblock in self.downsam:
211
+ shapes.append(out.shape[2:])
212
+ out = downblock(out)
213
+ out = out + self.resblocks(out)
214
+ for i, upblock in enumerate(self.upsam):
215
+ out = upblock(out, shapes[-1-i])
216
+ return self.final_step(out)
217
+
218
+ def forwardV3(self, x):
219
+ #v3: residual of v2 + input of the initial conv added back with the output
220
+ out = x + self.intialConv(x)
221
+ shapes = []
222
+ for downblock in self.downsam:
223
+ shapes.append(out.shape[2:])
224
+ out = downblock(out)
225
+ out = out + self.resblocks(out)
226
+ for i, upblock in enumerate(self.upsam):
227
+ out = upblock(out, shapes[-1-i])
228
+ return self.final_step(out)
229
+
230
+ def forwardV4(self, x):
231
+ #v4: residual of v3 + output of the initial conv added back with the input of final conv
232
+ iniconv = x + self.intialConv(x)
233
+ shapes = []
234
+ if len(self.downsam) > 0:
235
+ for i, downblock in enumerate(self.downsam):
236
+ if i == 0:
237
+ shapes.append(iniconv.shape[2:])
238
+ out = downblock(iniconv)
239
+ else:
240
+ shapes.append(out.shape[2:])
241
+ out = downblock(out)
242
+ else:
243
+ out = iniconv
244
+ out = out + self.resblocks(out)
245
+ for i, upblock in enumerate(self.upsam):
246
+ out = upblock(out, shapes[-1-i])
247
+ out = iniconv + out
248
+ return self.final_step(out)
249
+
250
+ def forwardV5(self, x):
251
+ #v5: residual of v4 + individual down blocks with individual up blocks
252
+ outs = [x + self.intialConv(x)]
253
+ shapes = []
254
+ for i, downblock in enumerate(self.downsam):
255
+ shapes.append(outs[-1].shape[2:])
256
+ outs.append(downblock(outs[-1]))
257
+ outs[-1] = outs[-1] + self.resblocks(outs[-1])
258
+ for i, upblock in enumerate(self.upsam):
259
+ outs[-1] = upblock(outs[-1], shapes[-1-i])
260
+ outs[-1] = outs[-2] + outs.pop()
261
+ return self.final_step(outs.pop())
262
+
263
+ #to run it here from this script, uncomment the following
264
+
265
+ if __name__ == "__main__": #to run it
266
+ image = torch.rand(2, 1, 240, 240) #specify your image: batch size, Channel, height, width
267
+ model = GP_ReconResNet(in_channels=1, n_classes=3, upinterp_algo='sinc') #Initialize the model
268
+ # model.eval()
269
+ out = model(image)
270
+ print(model(image))
GP_ShuffleUNet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##Dropout and out_act was added by hadya
2
+
3
+ import sys
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from . import GP_ShuffleUNet_pixel_shuffle, GP_ShuffleUNet_pixel_unshuffle
8
+ # import pixel_shuffle, pixel_unshuffle
9
+
10
+ # -------------------------------------------------------------------------------------------------------------------------------------------------##
11
+
12
+ class _double_conv(nn.Module):
13
+ """
14
+ Double Convolution Block
15
+ """
16
+
17
+ def __init__(self, in_channels, out_channels, k_size, stride, bias=True, conv_layer=nn.Conv3d):
18
+ super(_double_conv, self).__init__()
19
+ self.conv_1 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
20
+ stride=stride, padding=k_size // 2, bias=bias)
21
+ self.conv_2 = conv_layer(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
22
+ stride=stride, padding=k_size // 2, bias=bias)
23
+
24
+ self.relu = nn.ReLU(inplace=True)
25
+
26
+ def forward(self, x):
27
+ x = self.conv_1(x)
28
+ x = self.relu((x))
29
+ x = self.conv_2(x)
30
+ x = self.relu((x))
31
+
32
+ return x
33
+
34
+
35
+ class _conv_decomp(nn.Module):
36
+ """
37
+ Convolutional Decomposition Block
38
+ """
39
+
40
+ def __init__(self, in_channels, out_channels, k_size, stride, bias=True, conv_layer=nn.Conv3d):
41
+ super(_conv_decomp, self).__init__()
42
+ self.conv1 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
43
+ stride=stride, padding=k_size // 2, bias=bias)
44
+ self.conv2 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
45
+ stride=stride, padding=k_size // 2, bias=bias)
46
+ self.conv3 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
47
+ stride=stride, padding=k_size // 2, bias=bias)
48
+ self.conv4 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
49
+ stride=stride, padding=k_size // 2, bias=bias)
50
+ self.relu = nn.ReLU(inplace=True)
51
+
52
+ def forward(self, x):
53
+ x1 = self.conv1(x)
54
+ x1 = self.relu((x1))
55
+ x2 = self.conv2(x)
56
+ x2 = self.relu((x2))
57
+ x3 = self.conv3(x)
58
+ x3 = self.relu((x3))
59
+ x4 = self.conv4(x)
60
+ x4 = self.relu((x4))
61
+ return x1, x2, x3, x4
62
+
63
+
64
+ class _concat(nn.Module):
65
+ """
66
+ Skip-Addition block
67
+ """
68
+
69
+ def __init__(self):
70
+ super(_concat, self).__init__()
71
+
72
+ def forward(self, e1, e2, e3, e4, d1, d2, d3, d4):
73
+ self.X1 = e1 + d1
74
+ self.X2 = e2 + d2
75
+ self.X3 = e3 + d3
76
+ self.X4 = e4 + d4
77
+ x = torch.cat([self.X1, self.X2, self.X3, self.X4], dim=1)
78
+
79
+ return x
80
+
81
+ # -------------------------------------------------------------------------------------------------------------------------------------------------##
82
+
83
+ class GP_ShuffleUNet(nn.Module):
84
+
85
+ def __init__(self, d=3, in_ch=1, num_features=64, n_levels=3, out_ch=1, kernel_size=3, stride=1, dropout=False, out_act="softmax"):
86
+ super(GP_ShuffleUNet, self).__init__()
87
+
88
+ self.n_levels = n_levels
89
+ self.dropout = nn.Dropout2d() if dropout else nn.Sequential() #added by Hadya
90
+
91
+ num_features = num_features
92
+ filters = [num_features]
93
+ for _ in range(n_levels):
94
+ filters.append(filters[-1]*2)
95
+
96
+ if d==3:
97
+ conv_layer = nn.Conv3d
98
+ ps_fact = (2 ** 2)
99
+ elif d==2:
100
+ conv_layer = nn.Conv2d
101
+ ps_fact = 2
102
+ else:
103
+ sys.exit("Invalid d")
104
+
105
+ # Input
106
+ self.conv_inp = _double_conv(in_ch, filters[0], kernel_size, stride, conv_layer=conv_layer)
107
+
108
+ #Contraction path
109
+ self.wave_down = nn.ModuleList()
110
+ self.pix_unshuff = nn.ModuleList()
111
+ self.conv_enc = nn.ModuleList()
112
+ for i in range(0, n_levels):
113
+ self.wave_down.append(_conv_decomp(filters[i], filters[i], kernel_size, stride, conv_layer=conv_layer))
114
+ self.pix_unshuff.append(pixel_unshuffle.PixelUnshuffle(num_features * (2**i), num_features * (2**i), kernel_size, stride, d=d))
115
+ self.conv_enc.append(_double_conv(filters[i], filters[i+1], kernel_size, stride, conv_layer=conv_layer))
116
+
117
+ #Expansion path
118
+ self.cat = _concat()
119
+ self.pix_shuff = nn.ModuleList()
120
+ self.wave_up = nn.ModuleList()
121
+ self.convup = nn.ModuleList()
122
+ for i in range(n_levels-1,-1,-1):
123
+ self.pix_shuff.append(pixel_shuffle.PixelShuffle(num_features * (2**(i+1)), num_features * (2**(i+1)) * ps_fact, kernel_size, stride, d=d))
124
+ self.wave_up.append(_conv_decomp(filters[i], filters[i], kernel_size, stride, conv_layer=conv_layer))
125
+ self.convup.append(_double_conv(filters[i] * 5, filters[i], kernel_size, stride, conv_layer=conv_layer))
126
+
127
+ ###added For Classification, following Florian's GP-UNet
128
+ self.GMP = nn.AdaptiveMaxPool2d((1, 1))
129
+
130
+ #FC
131
+ if out_act == "softmax": #added by Hadya
132
+ self.last = nn.Sequential(
133
+ conv_layer(filters[0], out_ch, kernel_size=1, stride=1, padding=0, bias=True),
134
+ nn.Softmax2d()
135
+ )
136
+ else:
137
+ self.out = conv_layer(filters[0], out_ch, kernel_size=1, stride=1, padding=0, bias=True) #original
138
+
139
+ #Weight init
140
+ for m in self.modules():
141
+ if isinstance(m, conv_layer):
142
+ weight = nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
143
+ m.weight.data.copy_(weight)
144
+ if m.bias is not None:
145
+ m.bias.data.zero_()
146
+
147
+ def forward(self, x):
148
+ encs = [self.conv_inp(x)]
149
+
150
+ waves = []
151
+ for i in range(self.n_levels):
152
+ waves.append(self.wave_down[i](encs[-1]))
153
+ _tmp = self.pix_unshuff[i](waves[-1][-1])
154
+ encs.append(self.conv_enc[i](_tmp))
155
+
156
+ dec = encs.pop()
157
+
158
+ dec = self.dropout(dec) #added by hadya
159
+
160
+ for i in range(self.n_levels):
161
+ _tmp = self.pix_shuff[i](dec)
162
+ _tmp_waves = self.wave_up[i](_tmp) + waves.pop()
163
+ _tmp_cat = self.cat(*_tmp_waves)
164
+ dec = self.convup[i](torch.cat([encs.pop(), _tmp_cat], dim=1))
165
+
166
+ ###added section to make it GP-UNet
167
+ if self.training:
168
+ x = self.GMP(dec)
169
+ return self.out(x).view(x.shape[0],-1)
170
+ else:
171
+ mask = self.out(dec)
172
+ x = self.GMP(dec)
173
+ pred = self.out(x).view(x.shape[0],-1)
174
+ return pred, mask
175
+
176
+
177
+ # return self.out(dec) #####replace by line 154-161 to make it GP_ShuffleUNet
178
+
179
+
180
+ #to run it here from this script, uncomment the following
181
+
182
+ if __name__ == "__main__": #to run it
183
+ image = torch.rand(2, 1, 512, 512) #specify your image: batch size, Channel, height, width
184
+ model = GP_ShuffleUNet(d=2, in_ch=1, num_features=64, n_levels=3, out_ch=3, kernel_size=3, stride=1) #Initialize the model, d=3 default is for dimensionality conv2d or 3d, default out channel = 1 but in gp we need 3
185
+ model.eval()
186
+ out = model(image)
187
+ print(model(image))
GP_UNet.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/soumickmj/FTSuperResDynMRI/blob/main/models/unet2d.py
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ import torchcomplex.nn.functional as cF
7
+
8
+
9
+
10
+ __author__ = "Soumick Chatterjee"
11
+ __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
12
+ __credits__ = ["Soumick Chatterjee"]
13
+ __license__ = "GPL"
14
+ __version__ = "1.0.0"
15
+ __maintainer__ = "Soumick Chatterjee"
16
+ __email__ = "[email protected]"
17
+ __status__ = "Production"
18
+
19
+
20
+ class GP_UNet(nn.Module):
21
+ """
22
+ Implementation of
23
+ U-Net: Convolutional Networks for Biomedical Image Segmentation
24
+ (Ronneberger et al., 2015)
25
+ https://arxiv.org/abs/1505.04597
26
+
27
+ Using the default arguments will yield the exact version used
28
+ in the original paper
29
+
30
+ Args:
31
+ in_channels (int): number of input channels
32
+ n_classes (int): number of output channels
33
+ depth (int): depth of the network
34
+ wf (int): number of filters in the first layer is 2**wf
35
+ padding (bool): if True, apply padding such that the input shape
36
+ is the same as the output.
37
+ This may introduce artifacts
38
+ batch_norm (bool): Use BatchNorm after layers with an
39
+ activation function
40
+ up_mode (str): one of 'upconv' or 'upsample'.
41
+ 'upconv' will use transposed convolutions for
42
+ learned upsampling.
43
+ 'upsample_Bi' will use bilinear upsampling.
44
+ 'upsample_Sinc' will use sinc upsampling.
45
+ """
46
+ def __init__(self, in_channels=1, n_classes=1, depth=3, wf=6, padding=True,
47
+ batch_norm=False, up_mode='upconv', dropout=False, Relu = "Relu", out_act="None"): #dropout=False
48
+ super(GP_UNet, self).__init__()
49
+ assert up_mode in ('upconv', 'bilinear', 'sinc', "upsample_Sinc")
50
+ assert out_act in ("softmax", "None", "sigmoid", "relu")
51
+ self.padding = padding
52
+ self.depth = depth
53
+ self.Relu = Relu
54
+ self.dropout = nn.Dropout2d() if dropout else nn.Sequential()
55
+ prev_channels = in_channels
56
+ self.down_path = nn.ModuleList()
57
+ for i in range(depth):
58
+ self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
59
+ padding, batch_norm, Relu))
60
+ prev_channels = 2**(wf+i)
61
+
62
+ self.up_path = nn.ModuleList()
63
+ for i in reversed(range(depth - 1)):
64
+ self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
65
+ padding, batch_norm, Relu))
66
+ prev_channels = 2**(wf+i)
67
+
68
+ if out_act == "softmax":
69
+ self.last = nn.Sequential(
70
+ nn.Conv2d(prev_channels, n_classes, kernel_size=1),
71
+ nn.Softmax2d()
72
+ )
73
+
74
+ elif out_act == "sigmoid":
75
+ self.last = nn.Sequential(
76
+ nn.Conv2d(prev_channels, n_classes, kernel_size=1),
77
+ nn.Sigmoid()
78
+ )
79
+
80
+ elif out_act == "relu":
81
+ self.last = nn.Sequential(
82
+ nn.Conv2d(prev_channels, n_classes, kernel_size=1),
83
+ nn.ReLU()
84
+ )
85
+
86
+ else:
87
+ self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
88
+
89
+
90
+ ### For Classification, following Florian's GP-UNet
91
+ self.GMP = nn.AdaptiveMaxPool2d((1, 1))
92
+
93
+ def forward(self, x):
94
+ blocks = []
95
+ for i, down in enumerate(self.down_path):
96
+ x = down(x)
97
+ if i != len(self.down_path)-1:
98
+ blocks.append(x)
99
+ #x = nn.AvgPool2d(x, 2)
100
+ x = F.avg_pool2d(x, 2)
101
+ x = self.dropout(x)
102
+
103
+ for i, up in enumerate(self.up_path):
104
+ x = up(x, blocks[-i-1])
105
+
106
+ if self.training:
107
+ x = self.GMP(x)
108
+ return self.last(x).view(x.shape[0],-1)
109
+ else:
110
+ mask = self.last(x)
111
+ x = self.GMP(x)
112
+ pred = self.last(x).view(x.shape[0],-1)
113
+ return pred, mask
114
+
115
+ class UNetConvBlock(nn.Module):
116
+ def __init__(self, in_size, out_size, padding, batch_norm, Relu):
117
+ super(UNetConvBlock, self).__init__()
118
+ block = []
119
+
120
+ block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
121
+ padding=int(padding)))
122
+ if Relu == "Relu":
123
+ block.append(nn.ReLU())
124
+ else:
125
+ block.append(nn.PReLU())
126
+
127
+ if batch_norm:
128
+ block.append(nn.BatchNorm2d(out_size))
129
+
130
+ block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
131
+ padding=int(padding)))
132
+
133
+ if Relu == "Relu":
134
+ block.append(nn.ReLU())
135
+ else:
136
+ block.append(nn.PReLU())
137
+
138
+ if batch_norm:
139
+ block.append(nn.BatchNorm2d(out_size))
140
+
141
+ self.block = nn.Sequential(*block)
142
+
143
+ def forward(self, x):
144
+ out = self.block(x)
145
+ return out
146
+
147
+
148
+ class UNetUpBlock(nn.Module):
149
+ def __init__(self, in_size, out_size, up_mode, padding, batch_norm, Relu):
150
+ super(UNetUpBlock, self).__init__()
151
+
152
+ self.up_mode = up_mode
153
+
154
+ if up_mode == 'upconv':
155
+ self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2,
156
+ stride=2)
157
+ elif up_mode == 'bilinear':
158
+ self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2), #'trilinear'
159
+ nn.Conv2d(in_size, out_size, kernel_size=1))
160
+ elif 'inc' in up_mode:
161
+ self.up = nn.Conv2d(in_size, out_size, kernel_size=1)
162
+
163
+
164
+ self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm, Relu)
165
+
166
+ def forward(self, x, bridge):
167
+ if self.up_mode == 'upconv': # 'upconv'
168
+ up = self.up(x)
169
+ elif self.up_mode == 'bilinear':
170
+ up = self.up(x)
171
+ elif 'inc' in self.up_mode:
172
+ x = cF._sinc_interpolate(x, size=[int(x.shape[2]*2), int(x.shape[3]*2)]) #'sinc' ###sth wrong
173
+ up = self.up(x)
174
+
175
+ # bridge = self.center_crop(bridge, up.shape[2:]) #sending shape ignoring 2 digit, so target size start with 0,1,2
176
+ up = F.interpolate(up, size=bridge.shape[2:], mode='bilinear')
177
+ out = torch.cat([up, bridge], 1)
178
+ out = self.conv_block(out)
179
+
180
+ return out
181
+
182
+ #to run it here from this script, uncomment the following
183
+
184
+ if __name__ == "__main__": #to run it
185
+ image = torch.rand(2, 4, 240, 240) #specify your image: batch size, Channel, height, width
186
+ model = GP_UNet(in_channels=4, n_classes=3, depth=4, wf=6, up_mode="upsample_Sinc", Relu = "Relu") #Initialize the model, up_mode = "upconv" or "upsample1" == interpolate mode Bilinear or "upsample" == interpolate mode sinc
187
+ model.eval()
188
+ out = model(image)
189
+ print(model(image))
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPReconResNet"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "GPModelConfigs.GPReconResNetConfig",
7
+ "AutoModel": "GPModels.GPReconResNet"
8
+ },
9
+ "do_batchnorm": false,
10
+ "forwardV": 0,
11
+ "in_channels": 1,
12
+ "is3D": false,
13
+ "is_relu_leaky": true,
14
+ "model_type": "GPReconResNet",
15
+ "n_classes": 3,
16
+ "out_act": "None",
17
+ "post_interp_convtrans": false,
18
+ "res_blocks": 14,
19
+ "res_drop_prob": 0.5,
20
+ "starting_nfeatures": 64,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.44.2",
23
+ "updown_blocks": 2,
24
+ "upinterp_algo": "sinc"
25
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f86cb2396ae9c392ad99f6de61459f8764a368a3073a1244c780dfecb05ced54
3
+ size 69063232