noeedc commited on
Commit
7d97c60
·
1 Parent(s): be815d2

Add initial model configuration, implementation, and inference example

Browse files
Files changed (4) hide show
  1. config.json +9 -0
  2. dehazeformer.py +1018 -0
  3. inference_example.py +30 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "dehazeformer",
3
+ "architectures": ["DehazeFormerMCTWrapper"],
4
+ "auto_map": {
5
+ "AutoModel": "dehazeformer.DehazeFormerMCTWrapper",
6
+ "AutoConfig": "dehazeformer.DehazeFormerConfig"
7
+ },
8
+ "trust_remote_code": true
9
+ }
dehazeformer.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+ from torch.nn.init import _calculate_fan_in_and_fan_out
12
+ from timm.models.layers import trunc_normal_
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ class RLN(nn.Module):
17
+ r"""Revised LayerNorm"""
18
+
19
+ def __init__(self, dim, eps=1e-5, detach_grad=False):
20
+ super(RLN, self).__init__()
21
+ self.eps = eps
22
+ self.detach_grad = detach_grad
23
+
24
+ self.weight = nn.Parameter(torch.ones((1, dim, 1, 1)))
25
+ self.bias = nn.Parameter(torch.zeros((1, dim, 1, 1)))
26
+
27
+ self.meta1 = nn.Conv2d(1, dim, 1)
28
+ self.meta2 = nn.Conv2d(1, dim, 1)
29
+
30
+ trunc_normal_(self.meta1.weight, std=0.02)
31
+ nn.init.constant_(self.meta1.bias, 1)
32
+
33
+ trunc_normal_(self.meta2.weight, std=0.02)
34
+ nn.init.constant_(self.meta2.bias, 0)
35
+
36
+ def forward(self, input):
37
+ mean = torch.mean(input, dim=(1, 2, 3), keepdim=True)
38
+ std = torch.sqrt(
39
+ (input - mean).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.eps
40
+ )
41
+
42
+ normalized_input = (input - mean) / std
43
+
44
+ if self.detach_grad:
45
+ rescale, rebias = self.meta1(std.detach()), self.meta2(mean.detach())
46
+ else:
47
+ rescale, rebias = self.meta1(std), self.meta2(mean)
48
+
49
+ out = normalized_input * self.weight + self.bias
50
+ return out, rescale, rebias
51
+
52
+
53
+ class Mlp(nn.Module):
54
+ def __init__(
55
+ self, network_depth, in_features, hidden_features=None, out_features=None
56
+ ):
57
+ super().__init__()
58
+ out_features = out_features or in_features
59
+ hidden_features = hidden_features or in_features
60
+
61
+ self.network_depth = network_depth
62
+
63
+ self.mlp = nn.Sequential(
64
+ nn.Conv2d(in_features, hidden_features, 1),
65
+ nn.ReLU(True),
66
+ nn.Conv2d(hidden_features, out_features, 1),
67
+ )
68
+
69
+ self.apply(self._init_weights)
70
+
71
+ def _init_weights(self, m):
72
+ if isinstance(m, nn.Conv2d):
73
+ gain = (8 * self.network_depth) ** (-1 / 4)
74
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
75
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
76
+ trunc_normal_(m.weight, std=std)
77
+ if m.bias is not None:
78
+ nn.init.constant_(m.bias, 0)
79
+
80
+ def forward(self, x):
81
+ return self.mlp(x)
82
+
83
+
84
+ def window_partition(x, window_size):
85
+ B, H, W, C = x.shape
86
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
87
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, C)
88
+ return windows
89
+
90
+
91
+ def window_reverse(windows, window_size, H, W):
92
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
93
+ x = windows.view(
94
+ B, H // window_size, W // window_size, window_size, window_size, -1
95
+ )
96
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
97
+ return x
98
+
99
+
100
+ def get_relative_positions(window_size):
101
+ coords_h = torch.arange(window_size)
102
+ coords_w = torch.arange(window_size)
103
+
104
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
105
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
106
+ relative_positions = (
107
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
108
+ ) # 2, Wh*Ww, Wh*Ww
109
+
110
+ relative_positions = relative_positions.permute(
111
+ 1, 2, 0
112
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
113
+ relative_positions_log = torch.sign(relative_positions) * torch.log(
114
+ 1.0 + relative_positions.abs()
115
+ )
116
+
117
+ return relative_positions_log
118
+
119
+
120
+ class WindowAttention(nn.Module):
121
+ def __init__(self, dim, window_size, num_heads):
122
+
123
+ super().__init__()
124
+ self.dim = dim
125
+ self.window_size = window_size # Wh, Ww
126
+ self.num_heads = num_heads
127
+ head_dim = dim // num_heads
128
+ self.scale = head_dim**-0.5
129
+
130
+ relative_positions = get_relative_positions(self.window_size)
131
+ self.register_buffer("relative_positions", relative_positions)
132
+ self.meta = nn.Sequential(
133
+ nn.Linear(2, 256, bias=True),
134
+ nn.ReLU(True),
135
+ nn.Linear(256, num_heads, bias=True),
136
+ )
137
+
138
+ self.softmax = nn.Softmax(dim=-1)
139
+
140
+ def forward(self, qkv):
141
+ B_, N, _ = qkv.shape
142
+
143
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, self.dim // self.num_heads).permute(
144
+ 2, 0, 3, 1, 4
145
+ )
146
+
147
+ q, k, v = (
148
+ qkv[0],
149
+ qkv[1],
150
+ qkv[2],
151
+ ) # make torchscript happy (cannot use tensor as tuple)
152
+
153
+ q = q * self.scale
154
+ attn = q @ k.transpose(-2, -1)
155
+
156
+ relative_position_bias = self.meta(self.relative_positions)
157
+ relative_position_bias = relative_position_bias.permute(
158
+ 2, 0, 1
159
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
160
+ attn = attn + relative_position_bias.unsqueeze(0)
161
+
162
+ attn = self.softmax(attn)
163
+
164
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)
165
+ return x
166
+
167
+
168
+ class Attention(nn.Module):
169
+ def __init__(
170
+ self,
171
+ network_depth,
172
+ dim,
173
+ num_heads,
174
+ window_size,
175
+ shift_size,
176
+ use_attn=False,
177
+ conv_type=None,
178
+ ):
179
+ super().__init__()
180
+ self.dim = dim
181
+ self.head_dim = int(dim // num_heads)
182
+ self.num_heads = num_heads
183
+
184
+ self.window_size = window_size
185
+ self.shift_size = shift_size
186
+
187
+ self.network_depth = network_depth
188
+ self.use_attn = use_attn
189
+ self.conv_type = conv_type
190
+
191
+ if self.conv_type == "Conv":
192
+ self.conv = nn.Sequential(
193
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect"),
194
+ nn.ReLU(True),
195
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect"),
196
+ )
197
+
198
+ if self.conv_type == "DWConv":
199
+ self.conv = nn.Conv2d(
200
+ dim, dim, kernel_size=5, padding=2, groups=dim, padding_mode="reflect"
201
+ )
202
+
203
+ if self.conv_type == "DWConv" or self.use_attn:
204
+ self.V = nn.Conv2d(dim, dim, 1)
205
+ self.proj = nn.Conv2d(dim, dim, 1)
206
+
207
+ if self.use_attn:
208
+ self.QK = nn.Conv2d(dim, dim * 2, 1)
209
+ self.attn = WindowAttention(dim, window_size, num_heads)
210
+
211
+ self.apply(self._init_weights)
212
+
213
+ def _init_weights(self, m):
214
+ if isinstance(m, nn.Conv2d):
215
+ w_shape = m.weight.shape
216
+
217
+ if w_shape[0] == self.dim * 2: # QK
218
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
219
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
220
+ trunc_normal_(m.weight, std=std)
221
+ else:
222
+ gain = (8 * self.network_depth) ** (-1 / 4)
223
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
224
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
225
+ trunc_normal_(m.weight, std=std)
226
+
227
+ if m.bias is not None:
228
+ nn.init.constant_(m.bias, 0)
229
+
230
+ def check_size(self, x, shift=False):
231
+ _, _, h, w = x.size()
232
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
233
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
234
+
235
+ if shift:
236
+ x = F.pad(
237
+ x,
238
+ (
239
+ self.shift_size,
240
+ (self.window_size - self.shift_size + mod_pad_w) % self.window_size,
241
+ self.shift_size,
242
+ (self.window_size - self.shift_size + mod_pad_h) % self.window_size,
243
+ ),
244
+ mode="reflect",
245
+ )
246
+ else:
247
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
248
+ return x
249
+
250
+ def forward(self, X):
251
+ B, C, H, W = X.shape
252
+
253
+ if self.conv_type == "DWConv" or self.use_attn:
254
+ V = self.V(X)
255
+
256
+ if self.use_attn:
257
+ QK = self.QK(X)
258
+ QKV = torch.cat([QK, V], dim=1)
259
+
260
+ # shift
261
+ shifted_QKV = self.check_size(QKV, self.shift_size > 0)
262
+ Ht, Wt = shifted_QKV.shape[2:]
263
+
264
+ # partition windows
265
+ shifted_QKV = shifted_QKV.permute(0, 2, 3, 1)
266
+ qkv = window_partition(
267
+ shifted_QKV, self.window_size
268
+ ) # nW*B, window_size**2, C
269
+
270
+ attn_windows = self.attn(qkv)
271
+
272
+ # merge windows
273
+ shifted_out = window_reverse(
274
+ attn_windows, self.window_size, Ht, Wt
275
+ ) # B H' W' C
276
+
277
+ # reverse cyclic shift
278
+ out = shifted_out[
279
+ :,
280
+ self.shift_size : (self.shift_size + H),
281
+ self.shift_size : (self.shift_size + W),
282
+ :,
283
+ ]
284
+ attn_out = out.permute(0, 3, 1, 2)
285
+
286
+ if self.conv_type in ["Conv", "DWConv"]:
287
+ conv_out = self.conv(V)
288
+ out = self.proj(conv_out + attn_out)
289
+ else:
290
+ out = self.proj(attn_out)
291
+
292
+ else:
293
+ if self.conv_type == "Conv":
294
+ out = self.conv(X) # no attention and use conv, no projection
295
+ elif self.conv_type == "DWConv":
296
+ out = self.proj(self.conv(V))
297
+
298
+ return out
299
+
300
+
301
+ class TransformerBlock(nn.Module):
302
+ def __init__(
303
+ self,
304
+ network_depth,
305
+ dim,
306
+ num_heads,
307
+ mlp_ratio=4.0,
308
+ norm_layer=nn.LayerNorm,
309
+ mlp_norm=False,
310
+ window_size=8,
311
+ shift_size=0,
312
+ use_attn=True,
313
+ conv_type=None,
314
+ ):
315
+ super().__init__()
316
+ self.use_attn = use_attn
317
+ self.mlp_norm = mlp_norm
318
+
319
+ self.norm1 = norm_layer(dim) if use_attn else nn.Identity()
320
+ self.attn = Attention(
321
+ network_depth,
322
+ dim,
323
+ num_heads=num_heads,
324
+ window_size=window_size,
325
+ shift_size=shift_size,
326
+ use_attn=use_attn,
327
+ conv_type=conv_type,
328
+ )
329
+
330
+ self.norm2 = norm_layer(dim) if use_attn and mlp_norm else nn.Identity()
331
+ self.mlp = Mlp(network_depth, dim, hidden_features=int(dim * mlp_ratio))
332
+
333
+ def forward(self, x):
334
+ identity = x
335
+ if self.use_attn:
336
+ x, rescale, rebias = self.norm1(x)
337
+ x = self.attn(x)
338
+ if self.use_attn:
339
+ x = x * rescale + rebias
340
+ x = identity + x
341
+
342
+ identity = x
343
+ if self.use_attn and self.mlp_norm:
344
+ x, rescale, rebias = self.norm2(x)
345
+ x = self.mlp(x)
346
+ if self.use_attn and self.mlp_norm:
347
+ x = x * rescale + rebias
348
+ x = identity + x
349
+ return x
350
+
351
+
352
+ class BasicLayer(nn.Module):
353
+ def __init__(
354
+ self,
355
+ network_depth,
356
+ dim,
357
+ depth,
358
+ num_heads,
359
+ mlp_ratio=4.0,
360
+ norm_layer=nn.LayerNorm,
361
+ window_size=8,
362
+ attn_ratio=0.0,
363
+ attn_loc="last",
364
+ conv_type=None,
365
+ ):
366
+
367
+ super().__init__()
368
+ self.dim = dim
369
+ self.depth = depth
370
+
371
+ attn_depth = attn_ratio * depth
372
+
373
+ if attn_loc == "last":
374
+ use_attns = [i >= depth - attn_depth for i in range(depth)]
375
+ elif attn_loc == "first":
376
+ use_attns = [i < attn_depth for i in range(depth)]
377
+ elif attn_loc == "middle":
378
+ use_attns = [
379
+ i >= (depth - attn_depth) // 2 and i < (depth + attn_depth) // 2
380
+ for i in range(depth)
381
+ ]
382
+
383
+ # build blocks
384
+ self.blocks = nn.ModuleList(
385
+ [
386
+ TransformerBlock(
387
+ network_depth=network_depth,
388
+ dim=dim,
389
+ num_heads=num_heads,
390
+ mlp_ratio=mlp_ratio,
391
+ norm_layer=norm_layer,
392
+ window_size=window_size,
393
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
394
+ use_attn=use_attns[i],
395
+ conv_type=conv_type,
396
+ )
397
+ for i in range(depth)
398
+ ]
399
+ )
400
+
401
+ def forward(self, x):
402
+ for blk in self.blocks:
403
+ x = blk(x)
404
+ return x
405
+
406
+
407
+ class PatchEmbed(nn.Module):
408
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None):
409
+ super().__init__()
410
+ self.in_chans = in_chans
411
+ self.embed_dim = embed_dim
412
+
413
+ if kernel_size is None:
414
+ kernel_size = patch_size
415
+
416
+ self.proj = nn.Conv2d(
417
+ in_chans,
418
+ embed_dim,
419
+ kernel_size=kernel_size,
420
+ stride=patch_size,
421
+ padding=(kernel_size - patch_size + 1) // 2,
422
+ padding_mode="reflect",
423
+ )
424
+
425
+ def forward(self, x):
426
+ x = self.proj(x)
427
+ return x
428
+
429
+
430
+ class PatchUnEmbed(nn.Module):
431
+ def __init__(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None):
432
+ super().__init__()
433
+ self.out_chans = out_chans
434
+ self.embed_dim = embed_dim
435
+
436
+ if kernel_size is None:
437
+ kernel_size = 1
438
+
439
+ self.proj = nn.Sequential(
440
+ nn.Conv2d(
441
+ embed_dim,
442
+ out_chans * patch_size**2,
443
+ kernel_size=kernel_size,
444
+ padding=kernel_size // 2,
445
+ padding_mode="reflect",
446
+ ),
447
+ nn.PixelShuffle(patch_size),
448
+ )
449
+
450
+ def forward(self, x):
451
+ x = self.proj(x)
452
+ return x
453
+
454
+
455
+ class SKFusion(nn.Module):
456
+ def __init__(self, dim, height=2, reduction=8):
457
+ super(SKFusion, self).__init__()
458
+
459
+ self.height = height
460
+ d = max(int(dim / reduction), 4)
461
+
462
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
463
+ self.mlp = nn.Sequential(
464
+ nn.Conv2d(dim, d, 1, bias=False),
465
+ nn.ReLU(),
466
+ nn.Conv2d(d, dim * height, 1, bias=False),
467
+ )
468
+
469
+ self.softmax = nn.Softmax(dim=1)
470
+
471
+ def forward(self, in_feats):
472
+ B, C, H, W = in_feats[0].shape
473
+
474
+ in_feats = torch.cat(in_feats, dim=1)
475
+ in_feats = in_feats.view(B, self.height, C, H, W)
476
+
477
+ feats_sum = torch.sum(in_feats, dim=1)
478
+ attn = self.mlp(self.avg_pool(feats_sum))
479
+ attn = self.softmax(attn.view(B, self.height, C, 1, 1))
480
+
481
+ out = torch.sum(in_feats * attn, dim=1)
482
+ return out
483
+
484
+
485
+ class DehazeFormer(nn.Module):
486
+ def __init__(
487
+ self,
488
+ in_chans=3,
489
+ out_chans=4,
490
+ window_size=8,
491
+ embed_dims=[24, 48, 96, 48, 24],
492
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
493
+ depths=[16, 16, 16, 8, 8],
494
+ num_heads=[2, 4, 6, 1, 1],
495
+ attn_ratio=[1 / 4, 1 / 2, 3 / 4, 0, 0],
496
+ conv_type=["DWConv", "DWConv", "DWConv", "DWConv", "DWConv"],
497
+ norm_layer=[RLN, RLN, RLN, RLN, RLN],
498
+ ):
499
+ super(DehazeFormer, self).__init__()
500
+
501
+ # setting
502
+ self.patch_size = 4
503
+ self.window_size = window_size
504
+ self.mlp_ratios = mlp_ratios
505
+
506
+ # split image into non-overlapping patches
507
+ self.patch_embed = PatchEmbed(
508
+ patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3
509
+ )
510
+
511
+ # backbone
512
+ self.layer1 = BasicLayer(
513
+ network_depth=sum(depths),
514
+ dim=embed_dims[0],
515
+ depth=depths[0],
516
+ num_heads=num_heads[0],
517
+ mlp_ratio=mlp_ratios[0],
518
+ norm_layer=norm_layer[0],
519
+ window_size=window_size,
520
+ attn_ratio=attn_ratio[0],
521
+ attn_loc="last",
522
+ conv_type=conv_type[0],
523
+ )
524
+
525
+ self.patch_merge1 = PatchEmbed(
526
+ patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]
527
+ )
528
+
529
+ self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
530
+
531
+ self.layer2 = BasicLayer(
532
+ network_depth=sum(depths),
533
+ dim=embed_dims[1],
534
+ depth=depths[1],
535
+ num_heads=num_heads[1],
536
+ mlp_ratio=mlp_ratios[1],
537
+ norm_layer=norm_layer[1],
538
+ window_size=window_size,
539
+ attn_ratio=attn_ratio[1],
540
+ attn_loc="last",
541
+ conv_type=conv_type[1],
542
+ )
543
+
544
+ self.patch_merge2 = PatchEmbed(
545
+ patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]
546
+ )
547
+
548
+ self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
549
+
550
+ self.layer3 = BasicLayer(
551
+ network_depth=sum(depths),
552
+ dim=embed_dims[2],
553
+ depth=depths[2],
554
+ num_heads=num_heads[2],
555
+ mlp_ratio=mlp_ratios[2],
556
+ norm_layer=norm_layer[2],
557
+ window_size=window_size,
558
+ attn_ratio=attn_ratio[2],
559
+ attn_loc="last",
560
+ conv_type=conv_type[2],
561
+ )
562
+
563
+ self.patch_split1 = PatchUnEmbed(
564
+ patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2]
565
+ )
566
+
567
+ assert embed_dims[1] == embed_dims[3]
568
+ self.fusion1 = SKFusion(embed_dims[3])
569
+
570
+ self.layer4 = BasicLayer(
571
+ network_depth=sum(depths),
572
+ dim=embed_dims[3],
573
+ depth=depths[3],
574
+ num_heads=num_heads[3],
575
+ mlp_ratio=mlp_ratios[3],
576
+ norm_layer=norm_layer[3],
577
+ window_size=window_size,
578
+ attn_ratio=attn_ratio[3],
579
+ attn_loc="last",
580
+ conv_type=conv_type[3],
581
+ )
582
+
583
+ self.patch_split2 = PatchUnEmbed(
584
+ patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3]
585
+ )
586
+
587
+ assert embed_dims[0] == embed_dims[4]
588
+ self.fusion2 = SKFusion(embed_dims[4])
589
+
590
+ self.layer5 = BasicLayer(
591
+ network_depth=sum(depths),
592
+ dim=embed_dims[4],
593
+ depth=depths[4],
594
+ num_heads=num_heads[4],
595
+ mlp_ratio=mlp_ratios[4],
596
+ norm_layer=norm_layer[4],
597
+ window_size=window_size,
598
+ attn_ratio=attn_ratio[4],
599
+ attn_loc="last",
600
+ conv_type=conv_type[4],
601
+ )
602
+
603
+ # merge non-overlapping patches into image
604
+ self.patch_unembed = PatchUnEmbed(
605
+ patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3
606
+ )
607
+
608
+ def check_image_size(self, x):
609
+ # NOTE: for I2I test
610
+ _, _, h, w = x.size()
611
+ mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
612
+ mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
613
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
614
+ return x
615
+
616
+ def forward_features(self, x):
617
+ x = self.patch_embed(x)
618
+ x = self.layer1(x)
619
+ skip1 = x
620
+
621
+ x = self.patch_merge1(x)
622
+ x = self.layer2(x)
623
+ skip2 = x
624
+
625
+ x = self.patch_merge2(x)
626
+ x = self.layer3(x)
627
+ x = self.patch_split1(x)
628
+
629
+ x = self.fusion1([x, self.skip2(skip2)]) + x
630
+ x = self.layer4(x)
631
+ x = self.patch_split2(x)
632
+
633
+ x = self.fusion2([x, self.skip1(skip1)]) + x
634
+ x = self.layer5(x)
635
+ x = self.patch_unembed(x)
636
+ return x
637
+
638
+ def forward(self, x):
639
+ H, W = x.shape[2:]
640
+ x = self.check_image_size(x)
641
+
642
+ feat = self.forward_features(x)
643
+ K, B = torch.split(feat, (1, 3), dim=1)
644
+
645
+ x = K * x - B + x
646
+ x = x[:, :, :H, :W]
647
+ return x
648
+
649
+
650
+ def dehazeformer_t():
651
+ return DehazeFormer(
652
+ embed_dims=[24, 48, 96, 48, 24],
653
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
654
+ depths=[4, 4, 4, 2, 2],
655
+ num_heads=[2, 4, 6, 1, 1],
656
+ attn_ratio=[0, 1 / 2, 1, 0, 0],
657
+ conv_type=["DWConv", "DWConv", "DWConv", "DWConv", "DWConv"],
658
+ )
659
+
660
+
661
+ def dehazeformer_s():
662
+ return DehazeFormer(
663
+ embed_dims=[24, 48, 96, 48, 24],
664
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
665
+ depths=[8, 8, 8, 4, 4],
666
+ num_heads=[2, 4, 6, 1, 1],
667
+ attn_ratio=[1 / 4, 1 / 2, 3 / 4, 0, 0],
668
+ conv_type=["DWConv", "DWConv", "DWConv", "DWConv", "DWConv"],
669
+ )
670
+
671
+
672
+ def dehazeformer_b():
673
+ return DehazeFormer(
674
+ embed_dims=[24, 48, 96, 48, 24],
675
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
676
+ depths=[16, 16, 16, 8, 8],
677
+ num_heads=[2, 4, 6, 1, 1],
678
+ attn_ratio=[1 / 4, 1 / 2, 3 / 4, 0, 0],
679
+ conv_type=["DWConv", "DWConv", "DWConv", "DWConv", "DWConv"],
680
+ )
681
+
682
+
683
+ def dehazeformer_d():
684
+ return DehazeFormer(
685
+ embed_dims=[24, 48, 96, 48, 24],
686
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
687
+ depths=[32, 32, 32, 16, 16],
688
+ num_heads=[2, 4, 6, 1, 1],
689
+ attn_ratio=[1 / 4, 1 / 2, 3 / 4, 0, 0],
690
+ conv_type=["DWConv", "DWConv", "DWConv", "DWConv", "DWConv"],
691
+ )
692
+
693
+
694
+ def dehazeformer_w():
695
+ return DehazeFormer(
696
+ embed_dims=[48, 96, 192, 96, 48],
697
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
698
+ depths=[16, 16, 16, 8, 8],
699
+ num_heads=[2, 4, 6, 1, 1],
700
+ attn_ratio=[1 / 4, 1 / 2, 3 / 4, 0, 0],
701
+ conv_type=["DWConv", "DWConv", "DWConv", "DWConv", "DWConv"],
702
+ )
703
+
704
+
705
+ def dehazeformer_m():
706
+ return DehazeFormer(
707
+ embed_dims=[24, 48, 96, 48, 24],
708
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
709
+ depths=[12, 12, 12, 6, 6],
710
+ num_heads=[2, 4, 6, 1, 1],
711
+ attn_ratio=[1 / 4, 1 / 2, 3 / 4, 0, 0],
712
+ conv_type=["Conv", "Conv", "Conv", "Conv", "Conv"],
713
+ )
714
+
715
+
716
+ def dehazeformer_l():
717
+ return DehazeFormer(
718
+ embed_dims=[48, 96, 192, 96, 48],
719
+ mlp_ratios=[2.0, 4.0, 4.0, 2.0, 2.0],
720
+ depths=[16, 16, 16, 12, 12],
721
+ num_heads=[2, 4, 6, 1, 1],
722
+ attn_ratio=[1 / 4, 1 / 2, 3 / 4, 0, 0],
723
+ conv_type=["Conv", "Conv", "Conv", "Conv", "Conv"],
724
+ )
725
+
726
+
727
+ class DehazeFormerMCT(nn.Module):
728
+ def __init__(
729
+ self,
730
+ in_chans=3,
731
+ out_chans=3,
732
+ window_size=8,
733
+ embed_dims=[24, 48, 96, 48, 24],
734
+ mlp_ratios=[2.0, 2.0, 4.0, 2.0, 2.0],
735
+ depths=[4, 4, 8, 4, 4],
736
+ num_heads=[2, 4, 6, 4, 2],
737
+ attn_ratio=[1.0, 1.0, 1.0, 1.0, 1.0],
738
+ conv_type=["DWConv", "DWConv", "DWConv", "DWConv", "DWConv"],
739
+ norm_layer=[RLN, RLN, RLN, RLN, RLN],
740
+ ):
741
+ super(DehazeFormerMCT, self).__init__()
742
+
743
+ # setting
744
+ self.patch_size = 4
745
+ self.window_size = window_size
746
+ self.mlp_ratios = mlp_ratios
747
+
748
+ # split image into non-overlapping patches
749
+ self.patch_embed = PatchEmbed(
750
+ patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3
751
+ )
752
+
753
+ # backbone
754
+ self.layer1 = BasicLayer(
755
+ network_depth=sum(depths),
756
+ dim=embed_dims[0],
757
+ depth=depths[0],
758
+ num_heads=num_heads[0],
759
+ mlp_ratio=mlp_ratios[0],
760
+ norm_layer=norm_layer[0],
761
+ window_size=window_size,
762
+ attn_ratio=attn_ratio[0],
763
+ attn_loc="last",
764
+ conv_type=conv_type[0],
765
+ )
766
+
767
+ self.patch_merge1 = PatchEmbed(
768
+ patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]
769
+ )
770
+
771
+ self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
772
+
773
+ self.layer2 = BasicLayer(
774
+ network_depth=sum(depths),
775
+ dim=embed_dims[1],
776
+ depth=depths[1],
777
+ num_heads=num_heads[1],
778
+ mlp_ratio=mlp_ratios[1],
779
+ norm_layer=norm_layer[1],
780
+ window_size=window_size,
781
+ attn_ratio=attn_ratio[1],
782
+ attn_loc="last",
783
+ conv_type=conv_type[1],
784
+ )
785
+
786
+ self.patch_merge2 = PatchEmbed(
787
+ patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]
788
+ )
789
+
790
+ self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
791
+
792
+ self.layer3 = BasicLayer(
793
+ network_depth=sum(depths),
794
+ dim=embed_dims[2],
795
+ depth=depths[2],
796
+ num_heads=num_heads[2],
797
+ mlp_ratio=mlp_ratios[2],
798
+ norm_layer=norm_layer[2],
799
+ window_size=window_size,
800
+ attn_ratio=attn_ratio[2],
801
+ attn_loc="last",
802
+ conv_type=conv_type[2],
803
+ )
804
+
805
+ self.patch_split1 = PatchUnEmbed(
806
+ patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2]
807
+ )
808
+
809
+ assert embed_dims[1] == embed_dims[3]
810
+ self.fusion1 = SKFusion(embed_dims[3])
811
+
812
+ self.layer4 = BasicLayer(
813
+ network_depth=sum(depths),
814
+ dim=embed_dims[3],
815
+ depth=depths[3],
816
+ num_heads=num_heads[3],
817
+ mlp_ratio=mlp_ratios[3],
818
+ norm_layer=norm_layer[3],
819
+ window_size=window_size,
820
+ attn_ratio=attn_ratio[3],
821
+ attn_loc="last",
822
+ conv_type=conv_type[3],
823
+ )
824
+
825
+ self.patch_split2 = PatchUnEmbed(
826
+ patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3]
827
+ )
828
+
829
+ assert embed_dims[0] == embed_dims[4]
830
+ self.fusion2 = SKFusion(embed_dims[4])
831
+
832
+ self.layer5 = BasicLayer(
833
+ network_depth=sum(depths),
834
+ dim=embed_dims[4],
835
+ depth=depths[4],
836
+ num_heads=num_heads[4],
837
+ mlp_ratio=mlp_ratios[4],
838
+ norm_layer=norm_layer[4],
839
+ window_size=window_size,
840
+ attn_ratio=attn_ratio[4],
841
+ attn_loc="last",
842
+ conv_type=conv_type[4],
843
+ )
844
+
845
+ # merge non-overlapping patches into image
846
+ self.patch_unembed = PatchUnEmbed(
847
+ patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3
848
+ )
849
+
850
+ def forward(self, x, x_ref=None):
851
+ x = self.patch_embed(x)
852
+ if x_ref is not None:
853
+ x_ref = self.patch_embed(x_ref)
854
+ x = torch.cat([x, x_ref], dim=3)
855
+
856
+ x = self.layer1(x)
857
+ skip1 = x
858
+
859
+ x = self.patch_merge1(x)
860
+ x = self.layer2(x)
861
+ skip2 = x
862
+
863
+ x = self.patch_merge2(x)
864
+ x = self.layer3(x)
865
+ x = self.patch_split1(x)
866
+
867
+ x = self.fusion1([x, self.skip2(skip2)]) + x
868
+ x = self.layer4(x)
869
+ x = self.patch_split2(x)
870
+
871
+ x = self.fusion2([x, self.skip1(skip1)]) + x
872
+ x = self.layer5(x)
873
+ if x_ref is not None:
874
+ x, x_ref = torch.split(x, (x.shape[3] // 2, x.shape[3] // 2), dim=3)
875
+ x = self.patch_unembed(x)
876
+ return x
877
+
878
+
879
+ class dehazeformer_mct(nn.Module):
880
+ def __init__(self, rf_combine_type=None):
881
+ super(dehazeformer_mct, self).__init__()
882
+ self.ts = 256
883
+ self.l = 8
884
+
885
+ self.dims = 3 * 3 * self.l
886
+ self.rf_combine_type = rf_combine_type
887
+
888
+ ## Reference frame combination type if enabled
889
+ if self.rf_combine_type == 'concat-channel':
890
+ print('Loading Reference Frame model of type: Channel Concat!!')
891
+ self.basenet = DehazeFormerMCT(6, self.dims)
892
+ elif self.rf_combine_type == 'concat-spatial':
893
+ print('Loading Reference Frame model of type: Spatial Concat!!')
894
+ self.basenet = DehazeFormerMCT(3, self.dims)
895
+ else: ## default
896
+ print('Loading default MCT model without reference frame')
897
+ self.basenet = DehazeFormerMCT(3, self.dims)
898
+
899
+ def get_coord(self, x):
900
+ B, _, H, W = x.size()
901
+
902
+ coordh, coordw = torch.meshgrid(
903
+ [torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)], indexing="ij"
904
+ )
905
+ coordh = coordh.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1)
906
+ coordw = coordw.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1)
907
+
908
+ return coordw.detach(), coordh.detach()
909
+
910
+ def mapping(self, x, param):
911
+ # curves
912
+ curve = torch.stack(torch.chunk(param, 3, dim=1), dim=1)
913
+ curve_list = list(torch.chunk(curve, 3, dim=2))
914
+
915
+ # grid: x, y, z -> w, h, d ~[-1 ,1]
916
+ x_list = list(torch.chunk(x.detach(), 3, dim=1))
917
+ coordw, coordh = self.get_coord(x)
918
+ coordh, coordw = coordh.to(device), coordw.to(device)
919
+ grid_list = [torch.stack([coordw, coordh, x_i], dim=4) for x_i in x_list]
920
+
921
+ # mapping
922
+ out = sum(
923
+ [
924
+ F.grid_sample(curve_i, grid_i, "bilinear", "border", True)
925
+ for curve_i, grid_i in zip(curve_list, grid_list)
926
+ ]
927
+ ).squeeze(2)
928
+
929
+ return out # no Tanh is much better than using Tanh
930
+
931
+ def forward(self, x, ref=None):
932
+ # param input
933
+ x_d = F.interpolate(x, (self.ts, self.ts), mode='area')
934
+ if ref is not None:
935
+ r_d = F.interpolate(ref, (self.ts, self.ts), mode='area')
936
+
937
+ # Reference frame at input
938
+ if self.rf_combine_type == 'concat-channel' and ref is not None:
939
+ inputs = torch.cat([x_d, r_d], dim=1)
940
+ param = self.basenet(inputs)
941
+ elif self.rf_combine_type == 'concat-spatial' and ref is not None:
942
+ param = self.basenet(x_d, r_d)
943
+ else: # default
944
+ param = self.basenet(x_d)
945
+
946
+ return self.mapping(x, param)
947
+
948
+ # Dehazeformer configuration class
949
+ class DehazeFormerConfig(PretrainedConfig):
950
+ model_type = "dehazeformer"
951
+
952
+ def __init__(
953
+ self,
954
+ rf_combine_type="concat-channel",
955
+ ts=256,
956
+ l=8,
957
+ **kwargs
958
+ ):
959
+ self.rf_combine_type = rf_combine_type
960
+ self.ts = ts
961
+ self.l = l
962
+ super().__init__(**kwargs)
963
+
964
+ class DehazeFormerMCTWrapper(PreTrainedModel):
965
+ config_class = DehazeFormerConfig
966
+
967
+ def __init__(self, config):
968
+ super().__init__(config)
969
+ self.model = dehazeformer_mct(rf_combine_type=config.rf_combine_type)
970
+ self.normalize = transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
971
+
972
+ def preprocess(self, img):
973
+ """Preprocess input image to tensor format"""
974
+ if isinstance(img, Image.Image):
975
+ tensor = transforms.ToTensor()(img).unsqueeze(0)
976
+ elif isinstance(img, torch.Tensor):
977
+ tensor = img.unsqueeze(0) if img.dim() == 3 else img
978
+ else:
979
+ raise TypeError(f"Unsupported input type: {type(img)}. Expected PIL.Image or torch.Tensor.")
980
+ return self.normalize(tensor).to(self.device)
981
+
982
+ def forward(self, input_img, ref_img=None, **kwargs):
983
+ """
984
+ Forward pass for the DehazeFormer model
985
+
986
+ Args:
987
+ input_img: Input hazy image (PIL.Image or torch.Tensor)
988
+ ref_img: Reference frame image (PIL.Image or torch.Tensor)
989
+
990
+ Returns:
991
+ torch.Tensor: Dehazed output image
992
+ """
993
+ # Preprocess inputs
994
+ x = self.preprocess(input_img)
995
+
996
+ if ref_img is not None:
997
+ ref_x = self.preprocess(ref_img)
998
+
999
+ # Forward pass with reference frame
1000
+ if self.model.rf_combine_type == 'concat-channel':
1001
+ # Pass original image and reference separately to the model
1002
+ # The model will handle the concatenation internally
1003
+ output = self.model(x, ref_x)
1004
+ elif self.model.rf_combine_type == 'concat-spatial':
1005
+ # Spatial concatenation handled inside model
1006
+ output = self.model(x, ref_x)
1007
+ else:
1008
+ # Default: no reference frame
1009
+ output = self.model(x)
1010
+ else:
1011
+ # Forward pass without reference frame
1012
+ output = self.model(x)
1013
+
1014
+ # Denormalize output: [-1, 1] → [0, 1]
1015
+ output = ((output + 1) / 2).clamp(0, 1)
1016
+
1017
+ # Remove batch dimension if single image
1018
+ return output.squeeze(0) if output.size(0) == 1 else output
inference_example.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ from torchvision import transforms
6
+
7
+ # Change working directory to the script’s folder
8
+ # os.chdir(os.path.dirname(os.path.abspath(__file__)))
9
+
10
+ # Set device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # Load the model
14
+ model = AutoModel.from_pretrained("./claris_rf_channel", trust_remote_code=True)
15
+ model.to(device)
16
+ model.eval()
17
+
18
+ # Load input + reference frames
19
+ input_img = Image.open("claris_rf_channel/sample_img.png").convert("RGB")
20
+ ref_img = Image.open("claris_rf_channel/ref_img.png").convert("RGB")
21
+
22
+ # Inference
23
+ with torch.no_grad():
24
+ output = model(input_img, ref_img)
25
+
26
+ # Convert to PIL and save
27
+ output_pil = transforms.ToPILImage()(output.cpu())
28
+ output_pil.save("output_img_rfchannel.png")
29
+
30
+ print("Saved output as 'output_img_rfchannel.png'")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cce89f744b2e8a7344e387dbaafd1c34ce3122fd05f349c2f7517d4d97534e2
3
+ size 5907859