update files
Browse files- README.md +53 -17
- freqfusion.py → densefusion.py +68 -75
- get_depth_normap.py +8 -8
- model.py +8 -8
- test_shadow.py +8 -45
- utils/model_utils.py +3 -3
README.md
CHANGED
@@ -1,10 +1,33 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
* [Checkpoints](https://drive.google.com/file/d/1USD5sLvEcgFqIg7BDzc1OuInzSx3GnUN/view?usp=drive_link)
|
5 |
-
* Input / Output file
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
```bash
|
9 |
conda create -n ntire_shadow python=3.9 -y
|
10 |
|
@@ -16,15 +39,16 @@ pip install -r requirements.txt
|
|
16 |
|
17 |
```
|
18 |
|
19 |
-
## Folder Structure
|
|
|
20 |
```bash
|
21 |
test_dir
|
22 |
-
├──
|
23 |
│ ├── 0000.png
|
24 |
│ ├── 0001.png
|
25 |
│ ├── ...
|
26 |
-
├──
|
27 |
-
├──
|
28 |
|
29 |
|
30 |
output_dir
|
@@ -33,7 +57,7 @@ output_dir
|
|
33 |
├──...
|
34 |
```
|
35 |
|
36 |
-
## How to test?
|
37 |
1. Clone [Depth anything v2](https://github.com/DepthAnything/Depth-Anything-V2.git)
|
38 |
|
39 |
```bash
|
@@ -49,15 +73,15 @@ python get_depth_normap.py
|
|
49 |
Now folder structure will be
|
50 |
```bash
|
51 |
test_dir
|
52 |
-
├──
|
53 |
│ ├── 0000.png
|
54 |
│ ├── 0001.png
|
55 |
│ ├── ...
|
56 |
-
├──
|
57 |
│ ├── 0000.npy
|
58 |
│ ├── 0001.npy
|
59 |
│ ├── ...
|
60 |
-
├──
|
61 |
│ ├── 0000.npy
|
62 |
│ ├── 0001.npy
|
63 |
│ ├── ...
|
@@ -68,15 +92,15 @@ output_dir
|
|
68 |
├──...
|
69 |
```
|
70 |
|
71 |
-
|
72 |
```bash
|
73 |
git clone https://github.com/facebookresearch/dinov2.git
|
74 |
```
|
75 |
|
76 |
-
|
77 |
|
78 |
```bash
|
79 |
-
gdown
|
80 |
```
|
81 |
|
82 |
6. Run ```run_test.sh``` to get inference results.
|
@@ -84,5 +108,17 @@ gdown 1USD5sLvEcgFqIg7BDzc1OuInzSx3GnUN
|
|
84 |
```bash
|
85 |
bash run_test.sh
|
86 |
```
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
This code repository is release under [MIT License](https://github.com/VanLinLin/NTIRE25_Shadow_Removal?tab=MIT-1-ov-file#readme).
|
|
|
1 |
+
<h1 align="center">[ACMMM 2025] DenseSR: Image Shadow Removal as Dense Prediction</h1>
|
2 |
+
<p align="center">Yu-Fan Lin<sup>1</sup>, Chia-ming Lee<sup>1</sup>, Chih-Chung Hsu<sup>2</sup></p>
|
3 |
+
<p align="center"><sup>1</sup>National Cheng Kung University <sup>2</sup>National Yang Ming Chiao Tung University</p>
|
4 |
|
5 |
+
<div align="center">
|
|
|
|
|
6 |
|
7 |
+
[](https://www.arxiv.org/abs/2507.16472)
|
8 |
+
|
9 |
+
</div>
|
10 |
+
|
11 |
+
<details>
|
12 |
+
<summary>Abstract</summary>
|
13 |
+
Shadows are a common factor degrading image quality. Single-image shadow removal (SR), particularly under challenging indirect illumination, is hampered by non-uniform content degradation and inherent ambiguity. Consequently, traditional methods often fail to simultaneously recover intra-shadow details and maintain sharp boundaries, resulting in inconsistent restoration and blurring that negatively affect both downstream applications and the overall viewing experience. To overcome these limitations, we propose the DenseSR, approaching the problem from a dense prediction perspective to emphasize restoration quality. This framework uniquely synergizes two key strategies: (1) deep scene understanding guided by geometric-semantic priors to resolve ambiguity and implicitly localize shadows, and (2) high-fidelity restoration via a novel Dense Fusion Block (DFB) in the decoder. The DFB employs adaptive component processing-using an Adaptive Content Smoothing Module (ACSM) for consistent appearance and a Texture-Boundary Recuperation Module (TBRM) for fine textures and sharp boundaries-thereby directly tackling the inconsistent restoration and blurring issues. These purposefully processed components are effectively fused, yielding an optimized feature representation preserving both consistency and fidelity. Extensive experimental results demonstrate the merits of our approach over existing methods.
|
14 |
+
</details>
|
15 |
+
|
16 |
+
## ⭐ Citation
|
17 |
+
If you find this project useful, please consider citing us and giving us a star.
|
18 |
+
```bash
|
19 |
+
@misc{lin2025densesrimageshadowremoval,
|
20 |
+
title={DenseSR: Image Shadow Removal as Dense Prediction},
|
21 |
+
author={Yu-Fan Lin and Chia-Ming Lee and Chih-Chung Hsu},
|
22 |
+
year={2025},
|
23 |
+
eprint={2507.16472},
|
24 |
+
archivePrefix={arXiv},
|
25 |
+
primaryClass={cs.CV},
|
26 |
+
url={https://arxiv.org/abs/2507.16472},
|
27 |
+
}
|
28 |
+
```
|
29 |
+
|
30 |
+
## 🌱 Environments
|
31 |
```bash
|
32 |
conda create -n ntire_shadow python=3.9 -y
|
33 |
|
|
|
39 |
|
40 |
```
|
41 |
|
42 |
+
## 📂 Folder Structure
|
43 |
+
You can download WSRD dataset from [here](https://github.com/fvasluianu97/WSRD-DNSR).
|
44 |
```bash
|
45 |
test_dir
|
46 |
+
├── origin <- Put the shadow affected images in this folder
|
47 |
│ ├── 0000.png
|
48 |
│ ├── 0001.png
|
49 |
│ ├── ...
|
50 |
+
├── depth
|
51 |
+
├── normal
|
52 |
|
53 |
|
54 |
output_dir
|
|
|
57 |
├──...
|
58 |
```
|
59 |
|
60 |
+
## ✨ How to test?
|
61 |
1. Clone [Depth anything v2](https://github.com/DepthAnything/Depth-Anything-V2.git)
|
62 |
|
63 |
```bash
|
|
|
73 |
Now folder structure will be
|
74 |
```bash
|
75 |
test_dir
|
76 |
+
├── origin
|
77 |
│ ├── 0000.png
|
78 |
│ ├── 0001.png
|
79 |
│ ├── ...
|
80 |
+
├── depth
|
81 |
│ ├── 0000.npy
|
82 |
│ ├── 0001.npy
|
83 |
│ ├── ...
|
84 |
+
├── ormal
|
85 |
│ ├── 0000.npy
|
86 |
│ ├── 0001.npy
|
87 |
│ ├── ...
|
|
|
92 |
├──...
|
93 |
```
|
94 |
|
95 |
+
4. Clone [DINOv2](https://github.com/facebookresearch/dinov2.git)
|
96 |
```bash
|
97 |
git clone https://github.com/facebookresearch/dinov2.git
|
98 |
```
|
99 |
|
100 |
+
5. Download [shadow removal weight](https://drive.google.com/file/d/1of3KLSVhaXlsX3jasuwdPKBwb4O4hGZD/view?usp=drive_link)
|
101 |
|
102 |
```bash
|
103 |
+
gdown 1of3KLSVhaXlsX3jasuwdPKBwb4O4hGZD
|
104 |
```
|
105 |
|
106 |
6. Run ```run_test.sh``` to get inference results.
|
|
|
108 |
```bash
|
109 |
bash run_test.sh
|
110 |
```
|
111 |
+
|
112 |
+
## 📰 News
|
113 |
+
✔ 2025/08/11 Release WSRD pretrained model
|
114 |
+
|
115 |
+
✔ 2025/08/11 Release inference code
|
116 |
+
|
117 |
+
✔ 2025/07/05 Paper Accepted by ACMMM'25
|
118 |
+
|
119 |
+
## 🛠️ TODO
|
120 |
+
◻ Release training code
|
121 |
+
◻ Release other pretrained model
|
122 |
+
|
123 |
+
## 📜 License and
|
124 |
This code repository is release under [MIT License](https://github.com/VanLinLin/NTIRE25_Shadow_Removal?tab=MIT-1-ov-file#readme).
|
freqfusion.py → densefusion.py
RENAMED
@@ -7,59 +7,41 @@ from torch.utils.checkpoint import checkpoint
|
|
7 |
import warnings
|
8 |
import numpy as np
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w)
|
46 |
-
res = unfold_x * normed_mask
|
47 |
-
# test
|
48 |
-
# res[:, :, 0] = 1
|
49 |
-
# res[:, :, 1] = 2
|
50 |
-
# res[:, :, 2] = 3
|
51 |
-
# res[:, :, 3] = 4
|
52 |
-
res = res.sum(dim=2).reshape(b, c, m_h, m_w)
|
53 |
-
# res = F.pixel_shuffle(res, up)
|
54 |
-
# print(res.shape)
|
55 |
-
# print(res)
|
56 |
-
return res
|
57 |
-
|
58 |
-
def normal_init(module, mean=0, std=1, bias=0):
|
59 |
-
if hasattr(module, 'weight') and module.weight is not None:
|
60 |
-
nn.init.normal_(module.weight, mean, std)
|
61 |
-
if hasattr(module, 'bias') and module.bias is not None:
|
62 |
-
nn.init.constant_(module.bias, bias)
|
63 |
|
64 |
|
65 |
def constant_init(module, val, bias=0):
|
@@ -90,26 +72,12 @@ def resize(input,
|
|
90 |
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
91 |
|
92 |
def hamming2D(M, N):
|
93 |
-
"""
|
94 |
-
生成二维Hamming窗
|
95 |
-
|
96 |
-
参数:
|
97 |
-
- M:窗口的行数
|
98 |
-
- N:窗口的列数
|
99 |
-
|
100 |
-
返回:
|
101 |
-
- 二维Hamming窗
|
102 |
-
"""
|
103 |
-
# 生成水平和垂直方向上的Hamming窗
|
104 |
-
# hamming_x = np.blackman(M)
|
105 |
-
# hamming_x = np.kaiser(M)
|
106 |
hamming_x = np.hamming(M)
|
107 |
hamming_y = np.hamming(N)
|
108 |
-
# 通过外积生成二维Hamming窗
|
109 |
hamming_2d = np.outer(hamming_x, hamming_y)
|
110 |
return hamming_2d
|
111 |
|
112 |
-
class
|
113 |
def __init__(self,
|
114 |
hr_channels,
|
115 |
lr_channels,
|
@@ -122,14 +90,14 @@ class FreqFusion(nn.Module):
|
|
122 |
compressed_channels=64,
|
123 |
align_corners=False,
|
124 |
upsample_mode='nearest',
|
125 |
-
feature_resample=False,
|
126 |
feature_resample_group=4,
|
127 |
-
comp_feat_upsample=True,
|
128 |
use_high_pass=True,
|
129 |
use_low_pass=True,
|
130 |
hr_residual=True,
|
131 |
semi_conv=True,
|
132 |
-
hamming_window=True,
|
133 |
feature_resample_norm=True,
|
134 |
**kwargs):
|
135 |
super().__init__()
|
@@ -142,7 +110,7 @@ class FreqFusion(nn.Module):
|
|
142 |
self.compressed_channels = compressed_channels
|
143 |
self.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels,1)
|
144 |
self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels,1)
|
145 |
-
self.content_encoder = nn.Conv2d(
|
146 |
self.compressed_channels,
|
147 |
lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
|
148 |
self.encoder_kernel,
|
@@ -178,6 +146,8 @@ class FreqFusion(nn.Module):
|
|
178 |
self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0]))
|
179 |
self.register_buffer('hamming_highpass', torch.FloatTensor([1.0]))
|
180 |
self.init_weights()
|
|
|
|
|
181 |
|
182 |
def init_weights(self):
|
183 |
for m in self.modules():
|
@@ -217,6 +187,15 @@ class FreqFusion(nn.Module):
|
|
217 |
return self._forward(hr_feat, lr_feat)
|
218 |
|
219 |
def _forward(self, hr_feat, lr_feat):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
compressed_hr_feat = self.hr_channel_compressor(hr_feat)
|
221 |
compressed_lr_feat = self.lr_channel_compressor(lr_feat)
|
222 |
if self.semi_conv:
|
@@ -250,6 +229,11 @@ class FreqFusion(nn.Module):
|
|
250 |
mask_hr = self.content_encoder2(compressed_x)
|
251 |
|
252 |
mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)
|
|
|
|
|
|
|
|
|
|
|
253 |
if self.semi_conv:
|
254 |
lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)
|
255 |
else:
|
@@ -263,24 +247,33 @@ class FreqFusion(nn.Module):
|
|
263 |
if self.use_high_pass:
|
264 |
mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass)
|
265 |
hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1)
|
|
|
266 |
if self.hr_residual:
|
267 |
# print('using hr_residual')
|
268 |
hr_feat = hr_feat_hf + hr_feat
|
269 |
else:
|
270 |
hr_feat = hr_feat_hf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
if self.feature_resample:
|
273 |
# print(lr_feat.shape)
|
274 |
lr_feat = self.dysampler(hr_x=compressed_hr_feat,
|
275 |
lr_x=compressed_lr_feat, feat2sample=lr_feat)
|
276 |
-
|
|
|
277 |
return mask_lr, hr_feat, lr_feat
|
278 |
|
279 |
|
280 |
|
281 |
class LocalSimGuidedSampler(nn.Module):
|
282 |
"""
|
283 |
-
offset generator in
|
284 |
"""
|
285 |
def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3, sim_type='cos', norm=True, direction_feat='sim_concat'):
|
286 |
super().__init__()
|
@@ -436,6 +429,6 @@ if __name__ == '__main__':
|
|
436 |
|
437 |
hr_feat = torch.rand(1, 128, 512, 512)
|
438 |
lr_feat = torch.rand(1, 128, 256, 256)
|
439 |
-
model =
|
440 |
mask_lr, hr_feat, lr_feat = model(hr_feat=hr_feat, lr_feat=lr_feat)
|
441 |
print(mask_lr.shape)
|
|
|
7 |
import warnings
|
8 |
import numpy as np
|
9 |
|
10 |
+
|
11 |
+
def xavier_init(module: nn.Module,
|
12 |
+
gain: float = 1,
|
13 |
+
bias: float = 0,
|
14 |
+
distribution: str = 'normal') -> None:
|
15 |
+
assert distribution in ['uniform', 'normal']
|
16 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
17 |
+
if distribution == 'uniform':
|
18 |
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
19 |
+
else:
|
20 |
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
21 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
22 |
+
nn.init.constant_(module.bias, bias)
|
23 |
+
|
24 |
+
def carafe(x, normed_mask, kernel_size, group=1, up=1):
|
25 |
+
b, c, h, w = x.shape
|
26 |
+
_, m_c, m_h, m_w = normed_mask.shape
|
27 |
+
assert m_h == up * h
|
28 |
+
assert m_w == up * w
|
29 |
+
pad = kernel_size // 2
|
30 |
+
pad_x = F.pad(x, pad=[pad] * 4, mode='reflect')
|
31 |
+
unfold_x = F.unfold(pad_x, kernel_size=(kernel_size, kernel_size), stride=1, padding=0)
|
32 |
+
unfold_x = unfold_x.reshape(b, c * kernel_size * kernel_size, h, w)
|
33 |
+
unfold_x = F.interpolate(unfold_x, scale_factor=up, mode='nearest')
|
34 |
+
unfold_x = unfold_x.reshape(b, c, kernel_size * kernel_size, m_h, m_w)
|
35 |
+
normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w)
|
36 |
+
res = unfold_x * normed_mask
|
37 |
+
res = res.sum(dim=2).reshape(b, c, m_h, m_w)
|
38 |
+
return res
|
39 |
+
|
40 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
41 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
42 |
+
nn.init.normal_(module.weight, mean, std)
|
43 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
44 |
+
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
def constant_init(module, val, bias=0):
|
|
|
72 |
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
73 |
|
74 |
def hamming2D(M, N):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
hamming_x = np.hamming(M)
|
76 |
hamming_y = np.hamming(N)
|
|
|
77 |
hamming_2d = np.outer(hamming_x, hamming_y)
|
78 |
return hamming_2d
|
79 |
|
80 |
+
class DesneFusion(nn.Module):
|
81 |
def __init__(self,
|
82 |
hr_channels,
|
83 |
lr_channels,
|
|
|
90 |
compressed_channels=64,
|
91 |
align_corners=False,
|
92 |
upsample_mode='nearest',
|
93 |
+
feature_resample=False,
|
94 |
feature_resample_group=4,
|
95 |
+
comp_feat_upsample=True,
|
96 |
use_high_pass=True,
|
97 |
use_low_pass=True,
|
98 |
hr_residual=True,
|
99 |
semi_conv=True,
|
100 |
+
hamming_window=True,
|
101 |
feature_resample_norm=True,
|
102 |
**kwargs):
|
103 |
super().__init__()
|
|
|
110 |
self.compressed_channels = compressed_channels
|
111 |
self.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels,1)
|
112 |
self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels,1)
|
113 |
+
self.content_encoder = nn.Conv2d(
|
114 |
self.compressed_channels,
|
115 |
lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
|
116 |
self.encoder_kernel,
|
|
|
146 |
self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0]))
|
147 |
self.register_buffer('hamming_highpass', torch.FloatTensor([1.0]))
|
148 |
self.init_weights()
|
149 |
+
self.intermediate_results = {}
|
150 |
+
|
151 |
|
152 |
def init_weights(self):
|
153 |
for m in self.modules():
|
|
|
187 |
return self._forward(hr_feat, lr_feat)
|
188 |
|
189 |
def _forward(self, hr_feat, lr_feat):
|
190 |
+
# <<< 唯一修改的部分:在不影響運算的前提下,儲存特徵 >>>
|
191 |
+
|
192 |
+
# 每次 forward 開始時清空,避免儲存舊的結果
|
193 |
+
self.intermediate_results.clear()
|
194 |
+
|
195 |
+
# 1. 儲存原始輸入
|
196 |
+
self.intermediate_results['hr_feat_before'] = hr_feat.clone()
|
197 |
+
self.intermediate_results['lr_feat_before'] = lr_feat.clone()
|
198 |
+
|
199 |
compressed_hr_feat = self.hr_channel_compressor(hr_feat)
|
200 |
compressed_lr_feat = self.lr_channel_compressor(lr_feat)
|
201 |
if self.semi_conv:
|
|
|
229 |
mask_hr = self.content_encoder2(compressed_x)
|
230 |
|
231 |
mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)
|
232 |
+
|
233 |
+
# 2. 儲存低頻處理後的特徵
|
234 |
+
lr_feat_after = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)
|
235 |
+
self.intermediate_results['lr_feat_after'] = lr_feat_after.clone()
|
236 |
+
|
237 |
if self.semi_conv:
|
238 |
lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)
|
239 |
else:
|
|
|
247 |
if self.use_high_pass:
|
248 |
mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass)
|
249 |
hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1)
|
250 |
+
self.intermediate_results['hr_feat_hf_component'] = hr_feat_hf.clone()
|
251 |
if self.hr_residual:
|
252 |
# print('using hr_residual')
|
253 |
hr_feat = hr_feat_hf + hr_feat
|
254 |
else:
|
255 |
hr_feat = hr_feat_hf
|
256 |
+
self.intermediate_results['hr_feat_after'] = hr_feat.clone()
|
257 |
+
else:
|
258 |
+
# 如果不處理,也存入對應的值以避免錯誤
|
259 |
+
final_hr_feat = hr_feat
|
260 |
+
self.intermediate_results['hr_feat_hf_component'] = torch.zeros_like(final_hr_feat)
|
261 |
+
self.intermediate_results['hr_feat_after'] = final_hr_feat.clone()
|
262 |
+
|
263 |
|
264 |
if self.feature_resample:
|
265 |
# print(lr_feat.shape)
|
266 |
lr_feat = self.dysampler(hr_x=compressed_hr_feat,
|
267 |
lr_x=compressed_lr_feat, feat2sample=lr_feat)
|
268 |
+
self.intermediate_results['lr_feat_after'] = lr_feat.clone() # 如果有 dysampler,則更新
|
269 |
+
|
270 |
return mask_lr, hr_feat, lr_feat
|
271 |
|
272 |
|
273 |
|
274 |
class LocalSimGuidedSampler(nn.Module):
|
275 |
"""
|
276 |
+
offset generator in DesneFusion
|
277 |
"""
|
278 |
def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3, sim_type='cos', norm=True, direction_feat='sim_concat'):
|
279 |
super().__init__()
|
|
|
429 |
|
430 |
hr_feat = torch.rand(1, 128, 512, 512)
|
431 |
lr_feat = torch.rand(1, 128, 256, 256)
|
432 |
+
model = DesneFusion(hr_channels=128, lr_channels=128)
|
433 |
mask_lr, hr_feat, lr_feat = model(hr_feat=hr_feat, lr_feat=lr_feat)
|
434 |
print(mask_lr.shape)
|
get_depth_normap.py
CHANGED
@@ -25,22 +25,23 @@ def parse_args():
|
|
25 |
def generate_depth_maps(source_root, model_path):
|
26 |
source_root = Path(source_root)
|
27 |
origin = source_root / 'origin'
|
28 |
-
|
29 |
|
30 |
model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).cuda()
|
31 |
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
32 |
model.eval()
|
33 |
|
34 |
-
|
|
|
35 |
|
36 |
with torch.inference_mode():
|
37 |
-
for
|
38 |
-
folder_name =
|
39 |
-
dst_path =
|
40 |
|
41 |
dst_path.mkdir(parents=True, exist_ok=True)
|
42 |
|
43 |
-
bar = tqdm(
|
44 |
|
45 |
for image_path in bar:
|
46 |
try:
|
@@ -50,14 +51,13 @@ def generate_depth_maps(source_root, model_path):
|
|
50 |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
51 |
depth = depth.astype(np.uint8)
|
52 |
|
53 |
-
print(depth.shape)
|
54 |
np.save(f'{dst_path}/{image_path.stem}.npy', depth)
|
55 |
|
56 |
except Exception as e:
|
57 |
print(e)
|
58 |
continue
|
59 |
|
60 |
-
return
|
61 |
|
62 |
|
63 |
def calculate_normal_map(img_path: Path, ksize=5):
|
|
|
25 |
def generate_depth_maps(source_root, model_path):
|
26 |
source_root = Path(source_root)
|
27 |
origin = source_root / 'origin'
|
28 |
+
to_depth_list = [origin]
|
29 |
|
30 |
model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).cuda()
|
31 |
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
32 |
model.eval()
|
33 |
|
34 |
+
depth_path = source_root / 'depth'
|
35 |
+
depth_path.mkdir(parents=True, exist_ok=True)
|
36 |
|
37 |
with torch.inference_mode():
|
38 |
+
for to_depth_item in to_depth_list:
|
39 |
+
folder_name = to_depth_item.stem
|
40 |
+
dst_path = depth_path
|
41 |
|
42 |
dst_path.mkdir(parents=True, exist_ok=True)
|
43 |
|
44 |
+
bar = tqdm(to_depth_item.glob('*'))
|
45 |
|
46 |
for image_path in bar:
|
47 |
try:
|
|
|
51 |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
52 |
depth = depth.astype(np.uint8)
|
53 |
|
|
|
54 |
np.save(f'{dst_path}/{image_path.stem}.npy', depth)
|
55 |
|
56 |
except Exception as e:
|
57 |
print(e)
|
58 |
continue
|
59 |
|
60 |
+
return depth_path
|
61 |
|
62 |
|
63 |
def calculate_normal_map(img_path: Path, ksize=5):
|
model.py
CHANGED
@@ -7,7 +7,7 @@ from einops import rearrange, repeat
|
|
7 |
import math
|
8 |
from utils import grid_sample
|
9 |
|
10 |
-
from
|
11 |
|
12 |
#########################################
|
13 |
|
@@ -1114,7 +1114,7 @@ class ShadowFormer(nn.Module):
|
|
1114 |
|
1115 |
|
1116 |
|
1117 |
-
class
|
1118 |
def __init__(self, img_size=256, in_chans=3,
|
1119 |
embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
|
1120 |
win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
@@ -1265,13 +1265,13 @@ class ShadowFormerFreq(nn.Module):
|
|
1265 |
self.relu = nn.LeakyReLU()
|
1266 |
self.apply(self._init_weights)
|
1267 |
|
1268 |
-
self.
|
1269 |
lr_channels=512)
|
1270 |
|
1271 |
-
self.
|
1272 |
lr_channels=256)
|
1273 |
|
1274 |
-
self.
|
1275 |
lr_channels=128)
|
1276 |
|
1277 |
def _init_weights(self, m):
|
@@ -1362,7 +1362,7 @@ class ShadowFormerFreq(nn.Module):
|
|
1362 |
deconv0_B_C_H_W = deconv0.view(deconv0.shape[0], int(deconv0.shape[1]**0.5), int(deconv0.shape[1]**0.5), 256).permute(0, 3, 1, 2)
|
1363 |
# print(f'1.{deconv0_B_C_H_W.shape=}') # 1, 256, 64, 64
|
1364 |
|
1365 |
-
_, deconv0_B_C_H_W, lr_feat = self.
|
1366 |
# print(f'1.{deconv0.shape=}, {lr_feat.shape=}') # deconv0.shape=torch.Size([1, 256, 64, 64]), lr_feat.shape=torch.Size([1, 512, 64, 64])
|
1367 |
|
1368 |
deconv0 = deconv0_B_C_H_W.view(deconv0_B_C_H_W.shape[0], 256, -1).permute(0, 2, 1)
|
@@ -1382,7 +1382,7 @@ class ShadowFormerFreq(nn.Module):
|
|
1382 |
deconv1_B_C_H_W = deconv1.view(deconv1.shape[0], int(deconv1.shape[1]**0.5), int(deconv1.shape[1]**0.5), 128).permute(0, 3, 1, 2)
|
1383 |
# print(f'2.{deconv1_B_C_H_W.shape=}') # 1, 128, 128, 128
|
1384 |
|
1385 |
-
_, deconv1_B_C_H_W, lr_feat = self.
|
1386 |
|
1387 |
# print(f'2.{deconv1_B_C_H_W.shape=}, {lr_feat.shape=}') # hr_feat.shape=torch.Size([1, 128, 128, 128]), lr_feat.shape=torch.Size([1, 256, 128, 128])
|
1388 |
|
@@ -1403,7 +1403,7 @@ class ShadowFormerFreq(nn.Module):
|
|
1403 |
deconv2_B_C_H_W = deconv2.view(deconv2.shape[0], int(deconv2.shape[1]**0.5), int(deconv2.shape[1]**0.5), 64).permute(0, 3, 1, 2)
|
1404 |
# print(f'3.{deconv2_B_C_H_W.shape=}')
|
1405 |
|
1406 |
-
_, deconv2_B_C_H_W, lr_feat = self.
|
1407 |
|
1408 |
# print('*'*5, f'3.{deconv2_B_C_H_W.shape=}, {lr_feat.shape=}')
|
1409 |
|
|
|
7 |
import math
|
8 |
from utils import grid_sample
|
9 |
|
10 |
+
from densefusion import DesneFusion
|
11 |
|
12 |
#########################################
|
13 |
|
|
|
1114 |
|
1115 |
|
1116 |
|
1117 |
+
class DenseSR(nn.Module):
|
1118 |
def __init__(self, img_size=256, in_chans=3,
|
1119 |
embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
|
1120 |
win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
|
|
1265 |
self.relu = nn.LeakyReLU()
|
1266 |
self.apply(self._init_weights)
|
1267 |
|
1268 |
+
self.densefusion1 = DesneFusion(hr_channels=256,
|
1269 |
lr_channels=512)
|
1270 |
|
1271 |
+
self.densefusion2 = DesneFusion(hr_channels=128,
|
1272 |
lr_channels=256)
|
1273 |
|
1274 |
+
self.densefusion3 = DesneFusion(hr_channels=64,
|
1275 |
lr_channels=128)
|
1276 |
|
1277 |
def _init_weights(self, m):
|
|
|
1362 |
deconv0_B_C_H_W = deconv0.view(deconv0.shape[0], int(deconv0.shape[1]**0.5), int(deconv0.shape[1]**0.5), 256).permute(0, 3, 1, 2)
|
1363 |
# print(f'1.{deconv0_B_C_H_W.shape=}') # 1, 256, 64, 64
|
1364 |
|
1365 |
+
_, deconv0_B_C_H_W, lr_feat = self.densefusion1(hr_feat=deconv0_B_C_H_W, lr_feat=conv3_B_C_H_W) # 1, 256, 64, 64 & 1, 512, 32, 32
|
1366 |
# print(f'1.{deconv0.shape=}, {lr_feat.shape=}') # deconv0.shape=torch.Size([1, 256, 64, 64]), lr_feat.shape=torch.Size([1, 512, 64, 64])
|
1367 |
|
1368 |
deconv0 = deconv0_B_C_H_W.view(deconv0_B_C_H_W.shape[0], 256, -1).permute(0, 2, 1)
|
|
|
1382 |
deconv1_B_C_H_W = deconv1.view(deconv1.shape[0], int(deconv1.shape[1]**0.5), int(deconv1.shape[1]**0.5), 128).permute(0, 3, 1, 2)
|
1383 |
# print(f'2.{deconv1_B_C_H_W.shape=}') # 1, 128, 128, 128
|
1384 |
|
1385 |
+
_, deconv1_B_C_H_W, lr_feat = self.densefusion2(hr_feat=deconv1_B_C_H_W, lr_feat=deconv0_B_C_H_W) # 1, 128, 128, 128 & 1, 256, 64, 64
|
1386 |
|
1387 |
# print(f'2.{deconv1_B_C_H_W.shape=}, {lr_feat.shape=}') # hr_feat.shape=torch.Size([1, 128, 128, 128]), lr_feat.shape=torch.Size([1, 256, 128, 128])
|
1388 |
|
|
|
1403 |
deconv2_B_C_H_W = deconv2.view(deconv2.shape[0], int(deconv2.shape[1]**0.5), int(deconv2.shape[1]**0.5), 64).permute(0, 3, 1, 2)
|
1404 |
# print(f'3.{deconv2_B_C_H_W.shape=}')
|
1405 |
|
1406 |
+
_, deconv2_B_C_H_W, lr_feat = self.densefusion3(hr_feat=deconv2_B_C_H_W, lr_feat=deconv1_B_C_H_W) # 1, 64, 256, 256 & 1, 128, 128, 128
|
1407 |
|
1408 |
# print('*'*5, f'3.{deconv2_B_C_H_W.shape=}, {lr_feat.shape=}')
|
1409 |
|
test_shadow.py
CHANGED
@@ -9,10 +9,8 @@ from torch.utils.data import DataLoader
|
|
9 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
import torch.nn.functional as F
|
11 |
import random
|
12 |
-
# from utils.loader import get_validation_data
|
13 |
from utils.loader import get_test_data
|
14 |
import utils
|
15 |
-
import cv2
|
16 |
import torch.distributed as dist
|
17 |
from skimage.metrics import peak_signal_noise_ratio as psnr_loss
|
18 |
from skimage.metrics import structural_similarity as ssim_loss
|
@@ -21,10 +19,9 @@ parser.add_argument('--input_dir', default='test_dir',
|
|
21 |
type=str, help='Directory of validation images')
|
22 |
parser.add_argument('--result_dir', default='./output_dir',
|
23 |
type=str, help='Directory for results')
|
24 |
-
parser.add_argument('--weights', default='
|
25 |
,type=str, help='Path to weights')
|
26 |
-
|
27 |
-
parser.add_argument('--arch', type=str, default='ShadowFormerFreq', help='archtechture')
|
28 |
parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader')
|
29 |
parser.add_argument('--save_images', action='store_true', default=False, help='Save denoised images in result directory')
|
30 |
parser.add_argument('--cal_metrics', action='store_true', default=False, help='Measure denoised images with GT')
|
@@ -51,49 +48,38 @@ class SlidingWindowInference:
|
|
51 |
self.img_multiple_of = img_multiple_of
|
52 |
|
53 |
def _pad_input(self, x, h_pad, w_pad):
|
54 |
-
"""Handle padding using reflection padding"""
|
55 |
return F.pad(x, (0, w_pad, 0, h_pad), 'reflect')
|
56 |
|
57 |
def __call__(self, model, input_, point, normal, dino_net, device):
|
58 |
-
# Save original dimensions
|
59 |
original_height, original_width = input_.shape[2], input_.shape[3]
|
60 |
-
# print(f"Original size: {original_height}x{original_width}")
|
61 |
|
62 |
-
# Calculate minimum dimensions needed (at least window_size and multiple of img_multiple_of)
|
63 |
H = max(self.window_size,
|
64 |
((original_height + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of)
|
65 |
W = max(self.window_size,
|
66 |
((original_width + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of)
|
67 |
-
# print(f"Target padded size: {H}x{W}")
|
68 |
|
69 |
-
# Calculate required padding
|
70 |
padh = H - original_height
|
71 |
padw = W - original_width
|
72 |
-
# print(f"Padding: h={padh}, w={padw}")
|
73 |
|
74 |
# Pad all inputs
|
75 |
input_pad = self._pad_input(input_, padh, padw)
|
76 |
point_pad = self._pad_input(point, padh, padw)
|
77 |
normal_pad = self._pad_input(normal, padh, padw)
|
78 |
|
79 |
-
# If image was smaller than window_size, process it as a single window
|
80 |
if original_height <= self.window_size and original_width <= self.window_size:
|
81 |
-
# print("Image smaller than window size, processing as single padded window")
|
82 |
|
83 |
-
# For DINO features
|
84 |
DINO_patch_size = 14
|
85 |
h_size = H * DINO_patch_size // 8
|
86 |
w_size = W * DINO_patch_size // 8
|
87 |
|
88 |
UpSample_window = torch.nn.UpsamplingBilinear2d(size=(h_size, w_size))
|
89 |
|
90 |
-
# Get DINO features
|
91 |
with torch.no_grad():
|
92 |
input_DINO = UpSample_window(input_pad)
|
93 |
dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True)
|
94 |
|
95 |
# Model inference
|
96 |
-
with torch.
|
97 |
restored = model(input_pad, dino_features, point_pad, normal_pad)
|
98 |
|
99 |
# Crop back to original size
|
@@ -104,7 +90,6 @@ class SlidingWindowInference:
|
|
104 |
stride = self.window_size - self.overlap
|
105 |
h_steps = (H - self.window_size + stride - 1) // stride + 1
|
106 |
w_steps = (W - self.window_size + stride - 1) // stride + 1
|
107 |
-
# print(f"Steps: h={h_steps}, w={w_steps}")
|
108 |
|
109 |
# Create output tensor and counter
|
110 |
output = torch.zeros_like(input_pad)
|
@@ -123,8 +108,6 @@ class SlidingWindowInference:
|
|
123 |
point_window = point_pad[:, :, h_start:h_end, w_start:w_end]
|
124 |
normal_window = normal_pad[:, :, h_start:h_end, w_start:w_end]
|
125 |
|
126 |
-
# print(f"Processing window at ({h_idx}, {w_idx}): {input_window.shape}")
|
127 |
-
|
128 |
# For DINO features
|
129 |
DINO_patch_size = 14
|
130 |
h_size = self.window_size * DINO_patch_size // 8
|
@@ -138,7 +121,7 @@ class SlidingWindowInference:
|
|
138 |
dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True)
|
139 |
|
140 |
# Model inference
|
141 |
-
with torch.
|
142 |
restored = model(input_window, dino_features, point_window, normal_window)
|
143 |
|
144 |
# Create weight mask for smooth transition
|
@@ -180,7 +163,7 @@ g = torch.Generator()
|
|
180 |
g.manual_seed(1234)
|
181 |
|
182 |
torch.backends.cudnn.benchmark = True
|
183 |
-
|
184 |
######### Model ###########
|
185 |
model_restoration = utils.get_arch(args)
|
186 |
model_restoration.to(device)
|
@@ -218,38 +201,19 @@ with torch.no_grad():
|
|
218 |
ssim_val_rgb_list = []
|
219 |
rmse_val_rgb_list = []
|
220 |
for ii, data_test in enumerate(tqdm(test_loader), 0):
|
221 |
-
# rgb_gt = data_test[0].numpy().squeeze().transpose((1, 2, 0))
|
222 |
rgb_noisy = data_test[1].to(device)
|
223 |
point = data_test[2].to(device)
|
224 |
normal = data_test[3].to(device)
|
225 |
filenames = data_test[4]
|
226 |
|
227 |
-
# Pad the input if not_multiple_of win_size * 8
|
228 |
-
# height, width = rgb_noisy.shape[2], rgb_noisy.shape[3]
|
229 |
-
# H, W = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of, (
|
230 |
-
# (width + img_multiple_of) // img_multiple_of) * img_multiple_of
|
231 |
|
232 |
-
# padh = H - height if height % img_multiple_of != 0 else 0
|
233 |
-
# padw = W - width if width % img_multiple_of != 0 else 0
|
234 |
-
# rgb_noisy = F.pad(rgb_noisy, (0, padw, 0, padh), 'reflect')
|
235 |
-
# point = F.pad(point, (0, padw, 0, padh), 'reflect')
|
236 |
-
# normal = F.pad(normal, (0, padw, 0, padh), 'reflect')
|
237 |
-
# print(f'{rgb_noisy.shape=} {point.shape=} {normal.shape=}')
|
238 |
-
# UpSample_val = nn.UpsamplingBilinear2d(
|
239 |
-
# size=((int)(rgb_noisy.shape[2] * (DINO_patch_size / 8)),
|
240 |
-
# (int)(rgb_noisy.shape[3] * (DINO_patch_size / 8))))
|
241 |
-
# with torch.cuda.amp.autocast():
|
242 |
-
# # DINO_V2
|
243 |
-
# input_DINO = UpSample_val(rgb_noisy)
|
244 |
-
# dino_mat_features = DINO_Net.module.get_intermediate_layers(input_DINO, 4, True)
|
245 |
-
# rgb_restored = model_restoration(rgb_noisy, dino_mat_features, point, normal)
|
246 |
sliding_window = SlidingWindowInference(
|
247 |
-
window_size=512,
|
248 |
-
overlap=64,
|
249 |
img_multiple_of=8 * args.win_size
|
250 |
)
|
251 |
|
252 |
-
with torch.
|
253 |
rgb_restored = sliding_window(
|
254 |
model=model_restoration,
|
255 |
input_=rgb_noisy,
|
@@ -261,7 +225,6 @@ with torch.no_grad():
|
|
261 |
|
262 |
|
263 |
rgb_restored = torch.clamp(rgb_restored, 0.0, 1.0)
|
264 |
-
# rgb_restored = rgb_restored[:, : ,:height, :width]
|
265 |
rgb_restored = torch.clamp(rgb_restored, 0, 1).cpu().numpy().squeeze().transpose((1, 2, 0))
|
266 |
|
267 |
|
|
|
9 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
import torch.nn.functional as F
|
11 |
import random
|
|
|
12 |
from utils.loader import get_test_data
|
13 |
import utils
|
|
|
14 |
import torch.distributed as dist
|
15 |
from skimage.metrics import peak_signal_noise_ratio as psnr_loss
|
16 |
from skimage.metrics import structural_similarity as ssim_loss
|
|
|
19 |
type=str, help='Directory of validation images')
|
20 |
parser.add_argument('--result_dir', default='./output_dir',
|
21 |
type=str, help='Directory for results')
|
22 |
+
parser.add_argument('--weights', default='best_model_densefusion.pth'
|
23 |
,type=str, help='Path to weights')
|
24 |
+
parser.add_argument('--arch', type=str, default='DenseSR', help='archtechture')
|
|
|
25 |
parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader')
|
26 |
parser.add_argument('--save_images', action='store_true', default=False, help='Save denoised images in result directory')
|
27 |
parser.add_argument('--cal_metrics', action='store_true', default=False, help='Measure denoised images with GT')
|
|
|
48 |
self.img_multiple_of = img_multiple_of
|
49 |
|
50 |
def _pad_input(self, x, h_pad, w_pad):
|
|
|
51 |
return F.pad(x, (0, w_pad, 0, h_pad), 'reflect')
|
52 |
|
53 |
def __call__(self, model, input_, point, normal, dino_net, device):
|
|
|
54 |
original_height, original_width = input_.shape[2], input_.shape[3]
|
|
|
55 |
|
|
|
56 |
H = max(self.window_size,
|
57 |
((original_height + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of)
|
58 |
W = max(self.window_size,
|
59 |
((original_width + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of)
|
|
|
60 |
|
|
|
61 |
padh = H - original_height
|
62 |
padw = W - original_width
|
|
|
63 |
|
64 |
# Pad all inputs
|
65 |
input_pad = self._pad_input(input_, padh, padw)
|
66 |
point_pad = self._pad_input(point, padh, padw)
|
67 |
normal_pad = self._pad_input(normal, padh, padw)
|
68 |
|
|
|
69 |
if original_height <= self.window_size and original_width <= self.window_size:
|
|
|
70 |
|
|
|
71 |
DINO_patch_size = 14
|
72 |
h_size = H * DINO_patch_size // 8
|
73 |
w_size = W * DINO_patch_size // 8
|
74 |
|
75 |
UpSample_window = torch.nn.UpsamplingBilinear2d(size=(h_size, w_size))
|
76 |
|
|
|
77 |
with torch.no_grad():
|
78 |
input_DINO = UpSample_window(input_pad)
|
79 |
dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True)
|
80 |
|
81 |
# Model inference
|
82 |
+
with torch.amp.autocast(device_type='cuda'):
|
83 |
restored = model(input_pad, dino_features, point_pad, normal_pad)
|
84 |
|
85 |
# Crop back to original size
|
|
|
90 |
stride = self.window_size - self.overlap
|
91 |
h_steps = (H - self.window_size + stride - 1) // stride + 1
|
92 |
w_steps = (W - self.window_size + stride - 1) // stride + 1
|
|
|
93 |
|
94 |
# Create output tensor and counter
|
95 |
output = torch.zeros_like(input_pad)
|
|
|
108 |
point_window = point_pad[:, :, h_start:h_end, w_start:w_end]
|
109 |
normal_window = normal_pad[:, :, h_start:h_end, w_start:w_end]
|
110 |
|
|
|
|
|
111 |
# For DINO features
|
112 |
DINO_patch_size = 14
|
113 |
h_size = self.window_size * DINO_patch_size // 8
|
|
|
121 |
dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True)
|
122 |
|
123 |
# Model inference
|
124 |
+
with torch.amp.autocast(device_type='cuda'):
|
125 |
restored = model(input_window, dino_features, point_window, normal_window)
|
126 |
|
127 |
# Create weight mask for smooth transition
|
|
|
163 |
g.manual_seed(1234)
|
164 |
|
165 |
torch.backends.cudnn.benchmark = True
|
166 |
+
|
167 |
######### Model ###########
|
168 |
model_restoration = utils.get_arch(args)
|
169 |
model_restoration.to(device)
|
|
|
201 |
ssim_val_rgb_list = []
|
202 |
rmse_val_rgb_list = []
|
203 |
for ii, data_test in enumerate(tqdm(test_loader), 0):
|
|
|
204 |
rgb_noisy = data_test[1].to(device)
|
205 |
point = data_test[2].to(device)
|
206 |
normal = data_test[3].to(device)
|
207 |
filenames = data_test[4]
|
208 |
|
|
|
|
|
|
|
|
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
sliding_window = SlidingWindowInference(
|
211 |
+
window_size=512,
|
212 |
+
overlap=64,
|
213 |
img_multiple_of=8 * args.win_size
|
214 |
)
|
215 |
|
216 |
+
with torch.amp.autocast(device_type='cuda'):
|
217 |
rgb_restored = sliding_window(
|
218 |
model=model_restoration,
|
219 |
input_=rgb_noisy,
|
|
|
225 |
|
226 |
|
227 |
rgb_restored = torch.clamp(rgb_restored, 0.0, 1.0)
|
|
|
228 |
rgb_restored = torch.clamp(rgb_restored, 0, 1).cpu().numpy().squeeze().transpose((1, 2, 0))
|
229 |
|
230 |
|
utils/model_utils.py
CHANGED
@@ -56,7 +56,7 @@ def load_optim(optimizer, weights):
|
|
56 |
return lr
|
57 |
|
58 |
def get_arch(opt):
|
59 |
-
from model import ShadowFormer,
|
60 |
arch = opt.arch
|
61 |
|
62 |
print('You choose '+arch+'...')
|
@@ -64,8 +64,8 @@ def get_arch(opt):
|
|
64 |
model_restoration = ShadowFormer(img_size=opt.train_ps,embed_dim=opt.embed_dim,
|
65 |
win_size=opt.win_size,token_projection=opt.token_projection,
|
66 |
token_mlp=opt.token_mlp)
|
67 |
-
elif arch == '
|
68 |
-
model_restoration =
|
69 |
win_size=opt.win_size,token_projection=opt.token_projection,
|
70 |
token_mlp=opt.token_mlp)
|
71 |
else:
|
|
|
56 |
return lr
|
57 |
|
58 |
def get_arch(opt):
|
59 |
+
from model import ShadowFormer, DenseSR
|
60 |
arch = opt.arch
|
61 |
|
62 |
print('You choose '+arch+'...')
|
|
|
64 |
model_restoration = ShadowFormer(img_size=opt.train_ps,embed_dim=opt.embed_dim,
|
65 |
win_size=opt.win_size,token_projection=opt.token_projection,
|
66 |
token_mlp=opt.token_mlp)
|
67 |
+
elif arch == 'DenseSR':
|
68 |
+
model_restoration = DenseSR(img_size=opt.train_ps,embed_dim=opt.embed_dim,
|
69 |
win_size=opt.win_size,token_projection=opt.token_projection,
|
70 |
token_mlp=opt.token_mlp)
|
71 |
else:
|