vedant-jumle commited on
Commit
79b5829
·
1 Parent(s): 657d518

Initial CosAE release

Browse files
Files changed (4) hide show
  1. cosae/__init__.py +0 -0
  2. cosae/config.py +57 -0
  3. cosae/cosae.py +53 -0
  4. cosae/modules.py +267 -0
cosae/__init__.py ADDED
File without changes
cosae/config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+
3
+
4
+ class CosAEConfig(PretrainedConfig):
5
+ model_type = "cosae"
6
+
7
+ def __init__(
8
+ self,
9
+ image_size: tuple[int, int] = (256, 256),
10
+ # Encoder parameters
11
+ in_channels: int = 3,
12
+ hidden_dims: list[int] = (64, 128, 256, 512),
13
+ num_res_blocks: int = 2,
14
+ downsample_strides: list[int] = (2, 2, 2, 2),
15
+ use_encoder_attention: bool = True,
16
+ encoder_attention_heads: int = 8,
17
+ encoder_attention_layers: int = 1,
18
+ bottleneck_channels: int = 256,
19
+ basis_size: int = 32,
20
+ norm_type: str = "gn", # "gn" (GroupNorm) or "ln" (LayerNorm)
21
+ activation: str = "gelu", # "gelu" or "silu"
22
+
23
+ # Decoder parameters
24
+ decoder_hidden_dim: int = 256,
25
+ decoder_upsample_strides: list[int] = (2,), # e.g. (2,) for one 2× upsample
26
+ use_decoder_attention: bool = False,
27
+ decoder_attention_heads: int = 8,
28
+ decoder_attention_layers: int = 0,
29
+
30
+ **kwargs,
31
+ ):
32
+ """
33
+ Configuration for CosAEModel, including encoder, HCM, and decoder settings.
34
+ """
35
+ super().__init__(**kwargs)
36
+
37
+ # Encoder settings
38
+ self.in_channels = in_channels
39
+ self.hidden_dims = list(hidden_dims)
40
+ self.num_res_blocks = num_res_blocks
41
+ self.downsample_strides = list(downsample_strides)
42
+ self.use_encoder_attention = use_encoder_attention
43
+ self.encoder_attention_heads = encoder_attention_heads
44
+ self.encoder_attention_layers = encoder_attention_layers
45
+ self.bottleneck_channels = bottleneck_channels
46
+ self.basis_size = basis_size
47
+ self.norm_type = norm_type
48
+ self.activation = activation
49
+ self.image_size = image_size
50
+
51
+ # Decoder settings
52
+ self.decoder_hidden_dim = decoder_hidden_dim
53
+ self.decoder_upsample_strides = list(decoder_upsample_strides)
54
+ self.use_decoder_attention = use_decoder_attention
55
+ self.decoder_attention_heads = decoder_attention_heads
56
+ self.decoder_attention_layers = decoder_attention_layers
57
+
cosae/cosae.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from transformers import PreTrainedModel
7
+ from .modules import *
8
+ from .config import CosAEConfig
9
+
10
+ class CosAEModel(PreTrainedModel):
11
+ config_class = CosAEConfig
12
+ base_model_prefix = "cosae"
13
+
14
+ def __init__(self, config: CosAEConfig):
15
+ super().__init__(config)
16
+ # 1) Encoder
17
+ self.encoder = CosAEEncoder(config)
18
+
19
+ # 2) Harmonic Construction Module
20
+ # derive P = total downsampling factor from encoder strides
21
+ stem_ds = 2 * 2
22
+ P = stem_ds * math.prod(config.downsample_strides)
23
+ # basis size T = P // 2
24
+ T = P // 2
25
+ self.T = T
26
+ self.hcm = HarmonicConstructionModule(
27
+ bottleneck_channels=config.bottleneck_channels,
28
+ basis_size=config.basis_size
29
+ )
30
+
31
+ # 3) Decoder
32
+ self.decoder = CosAEDecoder(config)
33
+
34
+ # initialize weights, etc.
35
+ self.post_init()
36
+
37
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
38
+ """
39
+ Args:
40
+ pixel_values: [B, C_in, H, W] (C_in = 3 or 9 if using FFT)
41
+ Returns:
42
+ recon: [B, 3, H, W] reconstructed image
43
+ """
44
+ # Encode to get amplitudes & phases
45
+ bottleneck = self.encoder(pixel_values) # [B, 2c, H', W']
46
+ amp, ph = torch.chunk(bottleneck, 2, dim=1) # each [B, c, H', W']
47
+
48
+ # Build harmonics
49
+ harmonics = self.hcm(amp, ph) # [B, c, H, W]
50
+
51
+ # Decode to reconstruct
52
+ recon = self.decoder(harmonics) # [B, 3, H, W]
53
+ return recon
cosae/modules.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+ from .config import CosAEConfig
8
+
9
+ """This code has partially been generated by ChatGPT"""
10
+
11
+ class ResBlock(nn.Module):
12
+ def __init__(self, in_ch, out_ch, norm_type="gn", activation="gelu"):
13
+ super().__init__()
14
+ Norm = nn.GroupNorm if norm_type == "gn" else nn.LayerNorm
15
+ act = nn.GELU if activation == "gelu" else nn.SiLU
16
+
17
+ self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
18
+ self.norm1 = Norm(8, out_ch) if norm_type == "gn" else Norm(out_ch)
19
+ self.act1 = act()
20
+
21
+ self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False)
22
+ self.norm2 = Norm(8, out_ch) if norm_type == "gn" else Norm(out_ch)
23
+ self.act2 = act()
24
+
25
+ if in_ch != out_ch:
26
+ self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
27
+ else:
28
+ self.skip = nn.Identity()
29
+
30
+ def forward(self, x):
31
+ identity = self.skip(x)
32
+ out = self.conv1(x)
33
+ out = self.norm1(out)
34
+ out = self.act1(out)
35
+ out = self.conv2(out)
36
+ out = self.norm2(out)
37
+ out = out + identity
38
+ return self.act2(out)
39
+
40
+
41
+ class CosAEEncoder(PreTrainedModel):
42
+ config_class = CosAEConfig
43
+ base_model_prefix = "encoder"
44
+
45
+ def __init__(self, config: CosAEConfig):
46
+ super().__init__(config)
47
+ c = config
48
+ # Stem
49
+ self.stem = nn.Sequential(
50
+ nn.Conv2d(c.in_channels, c.hidden_dims[0], kernel_size=7, stride=2, padding=3, bias=False),
51
+ nn.GroupNorm(8, c.hidden_dims[0]) if c.norm_type == "gn" else nn.LayerNorm([c.hidden_dims[0], 128, 128]),
52
+ nn.GELU() if c.activation == "gelu" else nn.SiLU(),
53
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
54
+ )
55
+ # Downsampling stages
56
+ dims = c.hidden_dims
57
+ self.stages = nn.ModuleList()
58
+ in_ch = dims[0]
59
+ for i, out_ch in enumerate(dims[1:]):
60
+ blocks = []
61
+ for _ in range(c.num_res_blocks):
62
+ blocks.append(ResBlock(in_ch, out_ch, norm_type=c.norm_type, activation=c.activation))
63
+ in_ch = out_ch
64
+ # downsample conv
65
+ blocks.append(
66
+ nn.Sequential(
67
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=c.downsample_strides[i], padding=1, bias=False),
68
+ nn.GroupNorm(8, out_ch) if c.norm_type == "gn" else nn.LayerNorm([out_ch, -1, -1]),
69
+ nn.GELU() if c.activation == "gelu" else nn.SiLU(),
70
+ )
71
+ )
72
+ self.stages.append(nn.Sequential(*blocks))
73
+
74
+ # Optional global attention
75
+ if c.use_encoder_attention:
76
+ encoder_layer = nn.TransformerEncoderLayer(
77
+ d_model=dims[-1],
78
+ nhead=c.encoder_attention_heads,
79
+ dim_feedforward=dims[-1] * 4,
80
+ activation=c.activation,
81
+ batch_first=True,
82
+ )
83
+ self.attn = nn.TransformerEncoder(encoder_layer, num_layers=c.encoder_attention_layers)
84
+ else:
85
+ self.attn = None
86
+
87
+ # Head: project to 2 * bottleneck_channels
88
+ self.head = nn.Conv2d(dims[-1], 2 * c.bottleneck_channels, kernel_size=1)
89
+
90
+ # Initialize weights
91
+ self.post_init()
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ """
95
+ Args:
96
+ x: [B, C_in, H, W]
97
+ Returns:
98
+ bottleneck: [B, 2c, H/P, W/P]
99
+ """
100
+ # Stem
101
+ x = self.stem(x)
102
+ # Stages
103
+ for stage in self.stages:
104
+ x = stage(x)
105
+ # x: [B, dims[-1], H/P, W/P]
106
+ # Optional attention
107
+ if self.attn is not None:
108
+ B, C, H, W = x.shape
109
+ seq = x.flatten(2).transpose(1, 2) # [B, H*W, C]
110
+ seq = self.attn(seq)
111
+ x = seq.transpose(1, 2).view(B, C, H, W)
112
+ # Head
113
+ bottleneck = self.head(x)
114
+ return bottleneck
115
+
116
+ class HarmonicConstructionModule(nn.Module):
117
+ """
118
+ Given:
119
+ - amplitudes: Tensor of shape [B, c, H', W']
120
+ - phases: Tensor of shape [B, c, H', W']
121
+ and learnable frequencies (u, v) of shape [c, 2],
122
+ this module builds a [B, c, H'*T, W'*T] tensor of harmonics:
123
+ H[b,k,i*T + x, j*T + y]
124
+ = A[b,k,i,j] * cos( 2π/T * (u[k]*x + v[k]*y) - Φ[b,k,i,j] )
125
+ """
126
+ def __init__(self, bottleneck_channels: int, basis_size: int):
127
+ """
128
+ Args:
129
+ bottleneck_channels: c, number of freq components
130
+ basis_size: T, size of each cosine basis (e.g. 32 or 64)
131
+ """
132
+ super().__init__()
133
+ self.c = bottleneck_channels
134
+ self.T = basis_size
135
+
136
+ # Learnable frequencies in [0, T/2)
137
+ self.freqs = nn.Parameter(
138
+ torch.rand(self.c, 2) * (self.T / 2)
139
+ ) # shape [c,2] for (u,v)
140
+
141
+ # Precompute the x,y grid of size [T, T]
142
+ x = torch.arange(self.T, dtype=torch.float32)
143
+ y = torch.arange(self.T, dtype=torch.float32)
144
+ xs, ys = torch.meshgrid(x, y, indexing="ij") # both shape [T,T]
145
+
146
+ # Register as buffers so they move with .to(device)
147
+ self.register_buffer("xs", xs) # [T,T]
148
+ self.register_buffer("ys", ys) # [T,T]
149
+
150
+ def forward(self, amplitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor:
151
+ """
152
+ Args:
153
+ amplitude: [B, c, H', W']
154
+ phase: [B, c, H', W']
155
+ Returns:
156
+ harmonics: [B, c, H'*T, W'*T]
157
+ """
158
+ B, c, Hp, Wp = amplitude.shape
159
+ assert c == self.c, "Channel mismatch"
160
+
161
+ # 1) compute spatial_phase for each freq: [c, T, T]
162
+ # 2π/T * (u[k]*xs + v[k]*ys)
163
+ u = self.freqs[:, 0].view(c, 1, 1) # [c,1,1]
164
+ v = self.freqs[:, 1].view(c, 1, 1) # [c,1,1]
165
+ spatial_phase = (2 * math.pi / self.T) * (u * self.xs + v * self.ys)
166
+ # reshape for broadcasting to [1,c,1,1,T,T]
167
+ spatial_phase = spatial_phase.view(1, c, 1, 1, self.T, self.T)
168
+
169
+ # 2) prepare amplitude & phase maps:
170
+ # [B, c, Hp, Wp] → [B, c, Hp, Wp, 1, 1]
171
+ A = amplitude.view(B, c, Hp, Wp, 1, 1)
172
+ Φ = phase.view(B, c, Hp, Wp, 1, 1)
173
+
174
+ # 3) compute argument and harmonic:
175
+ # arg = spatial_phase - Φ
176
+ # H = A * cos(arg)
177
+ arg = spatial_phase - Φ # [B, c, Hp, Wp, T, T]
178
+ H = A * torch.cos(arg) # same shape
179
+
180
+ # 4) tile out to full spatial size [B, c, Hp*T, Wp*T]
181
+ # first permute to [B, c, Hp, T, Wp, T] then reshape
182
+ H = H.permute(0, 1, 2, 4, 3, 5) # [B, c, Hp, T, Wp, T]
183
+ H = H.reshape(B, c, Hp * self.T, Wp * self.T)
184
+
185
+ return H
186
+
187
+ class CosAEDecoder(PreTrainedModel):
188
+ config_class = CosAEConfig
189
+ base_model_prefix = "decoder"
190
+
191
+ def __init__(self, config: CosAEConfig):
192
+ super().__init__(config)
193
+ c = config
194
+
195
+ # 1×1 projection from HCM channels → decoder hidden dim
196
+ self.proj = nn.Conv2d(
197
+ c.bottleneck_channels,
198
+ c.decoder_hidden_dim,
199
+ kernel_size=1,
200
+ bias=False
201
+ )
202
+ # normalization + activation after proj
203
+ Norm = nn.GroupNorm if c.norm_type == "gn" else nn.LayerNorm
204
+ self.norm0 = Norm(8, c.decoder_hidden_dim) if c.norm_type=="gn" else Norm([c.decoder_hidden_dim, -1, -1])
205
+ self.act0 = nn.GELU() if c.activation=="gelu" else nn.SiLU()
206
+
207
+ # upsampling blocks
208
+ self.upsamples = nn.ModuleList()
209
+ for scale in c.decoder_upsample_strides:
210
+ block = nn.Sequential(
211
+ nn.Upsample(scale_factor=scale, mode="bilinear", align_corners=False),
212
+ nn.Conv2d(c.decoder_hidden_dim, c.decoder_hidden_dim, kernel_size=3, padding=1, bias=False),
213
+ Norm(8, c.decoder_hidden_dim) if c.norm_type=="gn" else Norm([c.decoder_hidden_dim, -1, -1]),
214
+ nn.GELU() if c.activation=="gelu" else nn.SiLU(),
215
+ )
216
+ self.upsamples.append(block)
217
+
218
+ # optional global attention in decoder
219
+ if c.use_decoder_attention:
220
+ enc_layer = nn.TransformerEncoderLayer(
221
+ d_model=c.decoder_hidden_dim,
222
+ nhead=c.decoder_attention_heads,
223
+ dim_feedforward=c.decoder_hidden_dim * 4,
224
+ activation=c.activation,
225
+ batch_first=True,
226
+ )
227
+ self.attn = nn.TransformerEncoder(enc_layer, num_layers=c.decoder_attention_layers)
228
+ else:
229
+ self.attn = None
230
+
231
+ # final conv to RGB
232
+ self.final_conv = nn.Conv2d(
233
+ c.decoder_hidden_dim,
234
+ 3,
235
+ kernel_size=3,
236
+ padding=1
237
+ )
238
+
239
+ # initialize weights
240
+ self.post_init()
241
+
242
+ def forward(self, harmonics: torch.Tensor) -> torch.Tensor:
243
+ """
244
+ Args:
245
+ harmonics: Tensor from HCM, shape [B, c, H*, W*]
246
+ Returns:
247
+ recon: Reconstructed image, shape [B, 3, H, W]
248
+ """
249
+ x = self.proj(harmonics) # [B, hidden_dim, H*, W*]
250
+ x = self.norm0(x)
251
+ x = self.act0(x)
252
+
253
+ # upsample to higher resolution
254
+ for up in self.upsamples:
255
+ x = up(x) # doubles H*, W* each block
256
+
257
+ # optional global attention
258
+ if self.attn is not None:
259
+ B, C, H, W = x.shape
260
+ seq = x.flatten(2).transpose(1, 2) # [B, H*W, C]
261
+ seq = self.attn(seq)
262
+ x = seq.transpose(1, 2).view(B, C, H, W)
263
+
264
+ # final RGB projection
265
+ recon = self.final_conv(x)
266
+ return recon
267
+