Spaces:
Runtime error
Runtime error
init
Browse files- app.py +113 -0
- images/0801x4.png +0 -0
- images/0804x4.png +0 -0
- images/0809x4.png +0 -0
- images/lion.jpg +0 -0
- images/logo.png +0 -0
- requirements.txt +27 -0
- sam_diffsr/configs/base/config_base.yaml +41 -0
- sam_diffsr/configs/base/diffsr_base.yaml +41 -0
- sam_diffsr/configs/base/sr_base.yaml +11 -0
- sam_diffsr/configs/data/df2k4x.yaml +11 -0
- sam_diffsr/configs/data/df2k4x_sam.yaml +11 -0
- sam_diffsr/configs/diffsr_df2k4x.yaml +18 -0
- sam_diffsr/configs/rrdb/df2k4x_pretrain.yaml +14 -0
- sam_diffsr/configs/sam/sam_diffsr_df2k4x.yaml +26 -0
- sam_diffsr/models_sr/__init__.py +0 -0
- sam_diffsr/models_sr/commons.py +317 -0
- sam_diffsr/models_sr/diffsr_modules.py +177 -0
- sam_diffsr/models_sr/diffusion.py +291 -0
- sam_diffsr/models_sr/diffusion_sam.py +90 -0
- sam_diffsr/models_sr/module_util.py +58 -0
- sam_diffsr/tasks/__init__.py +0 -0
- sam_diffsr/tasks/infer.py +81 -0
- sam_diffsr/tasks/rrdb.py +68 -0
- sam_diffsr/tasks/rrdb_sam.py +49 -0
- sam_diffsr/tasks/srdiff.py +76 -0
- sam_diffsr/tasks/srdiff_df2k.py +119 -0
- sam_diffsr/tasks/srdiff_df2k_sam.py +211 -0
- sam_diffsr/tasks/trainer.py +346 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709283169.wangchengchengdeMacBook-Pro.local.99018.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709284054.wangchengchengdeMacBook-Pro.local.99188.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709284076.wangchengchengdeMacBook-Pro.local.99198.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709284101.wangchengchengdeMacBook-Pro.local.99211.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709284193.wangchengchengdeMacBook-Pro.local.99233.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709284415.wangchengchengdeMacBook-Pro.local.99289.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709284460.wangchengchengdeMacBook-Pro.local.99308.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709284491.wangchengchengdeMacBook-Pro.local.99315.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709285127.wangchengchengdeMacBook-Pro.local.785.0 +3 -0
- sam_diffsr/tb_logs/events.out.tfevents.1709285146.wangchengchengdeMacBook-Pro.local.901.0 +3 -0
- sam_diffsr/tools/caculate_iqa.py +136 -0
- sam_diffsr/tools/visualize_sam_mask.py +20 -0
- sam_diffsr/utils_sr/__init__.py +0 -0
- sam_diffsr/utils_sr/dataset.py +50 -0
- sam_diffsr/utils_sr/hparams.py +157 -0
- sam_diffsr/utils_sr/indexed_datasets.py +72 -0
- sam_diffsr/utils_sr/matlab_resize.py +181 -0
- sam_diffsr/utils_sr/plt_img.py +109 -0
- sam_diffsr/utils_sr/sr_utils.py +171 -0
- sam_diffsr/utils_sr/utils.py +269 -0
- sam_diffsr/weight/model_ckpt_steps_400000.ckpt +3 -0
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from collections import OrderedDict
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import os
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
from sam_diffsr.utils_sr.hparams import set_hparams, hparams
|
14 |
+
from sam_diffsr.utils_sr.matlab_resize import imresize
|
15 |
+
|
16 |
+
|
17 |
+
def get_img_data(img_PIL, hparams, sr_scale=4):
|
18 |
+
img_lr = img_PIL.convert('RGB')
|
19 |
+
img_lr = np.uint8(np.asarray(img_lr))
|
20 |
+
|
21 |
+
h, w, c = img_lr.shape
|
22 |
+
h, w = h * sr_scale, w * sr_scale
|
23 |
+
h = h - h % (sr_scale * 2)
|
24 |
+
w = w - w % (sr_scale * 2)
|
25 |
+
h_l = h // sr_scale
|
26 |
+
w_l = w // sr_scale
|
27 |
+
|
28 |
+
img_lr = img_lr[:h_l, :w_l]
|
29 |
+
|
30 |
+
to_tensor_norm = transforms.Compose([
|
31 |
+
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
32 |
+
])
|
33 |
+
|
34 |
+
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
|
35 |
+
img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
|
36 |
+
|
37 |
+
img_lr = torch.unsqueeze(img_lr, dim=0)
|
38 |
+
img_lr_up = torch.unsqueeze(img_lr_up, dim=0)
|
39 |
+
|
40 |
+
return img_lr, img_lr_up
|
41 |
+
|
42 |
+
|
43 |
+
def load_checkpoint(model, ckpt_path):
|
44 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
45 |
+
print(f'loding check from: {ckpt_path}')
|
46 |
+
stat_dict = checkpoint['state_dict']['model']
|
47 |
+
|
48 |
+
new_state_dict = OrderedDict()
|
49 |
+
for k, v in stat_dict.items():
|
50 |
+
if k[:7] == 'module.':
|
51 |
+
k = k[7:] # 去掉 `module.`
|
52 |
+
new_state_dict[k] = v
|
53 |
+
|
54 |
+
model.load_state_dict(new_state_dict)
|
55 |
+
model.cuda()
|
56 |
+
del checkpoint
|
57 |
+
torch.cuda.empty_cache()
|
58 |
+
|
59 |
+
|
60 |
+
def model_init(ckpt_path):
|
61 |
+
set_hparams()
|
62 |
+
|
63 |
+
from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer
|
64 |
+
|
65 |
+
trainer = trainer()
|
66 |
+
|
67 |
+
trainer.build_model()
|
68 |
+
load_checkpoint(trainer.model, ckpt_path)
|
69 |
+
|
70 |
+
torch.backends.cudnn.benchmark = False
|
71 |
+
|
72 |
+
return trainer
|
73 |
+
|
74 |
+
|
75 |
+
def image_infer(img_PIL):
|
76 |
+
with torch.no_grad():
|
77 |
+
trainer.model.eval()
|
78 |
+
img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4)
|
79 |
+
|
80 |
+
img_lr = img_lr.to('cuda')
|
81 |
+
img_lr_up = img_lr_up.to('cuda')
|
82 |
+
|
83 |
+
img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
|
84 |
+
|
85 |
+
img_sr = img_sr.clamp(-1, 1)
|
86 |
+
img_sr = trainer.tensor2img(img_sr)[0]
|
87 |
+
img_sr = Image.fromarray(img_sr)
|
88 |
+
|
89 |
+
return img_sr
|
90 |
+
|
91 |
+
|
92 |
+
# cheetah = os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg")
|
93 |
+
|
94 |
+
root_path = os.path.dirname(__file__)
|
95 |
+
|
96 |
+
cheetah = os.path.join(root_path, "images/lion.jpg")
|
97 |
+
print(cheetah)
|
98 |
+
|
99 |
+
demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image",
|
100 |
+
# flagging_options=["blurry", "incorrect", "other"],
|
101 |
+
examples=[
|
102 |
+
os.path.join(root_path, "images/0801x4.png"),
|
103 |
+
os.path.join(root_path, "images/0809x4.png"),
|
104 |
+
os.path.join(root_path, "images/0809x4.png"),
|
105 |
+
]
|
106 |
+
)
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
parent_path = Path(__file__).absolute().parent
|
110 |
+
fill_root = os.path.abspath(parent_path)
|
111 |
+
ckpt_path = os.path.join(fill_root, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt')
|
112 |
+
trainer = model_init(ckpt_path)
|
113 |
+
demo.launch()
|
images/0801x4.png
ADDED
images/0804x4.png
ADDED
images/0809x4.png
ADDED
images/lion.jpg
ADDED
images/logo.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
Cython
|
4 |
+
matplotlib
|
5 |
+
tqdm
|
6 |
+
numpy
|
7 |
+
scipy
|
8 |
+
PyYAML
|
9 |
+
tensorboardX
|
10 |
+
tensorboard
|
11 |
+
scikit-learn
|
12 |
+
scikit-image
|
13 |
+
seaborn
|
14 |
+
pillow
|
15 |
+
opencv-contrib-python
|
16 |
+
einops
|
17 |
+
lpips
|
18 |
+
natsort
|
19 |
+
timm
|
20 |
+
openpyxl
|
21 |
+
kornia
|
22 |
+
xlwt==1.3.0
|
23 |
+
xlrd==1.2.0
|
24 |
+
pyiqa
|
25 |
+
rotary_embedding_torch
|
26 |
+
opencv-python>=4.8.0.76
|
27 |
+
opencv-python-headless>=4.5.5.64
|
sam_diffsr/configs/base/config_base.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# task
|
2 |
+
binary_data_dir: ''
|
3 |
+
work_dir: '' # experiment directory.
|
4 |
+
infer: false # infer
|
5 |
+
seed: 1234
|
6 |
+
debug: false
|
7 |
+
save_codes:
|
8 |
+
- configs
|
9 |
+
- models_sr
|
10 |
+
- tasks
|
11 |
+
- utils_sr
|
12 |
+
|
13 |
+
#############
|
14 |
+
# dataset
|
15 |
+
#############
|
16 |
+
ds_workers: 1
|
17 |
+
endless: false
|
18 |
+
|
19 |
+
#########
|
20 |
+
# train and eval
|
21 |
+
#########
|
22 |
+
print_nan_grads: false
|
23 |
+
load_ckpt: ''
|
24 |
+
save_best: true
|
25 |
+
num_ckpt_keep: 100
|
26 |
+
clip_grad_norm: 0
|
27 |
+
accumulate_grad_batches: 1
|
28 |
+
tb_log_interval: 100
|
29 |
+
num_sanity_val_steps: 5 # steps of validation at the beginning
|
30 |
+
check_val_every_n_epoch: 10
|
31 |
+
val_check_interval: 4000
|
32 |
+
valid_monitor_key: 'val_loss'
|
33 |
+
valid_monitor_mode: 'min'
|
34 |
+
max_epochs: 1000
|
35 |
+
max_updates: 600000
|
36 |
+
amp: false
|
37 |
+
batch_size: 32
|
38 |
+
eval_batch_size: 32
|
39 |
+
num_workers: 8
|
40 |
+
test_input_dir: ''
|
41 |
+
resume_from_checkpoint: 0
|
sam_diffsr/configs/base/diffsr_base.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- ./config_base.yaml
|
3 |
+
- ./sr_base.yaml
|
4 |
+
# model
|
5 |
+
beta_schedule: cosine
|
6 |
+
beta_s: 0.008
|
7 |
+
beta_end: 0.02
|
8 |
+
hidden_size: 64
|
9 |
+
timesteps: 100
|
10 |
+
res: true
|
11 |
+
res_rescale: 2.0
|
12 |
+
up_input: false
|
13 |
+
use_wn: false
|
14 |
+
gn_groups: 0
|
15 |
+
use_rrdb: true
|
16 |
+
#rrdb_num_block: 8
|
17 |
+
#rrdb_num_feat: 32
|
18 |
+
rrdb_num_block: 17
|
19 |
+
rrdb_num_feat: 64
|
20 |
+
rrdb_ckpt: ''
|
21 |
+
unet_dim_mults: 1|2|2|4
|
22 |
+
clip_input: true
|
23 |
+
denoise_fn: unet
|
24 |
+
use_attn: false
|
25 |
+
aux_l1_loss: true
|
26 |
+
aux_ssim_loss: false
|
27 |
+
aux_percep_loss: false
|
28 |
+
loss_type: l1
|
29 |
+
pred_noise: true
|
30 |
+
clip_grad_norm: 10
|
31 |
+
weight_init: false
|
32 |
+
fix_rrdb: true
|
33 |
+
|
34 |
+
# train and eval
|
35 |
+
lr: 0.0002
|
36 |
+
decay_steps: 100000
|
37 |
+
accumulate_grad_batches: 1
|
38 |
+
style_interp: false
|
39 |
+
save_intermediate: false
|
40 |
+
show_training_process: false
|
41 |
+
print_arch: false
|
sam_diffsr/configs/base/sr_base.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: ./config_base.yaml
|
2 |
+
data_interp: bicubic # bilinear | bicubic
|
3 |
+
data_augmentation: false
|
4 |
+
max_updates: 300000
|
5 |
+
batch_size: 16
|
6 |
+
eval_batch_size: 1
|
7 |
+
test_batch_size: 1
|
8 |
+
valid_steps: 3
|
9 |
+
num_sanity_val_steps: 3
|
10 |
+
test_save_png: false
|
11 |
+
gen_dir_name: ''
|
sam_diffsr/configs/data/df2k4x.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
binary_data_dir: data/train/df2k4x
|
2 |
+
patch_size: 160
|
3 |
+
crop_size: 320
|
4 |
+
thresh_size: 160
|
5 |
+
test_crop_size: [ 2040, 2040 ]
|
6 |
+
test_thresh_size: 0
|
7 |
+
valid_steps: 4
|
8 |
+
num_sanity_val_steps: 4
|
9 |
+
eval_batch_size: 1
|
10 |
+
test_batch_size: 1
|
11 |
+
sr_scale: 4
|
sam_diffsr/configs/data/df2k4x_sam.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
binary_data_dir: data/train/df2k4x_sam
|
2 |
+
patch_size: 160
|
3 |
+
crop_size: 320
|
4 |
+
thresh_size: 160
|
5 |
+
test_crop_size: [ 2040, 2040 ]
|
6 |
+
test_thresh_size: 0
|
7 |
+
valid_steps: 4
|
8 |
+
num_sanity_val_steps: 4
|
9 |
+
eval_batch_size: 1
|
10 |
+
test_batch_size: 1
|
11 |
+
sr_scale: 4
|
sam_diffsr/configs/diffsr_df2k4x.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- ./base/diffsr_base.yaml
|
3 |
+
- ./data/df2k4x.yaml
|
4 |
+
trainer_cls: tasks.srdiff_df2k.SRDiffDf2k
|
5 |
+
|
6 |
+
# model
|
7 |
+
unet_dim_mults: 1|2|3|4
|
8 |
+
decay_steps: 200000
|
9 |
+
|
10 |
+
# train and test
|
11 |
+
batch_size: 64
|
12 |
+
max_updates: 400000
|
13 |
+
|
14 |
+
sam_config:
|
15 |
+
cond_sam: False
|
16 |
+
p_losses_sam: False
|
17 |
+
p_sample_sam: False
|
18 |
+
q_sample_sam: False
|
sam_diffsr/configs/rrdb/df2k4x_pretrain.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- ../sr_base.yaml
|
3 |
+
- ../df2k4x.yaml
|
4 |
+
trainer_cls: tasks.rrdb.RRDBDf2kTask
|
5 |
+
# model
|
6 |
+
hidden_size: 64
|
7 |
+
lr: 0.0002
|
8 |
+
num_block: 17
|
9 |
+
|
10 |
+
# train and eval
|
11 |
+
max_updates: 100000
|
12 |
+
batch_size: 64
|
13 |
+
eval_batch_size: 1
|
14 |
+
valid_steps: 3
|
sam_diffsr/configs/sam/sam_diffsr_df2k4x.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- ../base/diffsr_base.yaml
|
3 |
+
- ../data/df2k4x_sam.yaml
|
4 |
+
trainer_cls: tasks.srdiff_df2k_sam.SRDiffDf2k_sam
|
5 |
+
|
6 |
+
# model
|
7 |
+
unet_dim_mults: 1|2|3|4
|
8 |
+
decay_steps: 200000
|
9 |
+
|
10 |
+
# train and test
|
11 |
+
batch_size: 64
|
12 |
+
max_updates: 400000
|
13 |
+
|
14 |
+
rrdb_num_feat: 64
|
15 |
+
|
16 |
+
sam_config:
|
17 |
+
cond_sam: False
|
18 |
+
p_losses_sam: True
|
19 |
+
mask_coefficient: True
|
20 |
+
|
21 |
+
sam_data_config:
|
22 |
+
all_same_mask_to_zero: False
|
23 |
+
normalize_01: False
|
24 |
+
normalize_11: False
|
25 |
+
|
26 |
+
num_sanity_val_steps: 2
|
sam_diffsr/models_sr/__init__.py
ADDED
File without changes
|
sam_diffsr/models_sr/commons.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class Residual(nn.Module):
|
10 |
+
def __init__(self, fn):
|
11 |
+
super().__init__()
|
12 |
+
self.fn = fn
|
13 |
+
|
14 |
+
def forward(self, x, *args, **kwargs):
|
15 |
+
return self.fn(x, *args, **kwargs) + x
|
16 |
+
|
17 |
+
|
18 |
+
class SinusoidalPosEmb(nn.Module):
|
19 |
+
def __init__(self, dim):
|
20 |
+
super().__init__()
|
21 |
+
self.dim = dim
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
device = x.device
|
25 |
+
half_dim = self.dim // 2
|
26 |
+
emb = math.log(10000) / (half_dim - 1)
|
27 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
28 |
+
emb = x[:, None] * emb[None, :]
|
29 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
30 |
+
return emb
|
31 |
+
|
32 |
+
|
33 |
+
class Mish(nn.Module):
|
34 |
+
def forward(self, x):
|
35 |
+
return x * torch.tanh(F.softplus(x))
|
36 |
+
|
37 |
+
|
38 |
+
class Rezero(nn.Module):
|
39 |
+
def __init__(self, fn):
|
40 |
+
super().__init__()
|
41 |
+
self.fn = fn
|
42 |
+
self.g = nn.Parameter(torch.zeros(1))
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return self.fn(x) * self.g
|
46 |
+
|
47 |
+
|
48 |
+
# building block modules
|
49 |
+
|
50 |
+
class Block(nn.Module):
|
51 |
+
def __init__(self, dim, dim_out, groups=8):
|
52 |
+
super().__init__()
|
53 |
+
if groups == 0:
|
54 |
+
self.block = nn.Sequential(
|
55 |
+
nn.ReflectionPad2d(1),
|
56 |
+
nn.Conv2d(dim, dim_out, 3),
|
57 |
+
Mish()
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
self.block = nn.Sequential(
|
61 |
+
nn.ReflectionPad2d(1),
|
62 |
+
nn.Conv2d(dim, dim_out, 3),
|
63 |
+
nn.GroupNorm(groups, dim_out),
|
64 |
+
Mish()
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return self.block(x)
|
69 |
+
|
70 |
+
|
71 |
+
class ResnetBlock(nn.Module):
|
72 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=0, groups=8):
|
73 |
+
super().__init__()
|
74 |
+
if time_emb_dim > 0:
|
75 |
+
self.mlp = nn.Sequential(
|
76 |
+
Mish(),
|
77 |
+
nn.Linear(time_emb_dim, dim_out)
|
78 |
+
)
|
79 |
+
|
80 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
81 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
82 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
83 |
+
|
84 |
+
def forward(self, x, time_emb=None, cond=None):
|
85 |
+
h = self.block1(x)
|
86 |
+
if time_emb is not None:
|
87 |
+
h += self.mlp(time_emb)[:, :, None, None]
|
88 |
+
if cond is not None:
|
89 |
+
h += cond
|
90 |
+
h = self.block2(h)
|
91 |
+
return h + self.res_conv(x)
|
92 |
+
|
93 |
+
|
94 |
+
class Upsample(nn.Module):
|
95 |
+
def __init__(self, dim):
|
96 |
+
super().__init__()
|
97 |
+
self.conv = nn.Sequential(
|
98 |
+
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
return self.conv(x)
|
103 |
+
|
104 |
+
|
105 |
+
class Downsample(nn.Module):
|
106 |
+
def __init__(self, dim):
|
107 |
+
super().__init__()
|
108 |
+
self.conv = nn.Sequential(
|
109 |
+
nn.ReflectionPad2d(1),
|
110 |
+
nn.Conv2d(dim, dim, 3, 2),
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
return self.conv(x)
|
115 |
+
|
116 |
+
|
117 |
+
class LinearAttention(nn.Module):
|
118 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
119 |
+
super().__init__()
|
120 |
+
self.heads = heads
|
121 |
+
hidden_dim = dim_head * heads
|
122 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
123 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
b, c, h, w = x.shape
|
127 |
+
qkv = self.to_qkv(x)
|
128 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
|
129 |
+
k = k.softmax(dim=-1)
|
130 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
131 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
132 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
133 |
+
return self.to_out(out)
|
134 |
+
|
135 |
+
|
136 |
+
class MultiheadAttention(nn.Module):
|
137 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
138 |
+
add_bias_kv=False, add_zero_attn=False):
|
139 |
+
super().__init__()
|
140 |
+
self.embed_dim = embed_dim
|
141 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
142 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
143 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
144 |
+
|
145 |
+
self.num_heads = num_heads
|
146 |
+
self.dropout = dropout
|
147 |
+
self.head_dim = embed_dim // num_heads
|
148 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
149 |
+
self.scaling = self.head_dim ** -0.5
|
150 |
+
if self.qkv_same_dim:
|
151 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
152 |
+
else:
|
153 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
154 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
155 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
156 |
+
|
157 |
+
if bias:
|
158 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
159 |
+
else:
|
160 |
+
self.register_parameter('in_proj_bias', None)
|
161 |
+
|
162 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
163 |
+
|
164 |
+
if add_bias_kv:
|
165 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
166 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
167 |
+
else:
|
168 |
+
self.bias_k = self.bias_v = None
|
169 |
+
|
170 |
+
self.add_zero_attn = add_zero_attn
|
171 |
+
|
172 |
+
self.reset_parameters()
|
173 |
+
|
174 |
+
self.enable_torch_version = False
|
175 |
+
if hasattr(F, "multi_head_attention_forward"):
|
176 |
+
self.enable_torch_version = True
|
177 |
+
else:
|
178 |
+
self.enable_torch_version = False
|
179 |
+
self.last_attn_probs = None
|
180 |
+
|
181 |
+
def reset_parameters(self):
|
182 |
+
if self.qkv_same_dim:
|
183 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
184 |
+
else:
|
185 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
186 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
187 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
188 |
+
|
189 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
190 |
+
if self.in_proj_bias is not None:
|
191 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
192 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
193 |
+
if self.bias_k is not None:
|
194 |
+
nn.init.xavier_normal_(self.bias_k)
|
195 |
+
if self.bias_v is not None:
|
196 |
+
nn.init.xavier_normal_(self.bias_v)
|
197 |
+
|
198 |
+
def forward(
|
199 |
+
self,
|
200 |
+
query, key, value,
|
201 |
+
key_padding_mask=None,
|
202 |
+
need_weights=True,
|
203 |
+
attn_mask=None,
|
204 |
+
before_softmax=False,
|
205 |
+
need_head_weights=False,
|
206 |
+
):
|
207 |
+
"""Input shape: [B, T, C]
|
208 |
+
|
209 |
+
Args:
|
210 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
211 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
212 |
+
padding elements are indicated by 1s.
|
213 |
+
need_weights (bool, optional): return the attention weights,
|
214 |
+
averaged over heads (default: False).
|
215 |
+
attn_mask (ByteTensor, optional): typically used to
|
216 |
+
implement causal attention, where the mask prevents the
|
217 |
+
attention from looking forward in time (default: None).
|
218 |
+
before_softmax (bool, optional): return the raw attention
|
219 |
+
weights and values before the attention softmax.
|
220 |
+
need_head_weights (bool, optional): return the attention
|
221 |
+
weights for each head. Implies *need_weights*. Default:
|
222 |
+
return the average attention weights over all heads.
|
223 |
+
"""
|
224 |
+
if need_head_weights:
|
225 |
+
need_weights = True
|
226 |
+
query = query.transpose(0, 1)
|
227 |
+
key = key.transpose(0, 1)
|
228 |
+
value = value.transpose(0, 1)
|
229 |
+
tgt_len, bsz, embed_dim = query.size()
|
230 |
+
assert embed_dim == self.embed_dim
|
231 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
232 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
233 |
+
query, key, value, self.embed_dim, self.num_heads,
|
234 |
+
self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v,
|
235 |
+
self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias,
|
236 |
+
self.training, key_padding_mask, need_weights, attn_mask)
|
237 |
+
attn_output = attn_output.transpose(0, 1)
|
238 |
+
return attn_output, attn_output_weights
|
239 |
+
|
240 |
+
def in_proj_qkv(self, query):
|
241 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
242 |
+
|
243 |
+
def in_proj_q(self, query):
|
244 |
+
if self.qkv_same_dim:
|
245 |
+
return self._in_proj(query, end=self.embed_dim)
|
246 |
+
else:
|
247 |
+
bias = self.in_proj_bias
|
248 |
+
if bias is not None:
|
249 |
+
bias = bias[:self.embed_dim]
|
250 |
+
return F.linear(query, self.q_proj_weight, bias)
|
251 |
+
|
252 |
+
def in_proj_k(self, key):
|
253 |
+
if self.qkv_same_dim:
|
254 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
255 |
+
else:
|
256 |
+
weight = self.k_proj_weight
|
257 |
+
bias = self.in_proj_bias
|
258 |
+
if bias is not None:
|
259 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
260 |
+
return F.linear(key, weight, bias)
|
261 |
+
|
262 |
+
def in_proj_v(self, value):
|
263 |
+
if self.qkv_same_dim:
|
264 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
265 |
+
else:
|
266 |
+
weight = self.v_proj_weight
|
267 |
+
bias = self.in_proj_bias
|
268 |
+
if bias is not None:
|
269 |
+
bias = bias[2 * self.embed_dim:]
|
270 |
+
return F.linear(value, weight, bias)
|
271 |
+
|
272 |
+
def _in_proj(self, input, start=0, end=None):
|
273 |
+
weight = self.in_proj_weight
|
274 |
+
bias = self.in_proj_bias
|
275 |
+
weight = weight[start:end, :]
|
276 |
+
if bias is not None:
|
277 |
+
bias = bias[start:end]
|
278 |
+
return F.linear(input, weight, bias)
|
279 |
+
|
280 |
+
|
281 |
+
class ResidualDenseBlock_5C(nn.Module):
|
282 |
+
def __init__(self, nf=64, gc=32, bias=True):
|
283 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
284 |
+
# gc: growth channel, i.e. intermediate channels
|
285 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
286 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
287 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
288 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
289 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
290 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
291 |
+
|
292 |
+
# initialization
|
293 |
+
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
294 |
+
|
295 |
+
def forward(self, x):
|
296 |
+
x1 = self.lrelu(self.conv1(x))
|
297 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
298 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
299 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
300 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
301 |
+
return x5 * 0.2 + x
|
302 |
+
|
303 |
+
|
304 |
+
class RRDB(nn.Module):
|
305 |
+
'''Residual in Residual Dense Block'''
|
306 |
+
|
307 |
+
def __init__(self, nf, gc=32):
|
308 |
+
super(RRDB, self).__init__()
|
309 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
310 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
311 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
312 |
+
|
313 |
+
def forward(self, x):
|
314 |
+
out = self.RDB1(x)
|
315 |
+
out = self.RDB2(out)
|
316 |
+
out = self.RDB3(out)
|
317 |
+
return out * 0.2 + x
|
sam_diffsr/models_sr/diffsr_modules.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from sam_diffsr.utils_sr.hparams import hparams
|
8 |
+
from .commons import Mish, SinusoidalPosEmb, RRDB, Residual, Rezero, LinearAttention
|
9 |
+
from .commons import ResnetBlock, Upsample, Block, Downsample
|
10 |
+
from .module_util import make_layer, initialize_weights
|
11 |
+
|
12 |
+
|
13 |
+
class RRDBNet(nn.Module):
|
14 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
15 |
+
super(RRDBNet, self).__init__()
|
16 |
+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
17 |
+
|
18 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
19 |
+
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
20 |
+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
21 |
+
#### upsampling
|
22 |
+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
23 |
+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
24 |
+
if hparams['sr_scale'] == 8:
|
25 |
+
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
26 |
+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
27 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
28 |
+
|
29 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
|
30 |
+
|
31 |
+
def forward(self, x, get_fea=False):
|
32 |
+
feas = []
|
33 |
+
x = (x + 1) / 2
|
34 |
+
fea_first = fea = self.conv_first(x)
|
35 |
+
for l in self.RRDB_trunk:
|
36 |
+
fea = l(fea)
|
37 |
+
feas.append(fea)
|
38 |
+
trunk = self.trunk_conv(fea)
|
39 |
+
fea = fea_first + trunk
|
40 |
+
feas.append(fea)
|
41 |
+
|
42 |
+
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
43 |
+
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
44 |
+
if hparams['sr_scale'] == 8:
|
45 |
+
fea = self.lrelu(self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
46 |
+
fea_hr = self.HRconv(fea)
|
47 |
+
out = self.conv_last(self.lrelu(fea_hr))
|
48 |
+
out = out.clamp(0, 1)
|
49 |
+
out = out * 2 - 1
|
50 |
+
if get_fea:
|
51 |
+
return out, feas
|
52 |
+
else:
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class Unet(nn.Module):
|
57 |
+
def __init__(self, dim, out_dim=None, dim_mults=(1, 2, 4, 8), cond_dim=32):
|
58 |
+
super().__init__()
|
59 |
+
dims = [3, *map(lambda m: dim * m, dim_mults)]
|
60 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
61 |
+
groups = 0
|
62 |
+
|
63 |
+
self.sam_config = hparams['sam_config']
|
64 |
+
|
65 |
+
cond_proj_in = cond_dim * ((hparams['rrdb_num_block'] + 1) // 3)
|
66 |
+
if self.sam_config['cond_sam']:
|
67 |
+
# cond_proj_in += 1
|
68 |
+
self.sam_conv = nn.Sequential(
|
69 |
+
nn.Conv2d(dim + 1, dim, 1, 1, 0, bias=True),
|
70 |
+
nn.Conv2d(dim, dim, 1, 1, 0, bias=True),
|
71 |
+
nn.Conv2d(dim, dim, 1, 1, 0, bias=True)
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
self.sam_conv = None
|
75 |
+
|
76 |
+
self.cond_proj = nn.ConvTranspose2d(cond_proj_in, dim, hparams['sr_scale'] * 2, hparams['sr_scale'],
|
77 |
+
hparams['sr_scale'] // 2)
|
78 |
+
|
79 |
+
self.time_pos_emb = SinusoidalPosEmb(dim)
|
80 |
+
self.mlp = nn.Sequential(
|
81 |
+
nn.Linear(dim, dim * 4),
|
82 |
+
Mish(),
|
83 |
+
nn.Linear(dim * 4, dim)
|
84 |
+
)
|
85 |
+
|
86 |
+
self.downs = nn.ModuleList([])
|
87 |
+
self.ups = nn.ModuleList([])
|
88 |
+
num_resolutions = len(in_out)
|
89 |
+
|
90 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
91 |
+
is_last = ind >= (num_resolutions - 1)
|
92 |
+
|
93 |
+
self.downs.append(nn.ModuleList([
|
94 |
+
ResnetBlock(dim_in, dim_out, time_emb_dim=dim, groups=groups),
|
95 |
+
ResnetBlock(dim_out, dim_out, time_emb_dim=dim, groups=groups),
|
96 |
+
Downsample(dim_out) if not is_last else nn.Identity()
|
97 |
+
]))
|
98 |
+
|
99 |
+
mid_dim = dims[-1]
|
100 |
+
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups)
|
101 |
+
if hparams['use_attn']:
|
102 |
+
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
103 |
+
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups)
|
104 |
+
|
105 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
106 |
+
is_last = ind >= (num_resolutions - 1)
|
107 |
+
|
108 |
+
self.ups.append(nn.ModuleList([
|
109 |
+
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim, groups=groups),
|
110 |
+
ResnetBlock(dim_in, dim_in, time_emb_dim=dim, groups=groups),
|
111 |
+
Upsample(dim_in) if not is_last else nn.Identity()
|
112 |
+
]))
|
113 |
+
|
114 |
+
self.final_conv = nn.Sequential(
|
115 |
+
Block(dim, dim, groups=groups),
|
116 |
+
nn.Conv2d(dim, out_dim, 1)
|
117 |
+
)
|
118 |
+
|
119 |
+
if hparams['res'] and hparams['up_input']:
|
120 |
+
self.up_proj = nn.Sequential(
|
121 |
+
nn.ReflectionPad2d(1), nn.Conv2d(3, dim, 3),
|
122 |
+
)
|
123 |
+
if hparams['use_wn']:
|
124 |
+
self.apply_weight_norm()
|
125 |
+
if hparams['weight_init']:
|
126 |
+
self.apply(initialize_weights)
|
127 |
+
|
128 |
+
def apply_weight_norm(self):
|
129 |
+
def _apply_weight_norm(m):
|
130 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
131 |
+
torch.nn.utils.weight_norm(m)
|
132 |
+
# print(f"| Weight norm is applied to {m}.")
|
133 |
+
|
134 |
+
self.apply(_apply_weight_norm)
|
135 |
+
|
136 |
+
def forward(self, x, time, cond, img_lr_up, sam_mask=None):
|
137 |
+
t = self.time_pos_emb(time)
|
138 |
+
t = self.mlp(t)
|
139 |
+
h = []
|
140 |
+
|
141 |
+
cond = self.cond_proj(torch.cat(cond[2::3], 1))
|
142 |
+
|
143 |
+
if self.sam_config['cond_sam']:
|
144 |
+
cond = torch.cat([cond, sam_mask], 1)
|
145 |
+
cond = self.sam_conv(cond)
|
146 |
+
|
147 |
+
for i, (resnet, resnet2, downsample) in enumerate(self.downs):
|
148 |
+
x = resnet(x, t)
|
149 |
+
x = resnet2(x, t)
|
150 |
+
if i == 0:
|
151 |
+
x = x + cond
|
152 |
+
if hparams['res'] and hparams['up_input']:
|
153 |
+
x = x + self.up_proj(img_lr_up)
|
154 |
+
h.append(x)
|
155 |
+
x = downsample(x)
|
156 |
+
|
157 |
+
x = self.mid_block1(x, t)
|
158 |
+
if hparams['use_attn']:
|
159 |
+
x = self.mid_attn(x)
|
160 |
+
x = self.mid_block2(x, t)
|
161 |
+
|
162 |
+
for resnet, resnet2, upsample in self.ups:
|
163 |
+
x = torch.cat((x, h.pop()), dim=1)
|
164 |
+
x = resnet(x, t)
|
165 |
+
x = resnet2(x, t)
|
166 |
+
x = upsample(x)
|
167 |
+
|
168 |
+
return self.final_conv(x)
|
169 |
+
|
170 |
+
def make_generation_fast_(self):
|
171 |
+
def remove_weight_norm(m):
|
172 |
+
try:
|
173 |
+
nn.utils.remove_weight_norm(m)
|
174 |
+
except ValueError: # this module didn't have weight norm
|
175 |
+
return
|
176 |
+
|
177 |
+
self.apply(remove_weight_norm)
|
sam_diffsr/models_sr/diffusion.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from sam_diffsr.utils_sr.plt_img import plt_tensor_img
|
9 |
+
from .module_util import default
|
10 |
+
from sam_diffsr.utils_sr.sr_utils import SSIM, PerceptualLoss
|
11 |
+
from sam_diffsr.utils_sr.hparams import hparams
|
12 |
+
|
13 |
+
|
14 |
+
# gaussian diffusion trainer class
|
15 |
+
def extract(a, t, x_shape):
|
16 |
+
b, *_ = t.shape
|
17 |
+
out = a.gather(-1, t)
|
18 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
19 |
+
|
20 |
+
|
21 |
+
def noise_like(shape, device, repeat=False):
|
22 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
23 |
+
noise = lambda: torch.randn(shape, device=device)
|
24 |
+
return repeat_noise() if repeat else noise()
|
25 |
+
|
26 |
+
|
27 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
28 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
29 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
30 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
31 |
+
return betas
|
32 |
+
|
33 |
+
|
34 |
+
def get_beta_schedule(num_diffusion_timesteps, beta_schedule='linear', beta_start=0.0001, beta_end=0.02):
|
35 |
+
if beta_schedule == 'quad':
|
36 |
+
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
|
37 |
+
elif beta_schedule == 'linear':
|
38 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
39 |
+
elif beta_schedule == 'warmup10':
|
40 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
41 |
+
elif beta_schedule == 'warmup50':
|
42 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
43 |
+
elif beta_schedule == 'const':
|
44 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
45 |
+
elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
46 |
+
betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
|
47 |
+
else:
|
48 |
+
raise NotImplementedError(beta_schedule)
|
49 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
50 |
+
return betas
|
51 |
+
|
52 |
+
|
53 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
54 |
+
"""
|
55 |
+
cosine schedule
|
56 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
57 |
+
"""
|
58 |
+
steps = timesteps + 1
|
59 |
+
x = np.linspace(0, steps, steps)
|
60 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
61 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
62 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
63 |
+
return np.clip(betas, a_min=0, a_max=0.999)
|
64 |
+
|
65 |
+
|
66 |
+
class GaussianDiffusion(nn.Module):
|
67 |
+
def __init__(self, denoise_fn, rrdb_net, timesteps=1000, loss_type='l1'):
|
68 |
+
super().__init__()
|
69 |
+
self.denoise_fn = denoise_fn
|
70 |
+
# condition net
|
71 |
+
self.rrdb = rrdb_net
|
72 |
+
self.ssim_loss = SSIM(window_size=11)
|
73 |
+
|
74 |
+
|
75 |
+
if hparams['beta_schedule'] == 'cosine':
|
76 |
+
betas = cosine_beta_schedule(timesteps, s=hparams['beta_s'])
|
77 |
+
if hparams['beta_schedule'] == 'linear':
|
78 |
+
betas = get_beta_schedule(timesteps, beta_end=hparams['beta_end'])
|
79 |
+
if hparams['res']:
|
80 |
+
betas[-1] = 0.999
|
81 |
+
|
82 |
+
alphas = 1. - betas
|
83 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
84 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
85 |
+
|
86 |
+
timesteps, = betas.shape
|
87 |
+
self.num_timesteps = int(timesteps)
|
88 |
+
self.loss_type = loss_type
|
89 |
+
|
90 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
91 |
+
|
92 |
+
self.register_buffer('betas', to_torch(betas))
|
93 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
94 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
95 |
+
|
96 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
97 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
98 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
99 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
100 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
101 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
102 |
+
|
103 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
104 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
105 |
+
|
106 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
107 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
108 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
109 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
110 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
111 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
112 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
113 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
114 |
+
self.sample_tqdm = True
|
115 |
+
|
116 |
+
self.mask_coefficient = to_torch(np.sqrt(1. - alphas_cumprod) * betas)
|
117 |
+
|
118 |
+
def q_mean_variance(self, x_start, t):
|
119 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
120 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
121 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
122 |
+
return mean, variance, log_variance
|
123 |
+
|
124 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
125 |
+
return (
|
126 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
127 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
128 |
+
)
|
129 |
+
|
130 |
+
def q_posterior(self, x_start, x_t, t):
|
131 |
+
posterior_mean = (
|
132 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
133 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
134 |
+
)
|
135 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
136 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
137 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
138 |
+
|
139 |
+
def p_mean_variance(self, x, t, noise_pred, clip_denoised: bool):
|
140 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
141 |
+
|
142 |
+
if clip_denoised:
|
143 |
+
x_recon.clamp_(-1.0, 1.0)
|
144 |
+
|
145 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
146 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
147 |
+
|
148 |
+
def forward(self, img_hr, img_lr, img_lr_up, t=None, *args, **kwargs):
|
149 |
+
x = img_hr
|
150 |
+
b, *_, device = *x.shape, x.device
|
151 |
+
t = torch.randint(0, self.num_timesteps, (b,), device=device).long() \
|
152 |
+
if t is None else torch.LongTensor([t]).repeat(b).to(device)
|
153 |
+
if hparams['use_rrdb']:
|
154 |
+
if hparams['fix_rrdb']:
|
155 |
+
self.rrdb.eval()
|
156 |
+
with torch.no_grad():
|
157 |
+
rrdb_out, cond = self.rrdb(img_lr, True)
|
158 |
+
else:
|
159 |
+
rrdb_out, cond = self.rrdb(img_lr, True)
|
160 |
+
else:
|
161 |
+
rrdb_out = img_lr_up
|
162 |
+
cond = img_lr
|
163 |
+
x = self.img2res(x, img_lr_up)
|
164 |
+
p_losses, x_tp1, noise_pred, x_t, x_t_gt, x_0 = self.p_losses(x, t, cond, img_lr_up, *args, **kwargs)
|
165 |
+
ret = {'q': p_losses}
|
166 |
+
if not hparams['fix_rrdb']:
|
167 |
+
if hparams['aux_l1_loss']:
|
168 |
+
ret['aux_l1'] = F.l1_loss(rrdb_out, img_hr)
|
169 |
+
if hparams['aux_ssim_loss']:
|
170 |
+
ret['aux_ssim'] = 1 - self.ssim_loss(rrdb_out, img_hr)
|
171 |
+
if hparams['aux_percep_loss']:
|
172 |
+
ret['aux_percep'] = self.percep_loss_fn[0](img_hr, rrdb_out)
|
173 |
+
|
174 |
+
|
175 |
+
x_tp1 = self.res2img(x_tp1, img_lr_up)
|
176 |
+
x_t = self.res2img(x_t, img_lr_up)
|
177 |
+
x_t_gt = self.res2img(x_t_gt, img_lr_up)
|
178 |
+
return ret, (x_tp1, x_t_gt, x_t), t
|
179 |
+
|
180 |
+
def p_losses(self, x_start, t, cond, img_lr_up, noise=None):
|
181 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
182 |
+
x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise)
|
183 |
+
x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise)
|
184 |
+
noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up)
|
185 |
+
x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred)
|
186 |
+
|
187 |
+
if self.loss_type == 'l1':
|
188 |
+
loss = (noise - noise_pred).abs().mean()
|
189 |
+
elif self.loss_type == 'l2':
|
190 |
+
loss = F.mse_loss(noise, noise_pred)
|
191 |
+
elif self.loss_type == 'ssim':
|
192 |
+
loss = (noise - noise_pred).abs().mean()
|
193 |
+
loss = loss + (1 - self.ssim_loss(noise, noise_pred))
|
194 |
+
else:
|
195 |
+
raise NotImplementedError()
|
196 |
+
return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred
|
197 |
+
|
198 |
+
def q_sample(self, x_start, t, noise=None):
|
199 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
200 |
+
t_cond = (t[:, None, None, None] >= 0).float()
|
201 |
+
t = t.clamp_min(0)
|
202 |
+
return (
|
203 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
204 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
205 |
+
) * t_cond + x_start * (1 - t_cond)
|
206 |
+
|
207 |
+
@torch.no_grad()
|
208 |
+
def p_sample(self, x, t, cond, img_lr_up, noise_pred=None, clip_denoised=True, repeat_noise=False):
|
209 |
+
if noise_pred is None:
|
210 |
+
noise_pred = self.denoise_fn(x, t, cond=cond, img_lr_up=img_lr_up)
|
211 |
+
b, *_, device = *x.shape, x.device
|
212 |
+
model_mean, _, model_log_variance, x0_pred = self.p_mean_variance(
|
213 |
+
x=x, t=t, noise_pred=noise_pred, clip_denoised=clip_denoised)
|
214 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
215 |
+
# no noise when t == 0
|
216 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
217 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0_pred
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def sample(self, img_lr, img_lr_up, shape, save_intermediate=False):
|
221 |
+
device = self.betas.device
|
222 |
+
b = shape[0]
|
223 |
+
if not hparams['res']:
|
224 |
+
t = torch.full((b,), self.num_timesteps - 1, device=device, dtype=torch.long)
|
225 |
+
img = self.q_sample(img_lr_up, t)
|
226 |
+
else:
|
227 |
+
img = torch.randn(shape, device=device)
|
228 |
+
if hparams['use_rrdb']:
|
229 |
+
rrdb_out, cond = self.rrdb(img_lr, True)
|
230 |
+
else:
|
231 |
+
rrdb_out = img_lr_up
|
232 |
+
cond = img_lr
|
233 |
+
it = reversed(range(0, self.num_timesteps))
|
234 |
+
if self.sample_tqdm:
|
235 |
+
it = tqdm(it, desc='sampling loop time step', total=self.num_timesteps)
|
236 |
+
images = []
|
237 |
+
for i in it:
|
238 |
+
img, x_recon = self.p_sample(
|
239 |
+
img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up)
|
240 |
+
if save_intermediate:
|
241 |
+
img_ = self.res2img(img, img_lr_up)
|
242 |
+
x_recon_ = self.res2img(x_recon, img_lr_up)
|
243 |
+
images.append((img_.cpu(), x_recon_.cpu()))
|
244 |
+
img = self.res2img(img, img_lr_up)
|
245 |
+
if save_intermediate:
|
246 |
+
return img, rrdb_out, images
|
247 |
+
else:
|
248 |
+
return img, rrdb_out
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def interpolate(self, x1, x2, img_lr, img_lr_up, t=None, lam=0.5):
|
252 |
+
b, *_, device = *x1.shape, x1.device
|
253 |
+
t = default(t, self.num_timesteps - 1)
|
254 |
+
if hparams['use_rrdb']:
|
255 |
+
rrdb_out, cond = self.rrdb(img_lr, True)
|
256 |
+
else:
|
257 |
+
cond = img_lr
|
258 |
+
|
259 |
+
assert x1.shape == x2.shape
|
260 |
+
|
261 |
+
x1 = self.img2res(x1, img_lr_up)
|
262 |
+
x2 = self.img2res(x2, img_lr_up)
|
263 |
+
|
264 |
+
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
|
265 |
+
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
|
266 |
+
|
267 |
+
img = (1 - lam) * xt1 + lam * xt2
|
268 |
+
for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
|
269 |
+
img, x_recon = self.p_sample(
|
270 |
+
img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up)
|
271 |
+
|
272 |
+
img = self.res2img(img, img_lr_up)
|
273 |
+
return img
|
274 |
+
|
275 |
+
def res2img(self, img_, img_lr_up, clip_input=None):
|
276 |
+
if clip_input is None:
|
277 |
+
clip_input = hparams['clip_input']
|
278 |
+
if hparams['res']:
|
279 |
+
if clip_input:
|
280 |
+
img_ = img_.clamp(-1, 1)
|
281 |
+
img_ = img_ / hparams['res_rescale'] + img_lr_up
|
282 |
+
return img_
|
283 |
+
|
284 |
+
def img2res(self, x, img_lr_up, clip_input=None):
|
285 |
+
if clip_input is None:
|
286 |
+
clip_input = hparams['clip_input']
|
287 |
+
if hparams['res']:
|
288 |
+
x = (x - img_lr_up) * hparams['res_rescale']
|
289 |
+
if clip_input:
|
290 |
+
x = x.clamp(-1, 1)
|
291 |
+
return x
|
sam_diffsr/models_sr/diffusion_sam.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from sam_diffsr.utils_sr.hparams import hparams
|
6 |
+
from .diffusion import GaussianDiffusion, noise_like, extract
|
7 |
+
from .module_util import default
|
8 |
+
|
9 |
+
|
10 |
+
class GaussianDiffusion_sam(GaussianDiffusion):
|
11 |
+
def __init__(self, denoise_fn, rrdb_net, timesteps=1000, loss_type='l1', sam_config=None):
|
12 |
+
super().__init__(denoise_fn, rrdb_net, timesteps, loss_type)
|
13 |
+
self.sam_config = sam_config
|
14 |
+
|
15 |
+
def p_losses(self, x_start, t, cond, img_lr_up, noise=None, sam_mask=None):
|
16 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
17 |
+
|
18 |
+
if self.sam_config['p_losses_sam']:
|
19 |
+
_sam_mask = F.interpolate(sam_mask, noise.shape[2:], mode='bilinear')
|
20 |
+
if self.sam_config.get('mask_coefficient', False):
|
21 |
+
_sam_mask *= extract(self.mask_coefficient.to(_sam_mask.device), t, x_start.shape)
|
22 |
+
noise += _sam_mask
|
23 |
+
|
24 |
+
x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise)
|
25 |
+
x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise)
|
26 |
+
noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up, sam_mask=sam_mask)
|
27 |
+
x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred, sam_mask=sam_mask)
|
28 |
+
|
29 |
+
if self.loss_type == 'l1':
|
30 |
+
loss = (noise - noise_pred).abs().mean()
|
31 |
+
elif self.loss_type == 'l2':
|
32 |
+
loss = F.mse_loss(noise, noise_pred)
|
33 |
+
elif self.loss_type == 'ssim':
|
34 |
+
loss = (noise - noise_pred).abs().mean()
|
35 |
+
loss = loss + (1 - self.ssim_loss(noise, noise_pred))
|
36 |
+
else:
|
37 |
+
raise NotImplementedError()
|
38 |
+
return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred
|
39 |
+
|
40 |
+
@torch.no_grad()
|
41 |
+
def p_sample(self, x, t, cond, img_lr_up, noise_pred=None, clip_denoised=True, repeat_noise=False, sam_mask=None):
|
42 |
+
if noise_pred is None:
|
43 |
+
noise_pred = self.denoise_fn(x, t, cond=cond, img_lr_up=img_lr_up, sam_mask=sam_mask)
|
44 |
+
b, *_, device = *x.shape, x.device
|
45 |
+
model_mean, _, model_log_variance, x0_pred = self.p_mean_variance(
|
46 |
+
x=x, t=t, noise_pred=noise_pred, clip_denoised=clip_denoised)
|
47 |
+
|
48 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
49 |
+
|
50 |
+
# no noise when t == 0
|
51 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
52 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0_pred
|
53 |
+
|
54 |
+
@torch.no_grad()
|
55 |
+
def sample(self, img_lr, img_lr_up, shape, sam_mask=None, save_intermediate=False):
|
56 |
+
device = self.betas.device
|
57 |
+
b = shape[0]
|
58 |
+
|
59 |
+
if not hparams['res']:
|
60 |
+
t = torch.full((b,), self.num_timesteps - 1, device=device, dtype=torch.long)
|
61 |
+
noise = None
|
62 |
+
img = self.q_sample(img_lr_up, t, noise=noise)
|
63 |
+
else:
|
64 |
+
img = torch.randn(shape, device=device)
|
65 |
+
|
66 |
+
if hparams['use_rrdb']:
|
67 |
+
rrdb_out, cond = self.rrdb(img_lr, True)
|
68 |
+
else:
|
69 |
+
rrdb_out = img_lr_up
|
70 |
+
cond = img_lr
|
71 |
+
|
72 |
+
it = reversed(range(0, self.num_timesteps))
|
73 |
+
|
74 |
+
if self.sample_tqdm:
|
75 |
+
it = tqdm(it, desc='sampling loop time step', total=self.num_timesteps)
|
76 |
+
|
77 |
+
images = []
|
78 |
+
for i in it:
|
79 |
+
img, x_recon = self.p_sample(
|
80 |
+
img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up, sam_mask=sam_mask)
|
81 |
+
if save_intermediate:
|
82 |
+
img_ = self.res2img(img, img_lr_up)
|
83 |
+
x_recon_ = self.res2img(x_recon, img_lr_up)
|
84 |
+
images.append((img_.cpu(), x_recon_.cpu()))
|
85 |
+
img = self.res2img(img, img_lr_up)
|
86 |
+
|
87 |
+
if save_intermediate:
|
88 |
+
return img, rrdb_out, images
|
89 |
+
else:
|
90 |
+
return img, rrdb_out
|
sam_diffsr/models_sr/module_util.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import init
|
4 |
+
|
5 |
+
|
6 |
+
def exists(x):
|
7 |
+
return x is not None
|
8 |
+
|
9 |
+
|
10 |
+
def default(val, d):
|
11 |
+
if exists(val):
|
12 |
+
return val
|
13 |
+
return d() if isfunction(d) else d
|
14 |
+
|
15 |
+
|
16 |
+
def cycle(dl):
|
17 |
+
while True:
|
18 |
+
for data in dl:
|
19 |
+
yield data
|
20 |
+
|
21 |
+
|
22 |
+
def num_to_groups(num, divisor):
|
23 |
+
groups = num // divisor
|
24 |
+
remainder = num % divisor
|
25 |
+
arr = [divisor] * groups
|
26 |
+
if remainder > 0:
|
27 |
+
arr.append(remainder)
|
28 |
+
return arr
|
29 |
+
|
30 |
+
|
31 |
+
def initialize_weights(net_l, scale=0.1):
|
32 |
+
if not isinstance(net_l, list):
|
33 |
+
net_l = [net_l]
|
34 |
+
for net in net_l:
|
35 |
+
for m in net.modules():
|
36 |
+
if isinstance(m, nn.Conv2d):
|
37 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
38 |
+
m.weight.data *= scale # for residual block
|
39 |
+
if m.bias is not None:
|
40 |
+
m.bias.data.zero_()
|
41 |
+
elif isinstance(m, nn.Linear):
|
42 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
43 |
+
m.weight.data *= scale
|
44 |
+
if m.bias is not None:
|
45 |
+
m.bias.data.zero_()
|
46 |
+
elif isinstance(m, nn.BatchNorm2d):
|
47 |
+
init.constant_(m.weight, 1)
|
48 |
+
init.constant_(m.bias.data, 0.0)
|
49 |
+
|
50 |
+
|
51 |
+
def make_layer(block, n_layers, seq=False):
|
52 |
+
layers = []
|
53 |
+
for _ in range(n_layers):
|
54 |
+
layers.append(block())
|
55 |
+
if seq:
|
56 |
+
return nn.Sequential(*layers)
|
57 |
+
else:
|
58 |
+
return nn.ModuleList(layers)
|
sam_diffsr/tasks/__init__.py
ADDED
File without changes
|
sam_diffsr/tasks/infer.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from collections import OrderedDict
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from tasks.srdiff_df2k import InferDataSet
|
8 |
+
|
9 |
+
parent_path = Path(__file__).absolute().parent.parent
|
10 |
+
sys.path.append(os.path.abspath(parent_path))
|
11 |
+
os.chdir(parent_path)
|
12 |
+
print(f'>-------------> parent path {parent_path}')
|
13 |
+
print(f'>-------------> current work dir {os.getcwd()}')
|
14 |
+
|
15 |
+
cache_path = os.path.join(parent_path, 'cache')
|
16 |
+
os.environ["HF_DATASETS_CACHE"] = cache_path
|
17 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
18 |
+
os.environ["torch_HOME"] = cache_path
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
from tqdm import tqdm
|
23 |
+
from torch.utils.tensorboard import SummaryWriter
|
24 |
+
from utils_sr.hparams import hparams, set_hparams
|
25 |
+
|
26 |
+
|
27 |
+
def load_ckpt(ckpt_path, model):
|
28 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
29 |
+
stat_dict = checkpoint['state_dict']['model']
|
30 |
+
|
31 |
+
new_state_dict = OrderedDict()
|
32 |
+
for k, v in stat_dict.items():
|
33 |
+
if k[:7] == 'module.':
|
34 |
+
k = k[7:] # 去掉 `module.`
|
35 |
+
new_state_dict[k] = v
|
36 |
+
|
37 |
+
model.load_state_dict(new_state_dict)
|
38 |
+
model.cuda()
|
39 |
+
|
40 |
+
|
41 |
+
def infer(trainer, ckpt_path, img_dir, save_dir):
|
42 |
+
trainer.build_model()
|
43 |
+
load_ckpt(ckpt_path, trainer.model)
|
44 |
+
|
45 |
+
dataset = InferDataSet(img_dir)
|
46 |
+
test_dataloader = torch.utils.data.DataLoader(
|
47 |
+
dataset, batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
|
48 |
+
|
49 |
+
torch.backends.cudnn.benchmark = False
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
trainer.model.eval()
|
53 |
+
pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
|
54 |
+
for batch_idx, batch in pbar:
|
55 |
+
img_lr, img_lr_up, img_name = batch
|
56 |
+
|
57 |
+
img_lr = img_lr.to('cuda')
|
58 |
+
img_lr_up = img_lr_up.to('cuda')
|
59 |
+
|
60 |
+
img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
|
61 |
+
|
62 |
+
img_sr = img_sr.clamp(-1, 1)
|
63 |
+
img_sr = trainer.tensor2img(img_sr)[0]
|
64 |
+
img_sr = Image.fromarray(img_sr)
|
65 |
+
img_sr.save(os.path.join(save_dir, img_name[0]))
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
set_hparams()
|
70 |
+
|
71 |
+
img_dir = hparams['img_dir']
|
72 |
+
save_dir = hparams['save_dir']
|
73 |
+
ckpt_path = hparams['ckpt_path']
|
74 |
+
|
75 |
+
pkg = ".".join(hparams["trainer_cls"].split(".")[:-1])
|
76 |
+
cls_name = hparams["trainer_cls"].split(".")[-1]
|
77 |
+
trainer = getattr(importlib.import_module(pkg), cls_name)()
|
78 |
+
|
79 |
+
os.makedirs(save_dir, exist_ok=True)
|
80 |
+
|
81 |
+
infer(trainer, ckpt_path, img_dir, save_dir)
|
sam_diffsr/tasks/rrdb.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from models_sr.diffsr_modules import RRDBNet
|
5 |
+
from tasks.srdiff_df2k import Df2kDataSet
|
6 |
+
from tasks.trainer import Trainer
|
7 |
+
from utils_sr.hparams import hparams
|
8 |
+
from utils_sr.sr_utils import PerceptualLoss
|
9 |
+
|
10 |
+
|
11 |
+
class RRDBTask(Trainer):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
if 'rrdb_loss' in hparams and hparams['rrdb_loss']['percep_loss']:
|
15 |
+
self.percep_loss_fn = PerceptualLoss()
|
16 |
+
self.percep_loss_weight = hparams['rrdb_loss']['percep_loss_weight']
|
17 |
+
else:
|
18 |
+
self.percep_loss_fn = None
|
19 |
+
self.percep_loss_weight = 0
|
20 |
+
|
21 |
+
def build_model(self):
|
22 |
+
hidden_size = hparams['hidden_size']
|
23 |
+
self.model = RRDBNet(3, 3, hidden_size, hparams['num_block'], hidden_size // 2)
|
24 |
+
return self.model
|
25 |
+
|
26 |
+
def build_optimizer(self, model):
|
27 |
+
return torch.optim.Adam(model.parameters(), lr=hparams['lr'])
|
28 |
+
|
29 |
+
def build_scheduler(self, optimizer):
|
30 |
+
return torch.optim.lr_scheduler.StepLR(optimizer, 200000, 0.5)
|
31 |
+
|
32 |
+
def training_step(self, sample):
|
33 |
+
img_hr = sample['img_hr']
|
34 |
+
img_lr = sample['img_lr']
|
35 |
+
p = self.model(img_lr)
|
36 |
+
total_loss = 0
|
37 |
+
loss = F.l1_loss(p, img_hr, reduction='mean')
|
38 |
+
total_loss += loss
|
39 |
+
|
40 |
+
if self.percep_loss_fn:
|
41 |
+
loss_percep = self.percep_loss_fn(img_hr, p) * self.percep_loss_weight
|
42 |
+
total_loss += loss_percep
|
43 |
+
return {'l': loss, 'loss_percep': loss_percep, 'total_loss': total_loss,
|
44 |
+
'lr': self.scheduler.get_last_lr()[0]}, total_loss
|
45 |
+
else:
|
46 |
+
return {'l': loss, 'lr': self.scheduler.get_last_lr()[0]}, total_loss
|
47 |
+
|
48 |
+
def sample_and_test(self, sample):
|
49 |
+
ret = {k: 0 for k in self.metric_keys}
|
50 |
+
ret['n_samples'] = 0
|
51 |
+
img_hr = sample['img_hr']
|
52 |
+
img_lr = sample['img_lr']
|
53 |
+
img_sr = self.model(img_lr)
|
54 |
+
img_sr = img_sr.clamp(-1, 1)
|
55 |
+
for b in range(img_sr.shape[0]):
|
56 |
+
s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
|
57 |
+
ret['psnr'] += s['psnr']
|
58 |
+
ret['ssim'] += s['ssim']
|
59 |
+
ret['lpips'] += s['lpips']
|
60 |
+
ret['lr_psnr'] += s['lr_psnr']
|
61 |
+
ret['n_samples'] += 1
|
62 |
+
return img_sr, img_sr, ret
|
63 |
+
|
64 |
+
|
65 |
+
class RRDBDf2kTask(RRDBTask):
|
66 |
+
def __init__(self):
|
67 |
+
super().__init__()
|
68 |
+
self.dataset_cls = Df2kDataSet
|
sam_diffsr/tasks/rrdb_sam.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from models_sr.diffsr_modules import RRDBNet
|
5 |
+
from tasks.srdiff_df2k_sam import Df2kDataSet_sam
|
6 |
+
from tasks.trainer import Trainer
|
7 |
+
from utils_sr.hparams import hparams
|
8 |
+
|
9 |
+
|
10 |
+
class RRDBTask_sam(Trainer):
|
11 |
+
def build_model(self):
|
12 |
+
hidden_size = hparams['hidden_size']
|
13 |
+
self.model = RRDBNet(3, 3, hidden_size, hparams['num_block'], hidden_size // 2)
|
14 |
+
return self.model
|
15 |
+
|
16 |
+
def build_optimizer(self, model):
|
17 |
+
return torch.optim.Adam(model.parameters(), lr=hparams['lr'])
|
18 |
+
|
19 |
+
def build_scheduler(self, optimizer):
|
20 |
+
return torch.optim.lr_scheduler.StepLR(optimizer, 200000, 0.5)
|
21 |
+
|
22 |
+
def training_step(self, sample):
|
23 |
+
img_hr = sample['img_hr']
|
24 |
+
img_lr = sample['img_lr']
|
25 |
+
p = self.model(img_lr)
|
26 |
+
loss = F.l1_loss(p, img_hr, reduction='mean')
|
27 |
+
return {'l': loss, 'lr': self.scheduler.get_last_lr()[0]}, loss
|
28 |
+
|
29 |
+
def sample_and_test(self, sample):
|
30 |
+
ret = {k: 0 for k in self.metric_keys}
|
31 |
+
ret['n_samples'] = 0
|
32 |
+
img_hr = sample['img_hr']
|
33 |
+
img_lr = sample['img_lr']
|
34 |
+
img_sr = self.model(img_lr)
|
35 |
+
img_sr = img_sr.clamp(-1, 1)
|
36 |
+
for b in range(img_sr.shape[0]):
|
37 |
+
s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
|
38 |
+
ret['psnr'] += s['psnr']
|
39 |
+
ret['ssim'] += s['ssim']
|
40 |
+
ret['lpips'] += s['lpips']
|
41 |
+
ret['lr_psnr'] += s['lr_psnr']
|
42 |
+
ret['n_samples'] += 1
|
43 |
+
return img_sr, img_sr, ret
|
44 |
+
|
45 |
+
|
46 |
+
class RRDBDf2kTask_sam(RRDBTask_sam):
|
47 |
+
def __init__(self):
|
48 |
+
super().__init__()
|
49 |
+
self.dataset_cls = Df2kDataSet_sam
|
sam_diffsr/tasks/srdiff.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from sam_diffsr.models_sr.diffsr_modules import Unet, RRDBNet
|
6 |
+
from sam_diffsr.models_sr.diffusion import GaussianDiffusion
|
7 |
+
from sam_diffsr.tasks.trainer import Trainer
|
8 |
+
from sam_diffsr.utils_sr.hparams import hparams
|
9 |
+
from sam_diffsr.utils_sr.utils import load_ckpt
|
10 |
+
|
11 |
+
|
12 |
+
class SRDiffTrainer(Trainer):
|
13 |
+
def build_model(self):
|
14 |
+
hidden_size = hparams['hidden_size']
|
15 |
+
dim_mults = hparams['unet_dim_mults']
|
16 |
+
dim_mults = [int(x) for x in dim_mults.split('|')]
|
17 |
+
denoise_fn = Unet(
|
18 |
+
hidden_size, out_dim=3, cond_dim=hparams['rrdb_num_feat'], dim_mults=dim_mults)
|
19 |
+
if hparams['use_rrdb']:
|
20 |
+
rrdb = RRDBNet(3, 3, hparams['rrdb_num_feat'], hparams['rrdb_num_block'],
|
21 |
+
hparams['rrdb_num_feat'] // 2)
|
22 |
+
if hparams['rrdb_ckpt'] != '' and os.path.exists(hparams['rrdb_ckpt']):
|
23 |
+
load_ckpt(rrdb, hparams['rrdb_ckpt'])
|
24 |
+
else:
|
25 |
+
rrdb = None
|
26 |
+
self.model = GaussianDiffusion(
|
27 |
+
denoise_fn=denoise_fn,
|
28 |
+
rrdb_net=rrdb,
|
29 |
+
timesteps=hparams['timesteps'],
|
30 |
+
loss_type=hparams['loss_type']
|
31 |
+
)
|
32 |
+
self.global_step = 0
|
33 |
+
return self.model
|
34 |
+
|
35 |
+
def sample_and_test(self, sample):
|
36 |
+
ret = {k: 0 for k in self.metric_keys}
|
37 |
+
ret['n_samples'] = 0
|
38 |
+
img_hr = sample['img_hr']
|
39 |
+
img_lr = sample['img_lr']
|
40 |
+
img_lr_up = sample['img_lr_up']
|
41 |
+
img_sr, rrdb_out = self.model.sample(img_lr, img_lr_up, img_hr.shape)
|
42 |
+
for b in range(img_sr.shape[0]):
|
43 |
+
s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
|
44 |
+
ret['psnr'] += s['psnr']
|
45 |
+
ret['ssim'] += s['ssim']
|
46 |
+
ret['lpips'] += s['lpips']
|
47 |
+
ret['lr_psnr'] += s['lr_psnr']
|
48 |
+
ret['n_samples'] += 1
|
49 |
+
return img_sr, rrdb_out, ret
|
50 |
+
|
51 |
+
def build_optimizer(self, model):
|
52 |
+
params = list(model.named_parameters())
|
53 |
+
if not hparams['fix_rrdb']:
|
54 |
+
params = [p for p in params if 'rrdb' not in p[0]]
|
55 |
+
params = [p[1] for p in params]
|
56 |
+
return torch.optim.Adam(params, lr=hparams['lr'])
|
57 |
+
|
58 |
+
def build_scheduler(self, optimizer):
|
59 |
+
if 'scheduler' in hparams:
|
60 |
+
scheduler_config = hparams['scheduler']
|
61 |
+
if scheduler_config['type'] == 'cosine':
|
62 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, hparams['max_updates'],
|
63 |
+
eta_min=scheduler_config['eta_min'])
|
64 |
+
|
65 |
+
else:
|
66 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
|
67 |
+
|
68 |
+
return lr_scheduler
|
69 |
+
|
70 |
+
def training_step(self, batch):
|
71 |
+
img_hr = batch['img_hr']
|
72 |
+
img_lr = batch['img_lr']
|
73 |
+
img_lr_up = batch['img_lr_up']
|
74 |
+
losses, _, _ = self.model(img_hr, img_lr, img_lr_up)
|
75 |
+
total_loss = sum(losses.values())
|
76 |
+
return losses, total_loss
|
sam_diffsr/tasks/srdiff_df2k.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
from sam_diffsr.tasks.srdiff import SRDiffTrainer
|
10 |
+
from sam_diffsr.utils_sr.dataset import SRDataSet
|
11 |
+
from sam_diffsr.utils_sr.hparams import hparams
|
12 |
+
from sam_diffsr.utils_sr.matlab_resize import imresize
|
13 |
+
|
14 |
+
|
15 |
+
class InferDataSet(Dataset):
|
16 |
+
def __init__(self, img_dir):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.img_path_list = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)]
|
20 |
+
self.to_tensor_norm = transforms.Compose([
|
21 |
+
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
22 |
+
])
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
sr_scale = hparams['sr_scale']
|
26 |
+
|
27 |
+
img_path = self.img_path_list[index]
|
28 |
+
img_name = os.path.basename(img_path)
|
29 |
+
|
30 |
+
img_lr = Image.open(img_path).convert('RGB')
|
31 |
+
img_lr = np.uint8(np.asarray(img_lr))
|
32 |
+
|
33 |
+
h, w, c = img_lr.shape
|
34 |
+
h, w = h * sr_scale, w * sr_scale
|
35 |
+
h = h - h % (sr_scale * 2)
|
36 |
+
w = w - w % (sr_scale * 2)
|
37 |
+
h_l = h // sr_scale
|
38 |
+
w_l = w // sr_scale
|
39 |
+
|
40 |
+
img_lr = img_lr[:h_l, :w_l]
|
41 |
+
|
42 |
+
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
|
43 |
+
img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
|
44 |
+
|
45 |
+
return img_lr, img_lr_up, img_name
|
46 |
+
|
47 |
+
def __len__(self):
|
48 |
+
return len(self.img_path_list)
|
49 |
+
|
50 |
+
|
51 |
+
class Df2kDataSet(SRDataSet):
|
52 |
+
def __init__(self, prefix='train'):
|
53 |
+
if prefix == 'valid':
|
54 |
+
_prefix = 'test'
|
55 |
+
else:
|
56 |
+
_prefix = prefix
|
57 |
+
|
58 |
+
super().__init__(_prefix)
|
59 |
+
self.patch_size = hparams['patch_size']
|
60 |
+
self.patch_size_lr = hparams['patch_size'] // hparams['sr_scale']
|
61 |
+
if prefix == 'valid':
|
62 |
+
self.len = hparams['eval_batch_size'] * hparams['valid_steps']
|
63 |
+
|
64 |
+
self.data_aug_transforms = transforms.Compose([
|
65 |
+
transforms.RandomHorizontalFlip(),
|
66 |
+
transforms.RandomRotation(20, resample=Image.BICUBIC),
|
67 |
+
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
|
68 |
+
])
|
69 |
+
|
70 |
+
def __getitem__(self, index):
|
71 |
+
item = self._get_item(index)
|
72 |
+
hparams = self.hparams
|
73 |
+
sr_scale = hparams['sr_scale']
|
74 |
+
|
75 |
+
img_hr = np.uint8(item['img'])
|
76 |
+
img_lr = np.uint8(item['img_lr'])
|
77 |
+
|
78 |
+
# TODO: clip for SRFlow
|
79 |
+
h, w, c = img_hr.shape
|
80 |
+
h = h - h % (sr_scale * 2)
|
81 |
+
w = w - w % (sr_scale * 2)
|
82 |
+
h_l = h // sr_scale
|
83 |
+
w_l = w // sr_scale
|
84 |
+
img_hr = img_hr[:h, :w]
|
85 |
+
img_lr = img_lr[:h_l, :w_l]
|
86 |
+
# random crop
|
87 |
+
if self.prefix == 'train':
|
88 |
+
if self.data_augmentation and random.random() < 0.5:
|
89 |
+
img_hr, img_lr = self.data_augment(img_hr, img_lr)
|
90 |
+
i = random.randint(0, h - self.patch_size) // sr_scale * sr_scale
|
91 |
+
i_lr = i // sr_scale
|
92 |
+
j = random.randint(0, w - self.patch_size) // sr_scale * sr_scale
|
93 |
+
j_lr = j // sr_scale
|
94 |
+
img_hr = img_hr[i:i + self.patch_size, j:j + self.patch_size]
|
95 |
+
img_lr = img_lr[i_lr:i_lr + self.patch_size_lr, j_lr:j_lr + self.patch_size_lr]
|
96 |
+
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
|
97 |
+
img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]]
|
98 |
+
return {
|
99 |
+
'img_hr': img_hr, 'img_lr': img_lr,
|
100 |
+
'img_lr_up': img_lr_up, 'item_name': item['item_name'],
|
101 |
+
'loc': np.array(item['loc']), 'loc_bdr': np.array(item['loc_bdr'])
|
102 |
+
}
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return self.len
|
106 |
+
|
107 |
+
def data_augment(self, img_hr, img_lr):
|
108 |
+
sr_scale = self.hparams['sr_scale']
|
109 |
+
img_hr = Image.fromarray(img_hr)
|
110 |
+
img_hr = self.data_aug_transforms(img_hr)
|
111 |
+
img_hr = np.asarray(img_hr) # np.uint8 [H, W, C]
|
112 |
+
img_lr = imresize(img_hr, 1 / sr_scale)
|
113 |
+
return img_hr, img_lr
|
114 |
+
|
115 |
+
|
116 |
+
class SRDiffDf2k(SRDiffTrainer):
|
117 |
+
def __init__(self):
|
118 |
+
super().__init__()
|
119 |
+
self.dataset_cls = Df2kDataSet
|
sam_diffsr/tasks/srdiff_df2k_sam.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from rotary_embedding_torch import RotaryEmbedding
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
from sam_diffsr.models_sr.diffsr_modules import RRDBNet, Unet
|
12 |
+
from sam_diffsr.models_sr.diffusion_sam import GaussianDiffusion_sam
|
13 |
+
from sam_diffsr.tasks.srdiff import SRDiffTrainer
|
14 |
+
from sam_diffsr.utils_sr.dataset import SRDataSet
|
15 |
+
from sam_diffsr.utils_sr.hparams import hparams
|
16 |
+
from sam_diffsr.utils_sr.indexed_datasets import IndexedDataset
|
17 |
+
from sam_diffsr.utils_sr.matlab_resize import imresize
|
18 |
+
from sam_diffsr.utils_sr.utils import load_ckpt
|
19 |
+
|
20 |
+
|
21 |
+
def normalize_01(data):
|
22 |
+
mu = np.mean(data)
|
23 |
+
sigma = np.std(data)
|
24 |
+
|
25 |
+
if sigma == 0.:
|
26 |
+
return data - mu
|
27 |
+
else:
|
28 |
+
return (data - mu) / sigma
|
29 |
+
|
30 |
+
|
31 |
+
def normalize_11(data):
|
32 |
+
mu = np.mean(data)
|
33 |
+
sigma = np.std(data)
|
34 |
+
|
35 |
+
if sigma == 0.:
|
36 |
+
return data - mu
|
37 |
+
else:
|
38 |
+
return (data - mu) / sigma - 1
|
39 |
+
|
40 |
+
|
41 |
+
class Df2kDataSet_sam(SRDataSet):
|
42 |
+
def __init__(self, prefix='train'):
|
43 |
+
|
44 |
+
if prefix == 'valid':
|
45 |
+
_prefix = 'test'
|
46 |
+
else:
|
47 |
+
_prefix = prefix
|
48 |
+
|
49 |
+
super().__init__(_prefix)
|
50 |
+
|
51 |
+
self.patch_size = hparams['patch_size']
|
52 |
+
self.patch_size_lr = hparams['patch_size'] // hparams['sr_scale']
|
53 |
+
if prefix == 'valid':
|
54 |
+
self.len = hparams['eval_batch_size'] * hparams['valid_steps']
|
55 |
+
|
56 |
+
self.data_position_aug_transforms = transforms.Compose([
|
57 |
+
transforms.RandomHorizontalFlip(),
|
58 |
+
transforms.RandomRotation(20, interpolation=Image.BICUBIC),
|
59 |
+
])
|
60 |
+
|
61 |
+
self.data_color_aug_transforms = transforms.Compose([
|
62 |
+
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
|
63 |
+
])
|
64 |
+
|
65 |
+
self.sam_config = hparams.get('sam_config', False)
|
66 |
+
|
67 |
+
if self.sam_config.get('mask_RoPE', False):
|
68 |
+
h, w = map(int, self.sam_config['mask_RoPE_shape'].split('-'))
|
69 |
+
rotary_emb = RotaryEmbedding(dim=h)
|
70 |
+
sam_mask = rotary_emb.rotate_queries_or_keys(torch.ones(1, 1, w, h))
|
71 |
+
self.RoPE_mask = sam_mask.cpu().numpy()[0, 0, ...]
|
72 |
+
|
73 |
+
def _get_item(self, index):
|
74 |
+
if self.indexed_ds is None:
|
75 |
+
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
|
76 |
+
return self.indexed_ds[index]
|
77 |
+
|
78 |
+
def __getitem__(self, index):
|
79 |
+
item = self._get_item(index)
|
80 |
+
hparams = self.hparams
|
81 |
+
sr_scale = hparams['sr_scale']
|
82 |
+
|
83 |
+
img_hr = np.uint8(item['img'])
|
84 |
+
img_lr = np.uint8(item['img_lr'])
|
85 |
+
|
86 |
+
if self.sam_config.get('mask_RoPE', False):
|
87 |
+
sam_mask = self.RoPE_mask
|
88 |
+
else:
|
89 |
+
if 'sam_mask' in item:
|
90 |
+
sam_mask = item['sam_mask']
|
91 |
+
if sam_mask.shape != img_hr.shape[:2]:
|
92 |
+
sam_mask = cv2.resize(sam_mask, dsize=img_hr.shape[:2][::-1])
|
93 |
+
else:
|
94 |
+
sam_mask = np.zeros_like(img_lr)
|
95 |
+
|
96 |
+
# TODO: clip for SRFlow
|
97 |
+
h, w, c = img_hr.shape
|
98 |
+
h = h - h % (sr_scale * 2)
|
99 |
+
w = w - w % (sr_scale * 2)
|
100 |
+
h_l = h // sr_scale
|
101 |
+
w_l = w // sr_scale
|
102 |
+
img_hr = img_hr[:h, :w]
|
103 |
+
sam_mask = sam_mask[:h, :w]
|
104 |
+
img_lr = img_lr[:h_l, :w_l]
|
105 |
+
|
106 |
+
# random crop
|
107 |
+
if self.prefix == 'train':
|
108 |
+
if self.data_augmentation and random.random() < 0.5:
|
109 |
+
img_hr, img_lr, sam_mask = self.data_augment(img_hr, img_lr, sam_mask)
|
110 |
+
i = random.randint(0, h - self.patch_size) // sr_scale * sr_scale
|
111 |
+
i_lr = i // sr_scale
|
112 |
+
j = random.randint(0, w - self.patch_size) // sr_scale * sr_scale
|
113 |
+
j_lr = j // sr_scale
|
114 |
+
img_hr = img_hr[i:i + self.patch_size, j:j + self.patch_size]
|
115 |
+
sam_mask = sam_mask[i:i + self.patch_size, j:j + self.patch_size]
|
116 |
+
img_lr = img_lr[i_lr:i_lr + self.patch_size_lr, j_lr:j_lr + self.patch_size_lr]
|
117 |
+
|
118 |
+
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
|
119 |
+
img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]]
|
120 |
+
|
121 |
+
if hparams['sam_data_config']['all_same_mask_to_zero']:
|
122 |
+
if len(np.unique(sam_mask)) == 1:
|
123 |
+
sam_mask = np.zeros_like(sam_mask)
|
124 |
+
|
125 |
+
if hparams['sam_data_config']['normalize_01']:
|
126 |
+
if len(np.unique(sam_mask)) != 1:
|
127 |
+
sam_mask = normalize_01(sam_mask)
|
128 |
+
|
129 |
+
if hparams['sam_data_config']['normalize_11']:
|
130 |
+
if len(np.unique(sam_mask)) != 1:
|
131 |
+
sam_mask = normalize_11(sam_mask)
|
132 |
+
|
133 |
+
sam_mask = torch.FloatTensor(sam_mask).unsqueeze(dim=0)
|
134 |
+
|
135 |
+
return {
|
136 |
+
'img_hr': img_hr, 'img_lr': img_lr,
|
137 |
+
'img_lr_up': img_lr_up, 'item_name': item['item_name'],
|
138 |
+
'loc': np.array(item['loc']), 'loc_bdr': np.array(item['loc_bdr']),
|
139 |
+
'sam_mask': sam_mask
|
140 |
+
}
|
141 |
+
|
142 |
+
def __len__(self):
|
143 |
+
return self.len
|
144 |
+
|
145 |
+
def data_augment(self, img_hr, img_lr, sam_mask):
|
146 |
+
sr_scale = self.hparams['sr_scale']
|
147 |
+
img_hr = Image.fromarray(img_hr)
|
148 |
+
img_hr, sam_mask = self.data_position_aug_transforms([img_hr, sam_mask])
|
149 |
+
img_hr = self.data_color_aug_transforms(img_hr)
|
150 |
+
img_hr = np.asarray(img_hr) # np.uint8 [H, W, C]
|
151 |
+
img_lr = imresize(img_hr, 1 / sr_scale)
|
152 |
+
return img_hr, img_lr, sam_mask
|
153 |
+
|
154 |
+
|
155 |
+
class SRDiffDf2k_sam(SRDiffTrainer):
|
156 |
+
def __init__(self):
|
157 |
+
super().__init__()
|
158 |
+
self.dataset_cls = Df2kDataSet_sam
|
159 |
+
self.sam_config = hparams['sam_config']
|
160 |
+
|
161 |
+
def build_model(self):
|
162 |
+
hidden_size = hparams['hidden_size']
|
163 |
+
dim_mults = hparams['unet_dim_mults']
|
164 |
+
dim_mults = [int(x) for x in dim_mults.split('|')]
|
165 |
+
|
166 |
+
denoise_fn = Unet(
|
167 |
+
hidden_size, out_dim=3, cond_dim=hparams['rrdb_num_feat'], dim_mults=dim_mults)
|
168 |
+
if hparams['use_rrdb']:
|
169 |
+
rrdb = RRDBNet(3, 3, hparams['rrdb_num_feat'], hparams['rrdb_num_block'],
|
170 |
+
hparams['rrdb_num_feat'] // 2)
|
171 |
+
if hparams['rrdb_ckpt'] != '' and os.path.exists(hparams['rrdb_ckpt']):
|
172 |
+
load_ckpt(rrdb, hparams['rrdb_ckpt'])
|
173 |
+
else:
|
174 |
+
rrdb = None
|
175 |
+
self.model = GaussianDiffusion_sam(
|
176 |
+
denoise_fn=denoise_fn,
|
177 |
+
rrdb_net=rrdb,
|
178 |
+
timesteps=hparams['timesteps'],
|
179 |
+
loss_type=hparams['loss_type'],
|
180 |
+
sam_config=hparams['sam_config']
|
181 |
+
)
|
182 |
+
self.global_step = 0
|
183 |
+
return self.model
|
184 |
+
|
185 |
+
# def sample_and_test(self, sample):
|
186 |
+
# ret = {k: 0 for k in self.metric_keys}
|
187 |
+
# ret['n_samples'] = 0
|
188 |
+
# img_hr = sample['img_hr']
|
189 |
+
# img_lr = sample['img_lr']
|
190 |
+
# img_lr_up = sample['img_lr_up']
|
191 |
+
# sam_mask = sample['sam_mask']
|
192 |
+
#
|
193 |
+
# img_sr, rrdb_out = self.model.sample(img_lr, img_lr_up, img_hr.shape, sam_mask=sam_mask)
|
194 |
+
#
|
195 |
+
# for b in range(img_sr.shape[0]):
|
196 |
+
# s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
|
197 |
+
# ret['psnr'] += s['psnr']
|
198 |
+
# ret['ssim'] += s['ssim']
|
199 |
+
# ret['lpips'] += s['lpips']
|
200 |
+
# ret['lr_psnr'] += s['lr_psnr']
|
201 |
+
# ret['n_samples'] += 1
|
202 |
+
# return img_sr, rrdb_out, ret
|
203 |
+
|
204 |
+
def training_step(self, batch):
|
205 |
+
img_hr = batch['img_hr']
|
206 |
+
img_lr = batch['img_lr']
|
207 |
+
img_lr_up = batch['img_lr_up']
|
208 |
+
sam_mask = batch['sam_mask']
|
209 |
+
losses, _, _ = self.model(img_hr, img_lr, img_lr_up, sam_mask=sam_mask)
|
210 |
+
total_loss = sum(losses.values())
|
211 |
+
return losses, total_loss
|
sam_diffsr/tasks/trainer.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
from collections import OrderedDict
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
parent_path = Path(__file__).absolute().parent.parent
|
10 |
+
sys.path.append(os.path.abspath(parent_path))
|
11 |
+
os.chdir(parent_path)
|
12 |
+
print(f'>-------------> parent path {parent_path}')
|
13 |
+
print(f'>-------------> current work dir {os.getcwd()}')
|
14 |
+
|
15 |
+
cache_path = os.path.join(parent_path, 'cache')
|
16 |
+
os.environ["HF_DATASETS_CACHE"] = cache_path
|
17 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
18 |
+
os.environ["torch_HOME"] = cache_path
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
from tqdm import tqdm
|
23 |
+
import numpy as np
|
24 |
+
from torch.utils.tensorboard import SummaryWriter
|
25 |
+
from sam_diffsr.utils_sr.hparams import hparams, set_hparams
|
26 |
+
from sam_diffsr.utils_sr.utils import plot_img, move_to_cuda, load_checkpoint, save_checkpoint, tensors_to_scalars, Measure, \
|
27 |
+
get_all_ckpts
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class Trainer:
|
32 |
+
def __init__(self):
|
33 |
+
self.logger = self.build_tensorboard(save_dir=hparams['work_dir'], name='tb_logs')
|
34 |
+
self.measure = Measure()
|
35 |
+
self.dataset_cls = None
|
36 |
+
self.metric_keys = ['psnr', 'ssim', 'lpips', 'lr_psnr']
|
37 |
+
self.metric_2_keys = ['psnr-Y', 'ssim', 'fid']
|
38 |
+
self.work_dir = hparams['work_dir']
|
39 |
+
self.first_val = True
|
40 |
+
|
41 |
+
self.val_steps = hparams['val_steps']
|
42 |
+
|
43 |
+
def build_tensorboard(self, save_dir, name, **kwargs):
|
44 |
+
log_dir = os.path.join(save_dir, name)
|
45 |
+
os.makedirs(log_dir, exist_ok=True)
|
46 |
+
return SummaryWriter(log_dir=log_dir, **kwargs)
|
47 |
+
|
48 |
+
def build_train_dataloader(self):
|
49 |
+
dataset = self.dataset_cls('train')
|
50 |
+
return torch.utils.data.DataLoader(
|
51 |
+
dataset, batch_size=hparams['batch_size'], shuffle=True,
|
52 |
+
pin_memory=False, num_workers=hparams['num_workers'])
|
53 |
+
|
54 |
+
def build_val_dataloader(self):
|
55 |
+
return torch.utils.data.DataLoader(
|
56 |
+
self.dataset_cls('valid'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
|
57 |
+
|
58 |
+
def build_test_dataloader(self):
|
59 |
+
return torch.utils.data.DataLoader(
|
60 |
+
self.dataset_cls('test'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
|
61 |
+
|
62 |
+
def build_model(self):
|
63 |
+
raise NotImplementedError
|
64 |
+
|
65 |
+
def sample_and_test(self, sample):
|
66 |
+
raise NotImplementedError
|
67 |
+
|
68 |
+
def build_optimizer(self, model):
|
69 |
+
raise NotImplementedError
|
70 |
+
|
71 |
+
def build_scheduler(self, optimizer):
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
def training_step(self, batch):
|
75 |
+
raise NotImplementedError
|
76 |
+
|
77 |
+
def train(self):
|
78 |
+
model = self.build_model()
|
79 |
+
optimizer = self.build_optimizer(model)
|
80 |
+
self.global_step = training_step = load_checkpoint(model, optimizer, hparams['work_dir'], steps=self.val_steps)
|
81 |
+
self.scheduler = scheduler = self.build_scheduler(optimizer)
|
82 |
+
scheduler.step(training_step)
|
83 |
+
dataloader = self.build_train_dataloader()
|
84 |
+
|
85 |
+
train_pbar = tqdm(dataloader, initial=training_step, total=float('inf'),
|
86 |
+
dynamic_ncols=True, unit='step')
|
87 |
+
while self.global_step < hparams['max_updates']:
|
88 |
+
for batch in train_pbar:
|
89 |
+
if training_step % hparams['val_check_interval'] == 0:
|
90 |
+
with torch.no_grad():
|
91 |
+
model.eval()
|
92 |
+
self.validate(training_step)
|
93 |
+
save_checkpoint(model, optimizer, self.work_dir, training_step, hparams['num_ckpt_keep'])
|
94 |
+
model.train()
|
95 |
+
batch = move_to_cuda(batch)
|
96 |
+
losses, total_loss = self.training_step(batch)
|
97 |
+
optimizer.zero_grad()
|
98 |
+
|
99 |
+
total_loss.backward()
|
100 |
+
optimizer.step()
|
101 |
+
training_step += 1
|
102 |
+
scheduler.step(training_step)
|
103 |
+
self.global_step = training_step
|
104 |
+
if training_step % 100 == 0:
|
105 |
+
self.log_metrics({f'tr/{k}': v for k, v in losses.items()}, training_step)
|
106 |
+
train_pbar.set_postfix(**tensors_to_scalars(losses))
|
107 |
+
|
108 |
+
def validate(self, training_step):
|
109 |
+
val_dataloader = self.build_val_dataloader()
|
110 |
+
pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader))
|
111 |
+
metrics = {}
|
112 |
+
for batch_idx, batch in pbar:
|
113 |
+
# 每次运行的第一次validation只跑一小部分数据,来验证代码能否跑通
|
114 |
+
if self.first_val and batch_idx > hparams['num_sanity_val_steps'] - 1:
|
115 |
+
break
|
116 |
+
batch = move_to_cuda(batch)
|
117 |
+
img, rrdb_out, ret = self.sample_and_test(batch)
|
118 |
+
img_hr = batch['img_hr']
|
119 |
+
img_lr = batch['img_lr']
|
120 |
+
img_lr_up = batch['img_lr_up']
|
121 |
+
if img is not None:
|
122 |
+
self.logger.add_image(f'Pred_{batch_idx}', plot_img(img[0]), self.global_step)
|
123 |
+
if hparams.get('aux_l1_loss'):
|
124 |
+
self.logger.add_image(f'rrdb_out_{batch_idx}', plot_img(rrdb_out[0]), self.global_step)
|
125 |
+
if self.global_step <= hparams['val_check_interval']:
|
126 |
+
self.logger.add_image(f'HR_{batch_idx}', plot_img(img_hr[0]), self.global_step)
|
127 |
+
self.logger.add_image(f'LR_{batch_idx}', plot_img(img_lr[0]), self.global_step)
|
128 |
+
self.logger.add_image(f'BL_{batch_idx}', plot_img(img_lr_up[0]), self.global_step)
|
129 |
+
metrics = {}
|
130 |
+
metrics.update({k: np.mean(ret[k]) for k in self.metric_keys})
|
131 |
+
pbar.set_postfix(**tensors_to_scalars(metrics))
|
132 |
+
if hparams['infer']:
|
133 |
+
print('Val results:', metrics)
|
134 |
+
else:
|
135 |
+
if not self.first_val:
|
136 |
+
self.log_metrics({f'val/{k}': v for k, v in metrics.items()}, training_step)
|
137 |
+
print('Val results:', metrics)
|
138 |
+
else:
|
139 |
+
print('Sanity val results:', metrics)
|
140 |
+
self.first_val = False
|
141 |
+
|
142 |
+
def build_test_my_dataloader(self, data_name):
|
143 |
+
return torch.utils.data.DataLoader(
|
144 |
+
self.dataset_cls(data_name), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
|
145 |
+
|
146 |
+
def benchmark(self, benchmark_name_list, metric_list):
|
147 |
+
from sam_diffsr.tools.caculate_iqa import eval_img_IQA
|
148 |
+
|
149 |
+
model = self.build_model()
|
150 |
+
optimizer = self.build_optimizer(model)
|
151 |
+
training_step = load_checkpoint(model, optimizer, hparams['work_dir'], hparams['val_steps'])
|
152 |
+
self.global_step = training_step
|
153 |
+
|
154 |
+
optimizer = None
|
155 |
+
|
156 |
+
for data_name in benchmark_name_list:
|
157 |
+
test_dataloader = self.build_test_my_dataloader(data_name)
|
158 |
+
|
159 |
+
self.results = {k: 0 for k in self.metric_keys}
|
160 |
+
self.n_samples = 0
|
161 |
+
self.gen_dir = f"{hparams['work_dir']}/results_{self.global_step}_{hparams['gen_dir_name']}/benchmark/{data_name}"
|
162 |
+
if hparams['test_save_png']:
|
163 |
+
subprocess.check_call(f'rm -rf {self.gen_dir}', shell=True)
|
164 |
+
os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True)
|
165 |
+
os.makedirs(f'{self.gen_dir}/SR', exist_ok=True)
|
166 |
+
|
167 |
+
self.model.sample_tqdm = False
|
168 |
+
torch.backends.cudnn.benchmark = False
|
169 |
+
if hparams['test_save_png']:
|
170 |
+
if hasattr(self.model.denoise_fn, 'make_generation_fast_'):
|
171 |
+
self.model.denoise_fn.make_generation_fast_()
|
172 |
+
os.makedirs(f'{self.gen_dir}/HR', exist_ok=True)
|
173 |
+
|
174 |
+
result_dict = {}
|
175 |
+
|
176 |
+
with torch.no_grad():
|
177 |
+
model.eval()
|
178 |
+
pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
|
179 |
+
for batch_idx, batch in pbar:
|
180 |
+
move_to_cuda(batch)
|
181 |
+
gen_dir = self.gen_dir
|
182 |
+
item_names = batch['item_name']
|
183 |
+
img_hr = batch['img_hr']
|
184 |
+
img_lr = batch['img_lr']
|
185 |
+
img_lr_up = batch['img_lr_up']
|
186 |
+
|
187 |
+
res = self.sample_and_test(batch)
|
188 |
+
if len(res) == 3:
|
189 |
+
img_sr, rrdb_out, ret = res
|
190 |
+
else:
|
191 |
+
img_sr, ret = res
|
192 |
+
rrdb_out = img_sr
|
193 |
+
|
194 |
+
img_lr_up = batch.get('img_lr_up', img_lr_up)
|
195 |
+
if img_sr is not None:
|
196 |
+
metrics = list(self.metric_keys)
|
197 |
+
result_dict[batch['item_name'][0]] = {}
|
198 |
+
for k in metrics:
|
199 |
+
self.results[k] += ret[k]
|
200 |
+
result_dict[batch['item_name'][0]][k] = ret[k]
|
201 |
+
self.n_samples += ret['n_samples']
|
202 |
+
|
203 |
+
print({k: round(self.results[k] / self.n_samples, 3) for k in self.results}, 'total:',
|
204 |
+
self.n_samples)
|
205 |
+
|
206 |
+
if hparams['test_save_png'] and img_sr is not None:
|
207 |
+
img_sr = self.tensor2img(img_sr)
|
208 |
+
img_hr = self.tensor2img(img_hr)
|
209 |
+
img_lr = self.tensor2img(img_lr)
|
210 |
+
img_lr_up = self.tensor2img(img_lr_up)
|
211 |
+
rrdb_out = self.tensor2img(rrdb_out)
|
212 |
+
for item_name, hr_p, hr_g, lr, lr_up, rrdb_o in zip(
|
213 |
+
item_names, img_sr, img_hr, img_lr, img_lr_up, rrdb_out):
|
214 |
+
item_name = os.path.splitext(item_name)[0]
|
215 |
+
hr_p = Image.fromarray(hr_p)
|
216 |
+
hr_g = Image.fromarray(hr_g)
|
217 |
+
hr_p.save(f"{gen_dir}/SR/{item_name}.png")
|
218 |
+
hr_g.save(f"{gen_dir}/HR/{item_name}.png")
|
219 |
+
|
220 |
+
exp_name = hparams['work_dir'].split('/')[-1]
|
221 |
+
sr_img_dir = f"{gen_dir}/SR/"
|
222 |
+
gt_img_dir = f"{gen_dir}/HR/"
|
223 |
+
excel_path = f"{hparams['work_dir']}/IQA-val-benchmark-{exp_name}.xlsx"
|
224 |
+
epoch = training_step
|
225 |
+
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
|
226 |
+
|
227 |
+
os.makedirs(f'{self.gen_dir}', exist_ok=True)
|
228 |
+
eval_json_path = os.path.join(self.gen_dir, 'eval.json')
|
229 |
+
avg_result = {k: round(self.results[k] / self.n_samples, 4) for k in self.results}
|
230 |
+
with open(eval_json_path, 'w+') as file:
|
231 |
+
json.dump(avg_result, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False)
|
232 |
+
json.dump(result_dict, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False)
|
233 |
+
|
234 |
+
def benchmark_loop(self, benchmark_name_list, metric_list, gt_path):
|
235 |
+
# infer and evaluation all save checkpoint
|
236 |
+
from sam_diffsr.tools.caculate_iqa import eval_img_IQA
|
237 |
+
|
238 |
+
model = self.build_model()
|
239 |
+
|
240 |
+
def get_checkpoint(model, checkpoint):
|
241 |
+
stat_dict = checkpoint['state_dict']['model']
|
242 |
+
|
243 |
+
new_state_dict = OrderedDict()
|
244 |
+
for k, v in stat_dict.items():
|
245 |
+
if k[:7] == 'module.':
|
246 |
+
k = k[7:] # 去掉 `module.`
|
247 |
+
new_state_dict[k] = v
|
248 |
+
|
249 |
+
model.load_state_dict(new_state_dict)
|
250 |
+
model.cuda()
|
251 |
+
training_step = checkpoint['global_step']
|
252 |
+
del checkpoint
|
253 |
+
torch.cuda.empty_cache()
|
254 |
+
|
255 |
+
return training_step
|
256 |
+
|
257 |
+
ckpt_paths = get_all_ckpts(hparams['work_dir'])
|
258 |
+
for ckpt_path in ckpt_paths:
|
259 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
260 |
+
training_step = get_checkpoint(model, checkpoint)
|
261 |
+
|
262 |
+
self.global_step = training_step
|
263 |
+
|
264 |
+
for data_name in benchmark_name_list:
|
265 |
+
test_dataloader = self.build_test_my_dataloader(data_name)
|
266 |
+
|
267 |
+
self.results = {k: 0 for k in self.metric_keys + self.metric_2_keys}
|
268 |
+
self.n_samples = 0
|
269 |
+
self.gen_dir = f"{hparams['work_dir']}/results_{training_step}_{hparams['gen_dir_name']}/benchmark/{data_name}"
|
270 |
+
|
271 |
+
os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True)
|
272 |
+
os.makedirs(f'{self.gen_dir}/SR', exist_ok=True)
|
273 |
+
|
274 |
+
self.model.sample_tqdm = False
|
275 |
+
torch.backends.cudnn.benchmark = False
|
276 |
+
|
277 |
+
with torch.no_grad():
|
278 |
+
model.eval()
|
279 |
+
pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
|
280 |
+
for batch_idx, batch in pbar:
|
281 |
+
move_to_cuda(batch)
|
282 |
+
gen_dir = self.gen_dir
|
283 |
+
item_names = batch['item_name']
|
284 |
+
|
285 |
+
res = self.sample_and_test(batch)
|
286 |
+
if len(res) == 3:
|
287 |
+
img_sr, rrdb_out, ret = res
|
288 |
+
else:
|
289 |
+
img_sr, ret = res
|
290 |
+
rrdb_out = img_sr
|
291 |
+
|
292 |
+
img_sr = self.tensor2img(img_sr)
|
293 |
+
|
294 |
+
for item_name, hr_p in zip(item_names, img_sr):
|
295 |
+
item_name = os.path.splitext(item_name)[0]
|
296 |
+
hr_p = Image.fromarray(hr_p)
|
297 |
+
hr_p.save(f"{gen_dir}/SR/{item_name}.png")
|
298 |
+
|
299 |
+
exp_name = hparams['work_dir'].split('/')[-1]
|
300 |
+
sr_img_dir = f"{gen_dir}/SR/"
|
301 |
+
gt_img_dir = f"{gt_path}/{data_name}/HR"
|
302 |
+
excel_path = f"{hparams['work_dir']}/IQA-val-benchmark_loop-{exp_name}.xlsx"
|
303 |
+
epoch = training_step
|
304 |
+
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
|
305 |
+
|
306 |
+
# utils_sr
|
307 |
+
def log_metrics(self, metrics, step):
|
308 |
+
metrics = self.metrics_to_scalars(metrics)
|
309 |
+
logger = self.logger
|
310 |
+
for k, v in metrics.items():
|
311 |
+
if isinstance(v, torch.Tensor):
|
312 |
+
v = v.item()
|
313 |
+
logger.add_scalar(k, v, step)
|
314 |
+
|
315 |
+
def metrics_to_scalars(self, metrics):
|
316 |
+
new_metrics = {}
|
317 |
+
for k, v in metrics.items():
|
318 |
+
if isinstance(v, torch.Tensor):
|
319 |
+
v = v.item()
|
320 |
+
|
321 |
+
if type(v) is dict:
|
322 |
+
v = self.metrics_to_scalars(v)
|
323 |
+
|
324 |
+
new_metrics[k] = v
|
325 |
+
|
326 |
+
return new_metrics
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def tensor2img(img):
|
330 |
+
img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1) * 127.5)
|
331 |
+
img = img.clip(min=0, max=255).astype(np.uint8)
|
332 |
+
return img
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == '__main__':
|
336 |
+
set_hparams()
|
337 |
+
|
338 |
+
pkg = ".".join(hparams["trainer_cls"].split(".")[:-1])
|
339 |
+
cls_name = hparams["trainer_cls"].split(".")[-1]
|
340 |
+
trainer = getattr(importlib.import_module(pkg), cls_name)()
|
341 |
+
if hparams['benchmark_loop']:
|
342 |
+
trainer.benchmark_loop(hparams['benchmark_name_list'], hparams['metric_list'], hparams['gt_img_path'])
|
343 |
+
elif hparams['benchmark']:
|
344 |
+
trainer.benchmark(hparams['benchmark_name_list'], hparams['metric_list'])
|
345 |
+
else:
|
346 |
+
trainer.train()
|
sam_diffsr/tb_logs/events.out.tfevents.1709283169.wangchengchengdeMacBook-Pro.local.99018.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aba4ab7fc71e002fcd70117a9bb9ad042341fc80f7d51f5bd4bd9610c508a655
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709284054.wangchengchengdeMacBook-Pro.local.99188.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13225cd295f9d05736be890cd9b70fdf332846531c35760d37b7aae5ca02e584
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709284076.wangchengchengdeMacBook-Pro.local.99198.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf4eadf9f990294906e556c51d92636c7af10c6aec626003341eb0d7b3f5bc38
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709284101.wangchengchengdeMacBook-Pro.local.99211.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:687ac32fefdcbdb917bfa7c9197f8e64b050a506f9676d365f23484183da5629
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709284193.wangchengchengdeMacBook-Pro.local.99233.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3fd3acbfec3acc9d5d582e81cd6819b2e5512f45bbf85918b89e6ea371ffbdf2
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709284415.wangchengchengdeMacBook-Pro.local.99289.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5867fd2539cfbb896867c770c5c7c0f5d9687e29ea0322646fe820dae1cae08c
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709284460.wangchengchengdeMacBook-Pro.local.99308.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7999747d297acf0c198f9b8739134ee8ba7d4bc4fae209e47fc4bbd09a7206c
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709284491.wangchengchengdeMacBook-Pro.local.99315.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:172e36eb89b0238ad5cdac335de91bc2721a1944c5e9807027a6c3eeea64f918
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709285127.wangchengchengdeMacBook-Pro.local.785.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f51161b5901b64dd694dd06243d6143281b4df0f839ec89601407aad680cfc34
|
3 |
+
size 88
|
sam_diffsr/tb_logs/events.out.tfevents.1709285146.wangchengchengdeMacBook-Pro.local.901.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:776fa889c3c898a1b157f45aed579d9976791c718b507508421100c03baeb401
|
3 |
+
size 88
|
sam_diffsr/tools/caculate_iqa.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ssl
|
3 |
+
from os.path import join
|
4 |
+
from pathlib import Path
|
5 |
+
from statistics import mean
|
6 |
+
|
7 |
+
parent_path = Path(__file__).absolute().parent.parent
|
8 |
+
parent_path = os.path.abspath(parent_path)
|
9 |
+
|
10 |
+
os.environ["CURL_CA_BUNDLE"] = ""
|
11 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
12 |
+
|
13 |
+
cache_path = os.path.join(parent_path, 'cache')
|
14 |
+
os.environ["HF_DATASETS_CACHE"] = cache_path
|
15 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
16 |
+
os.environ["torch_HOME"] = cache_path
|
17 |
+
|
18 |
+
import PIL
|
19 |
+
import numpy as np
|
20 |
+
import pandas as pd
|
21 |
+
import pyiqa
|
22 |
+
import torch
|
23 |
+
from PIL import Image
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
27 |
+
|
28 |
+
metric_dict = {
|
29 |
+
'psnr-Y': pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr'),
|
30 |
+
'ssim': pyiqa.create_metric('ssim', color_space='ycbcr'),
|
31 |
+
'fid': pyiqa.create_metric('fid'),
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
def load_img(path, target_size=None):
|
36 |
+
image = Image.open(path).convert("RGB")
|
37 |
+
if target_size:
|
38 |
+
h, w = target_size
|
39 |
+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
40 |
+
image = np.array(image).astype(np.float32) / 255.0
|
41 |
+
image = image[None].transpose(0, 3, 1, 2)
|
42 |
+
image = torch.from_numpy(image)
|
43 |
+
return image
|
44 |
+
|
45 |
+
|
46 |
+
def eval_img_IQA(gt_dir, sr_dir, excel_path, metric_list, exp_name, data_name):
|
47 |
+
gt_img_list = os.listdir(gt_dir)
|
48 |
+
|
49 |
+
iqa_result = {}
|
50 |
+
|
51 |
+
for metric in metric_list:
|
52 |
+
iqa_metric = metric_dict[metric].to(device)
|
53 |
+
score_fr_list = []
|
54 |
+
|
55 |
+
if metric == 'fid':
|
56 |
+
score_fr = iqa_metric(sr_dir, gt_dir)
|
57 |
+
iqa_result[metric] = float(score_fr)
|
58 |
+
print(f'{metric}: {float(score_fr)}')
|
59 |
+
else:
|
60 |
+
for img_name in tqdm(gt_img_list):
|
61 |
+
base_name = img_name.split('.')[0]
|
62 |
+
sr_img_name = f'{base_name}.png'
|
63 |
+
gt_img_path = join(gt_dir, img_name)
|
64 |
+
sr_img_path = join(sr_dir, sr_img_name)
|
65 |
+
|
66 |
+
if not os.path.exists(sr_img_path):
|
67 |
+
print(f'File not exist: {sr_img_path}')
|
68 |
+
continue
|
69 |
+
|
70 |
+
gt_img = load_img(gt_img_path, target_size=None)
|
71 |
+
target_size = gt_img.shape[2:]
|
72 |
+
sr_img = load_img(sr_img_path, target_size=target_size)
|
73 |
+
|
74 |
+
score_fr = iqa_metric(sr_img, gt_img)
|
75 |
+
|
76 |
+
if score_fr.shape == (1,):
|
77 |
+
score_fr = score_fr[0]
|
78 |
+
if isinstance(score_fr, torch.Tensor):
|
79 |
+
score_fr = float(score_fr.cpu().numpy())
|
80 |
+
else:
|
81 |
+
score_fr = float(score_fr)
|
82 |
+
score_fr_list.append(score_fr)
|
83 |
+
|
84 |
+
mean_score = mean(score_fr_list)
|
85 |
+
iqa_result[metric] = float(mean_score)
|
86 |
+
print(f'{metric}: {mean_score}')
|
87 |
+
|
88 |
+
if os.path.exists(excel_path):
|
89 |
+
df = pd.read_excel(excel_path)
|
90 |
+
else:
|
91 |
+
df = pd.DataFrame(columns=['exp'])
|
92 |
+
|
93 |
+
new_index = len(df.index)
|
94 |
+
|
95 |
+
exp_name = int(exp_name)
|
96 |
+
if exp_name in df['exp'].to_list():
|
97 |
+
new_index = df[df['exp'] == exp_name].index.tolist()[0]
|
98 |
+
else:
|
99 |
+
df.loc[new_index, 'exp'] = exp_name
|
100 |
+
|
101 |
+
for index, metric in enumerate(metric_list):
|
102 |
+
df_metric = f'{data_name}-{metric}'
|
103 |
+
if df_metric not in df.columns.tolist():
|
104 |
+
df[df_metric] = ''
|
105 |
+
|
106 |
+
df.loc[new_index, df_metric] = iqa_result[metric]
|
107 |
+
|
108 |
+
df.sort_values(by='exp', inplace=True)
|
109 |
+
|
110 |
+
df.to_excel(excel_path, startcol=0, index=False)
|
111 |
+
|
112 |
+
|
113 |
+
def main():
|
114 |
+
epoch = 400000
|
115 |
+
add_name = ''
|
116 |
+
exp_root = '/home/ma-user/work/code/SRDiff-main/checkpoints'
|
117 |
+
|
118 |
+
model_type_list = ['diffsr_df2k4x_sam-pl_qs-zero']
|
119 |
+
|
120 |
+
metric_list = ['psnr-Y', 'ssim', 'fid']
|
121 |
+
benchmark_name_list = ['test_Set5', 'test_Set14', 'test_Urban100', 'test_Manga109', 'test_BSDS100']
|
122 |
+
|
123 |
+
# if benchmark:
|
124 |
+
for model_type in model_type_list:
|
125 |
+
excel_path = join(exp_root, model_type, f'IQA-val-{model_type}.xls')
|
126 |
+
for benchmark_name in benchmark_name_list:
|
127 |
+
exp_dir = join(exp_root, f'{model_type}/results_{epoch}_{add_name}/benchmark/{benchmark_name}')
|
128 |
+
gt_img_dir = join(exp_dir, 'HR')
|
129 |
+
sr_img_dir = join(exp_dir, 'SR')
|
130 |
+
|
131 |
+
data_name = benchmark_name[5:]
|
132 |
+
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
main()
|
sam_diffsr/tools/visualize_sam_mask.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
num = '0824'
|
9 |
+
|
10 |
+
sam_npy = '/home/ma-user/work/data/sr_sam/merge_RoPE/DF2K/DF2K_train_HR'
|
11 |
+
save_dir = '/home/ma-user/work/data/sr_sam/merge_RoPE/vis/DF2K/DF2K_train_HR'
|
12 |
+
|
13 |
+
os.makedirs(save_dir, exist_ok=True)
|
14 |
+
|
15 |
+
for file in tqdm(glob.glob(f'{sam_npy}/*.npy')):
|
16 |
+
name = os.path.basename(file).split('.')[0]
|
17 |
+
save_path = os.path.join(save_dir, f'{name}.png')
|
18 |
+
img = np.load(file)
|
19 |
+
plt.imshow(img)
|
20 |
+
plt.savefig(save_path)
|
sam_diffsr/utils_sr/__init__.py
ADDED
File without changes
|
sam_diffsr/utils_sr/dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
from .hparams import hparams
|
7 |
+
from .indexed_datasets import IndexedDataset
|
8 |
+
from .matlab_resize import imresize
|
9 |
+
|
10 |
+
|
11 |
+
class SRDataSet(Dataset):
|
12 |
+
def __init__(self, prefix='train'):
|
13 |
+
self.hparams = hparams
|
14 |
+
self.data_dir = hparams['binary_data_dir']
|
15 |
+
self.prefix = prefix
|
16 |
+
self.len = len(IndexedDataset(f'{self.data_dir}/{self.prefix}'))
|
17 |
+
self.to_tensor_norm = transforms.Compose([
|
18 |
+
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
19 |
+
])
|
20 |
+
assert hparams['data_interp'] in ['bilinear', 'bicubic']
|
21 |
+
self.data_augmentation = hparams['data_augmentation']
|
22 |
+
self.indexed_ds = None
|
23 |
+
if self.prefix == 'valid':
|
24 |
+
self.len = hparams['eval_batch_size'] * hparams['valid_steps']
|
25 |
+
|
26 |
+
def _get_item(self, index):
|
27 |
+
if self.indexed_ds is None:
|
28 |
+
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
|
29 |
+
return self.indexed_ds[index]
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
item = self._get_item(index)
|
33 |
+
hparams = self.hparams
|
34 |
+
img_hr = item['img']
|
35 |
+
img_hr = Image.fromarray(np.uint8(img_hr))
|
36 |
+
img_hr = self.pre_process(img_hr) # PIL
|
37 |
+
img_hr = np.asarray(img_hr) # np.uint8 [H, W, C]
|
38 |
+
img_lr = imresize(img_hr, 1 / hparams['sr_scale'], method=hparams['data_interp']) # np.uint8 [H, W, C]
|
39 |
+
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
|
40 |
+
img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]]
|
41 |
+
return {
|
42 |
+
'img_hr': img_hr, 'img_lr': img_lr, 'img_lr_up': img_lr_up,
|
43 |
+
'item_name': item['item_name']
|
44 |
+
}
|
45 |
+
|
46 |
+
def pre_process(self, img_hr):
|
47 |
+
return img_hr
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return self.len
|
sam_diffsr/utils_sr/hparams.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
global_print_hparams = True
|
8 |
+
hparams = {}
|
9 |
+
|
10 |
+
|
11 |
+
class Args:
|
12 |
+
def __init__(self, **kwargs):
|
13 |
+
for k, v in kwargs.items():
|
14 |
+
self.__setattr__(k, v)
|
15 |
+
|
16 |
+
|
17 |
+
def override_config(old_config: dict, new_config: dict):
|
18 |
+
for k, v in new_config.items():
|
19 |
+
if isinstance(v, dict) and k in old_config:
|
20 |
+
override_config(old_config[k], new_config[k])
|
21 |
+
else:
|
22 |
+
old_config[k] = v
|
23 |
+
|
24 |
+
|
25 |
+
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
26 |
+
parent_path = Path(__file__).absolute().parent.parent
|
27 |
+
fill_root = os.path.abspath(parent_path)
|
28 |
+
|
29 |
+
if config == '' and exp_name == '':
|
30 |
+
parser = argparse.ArgumentParser(description='')
|
31 |
+
parser.add_argument('--config', type=str, default=os.path.join(fill_root, 'configs/sam/sam_diffsr_df2k4x.yaml'),
|
32 |
+
help='location of the data corpus')
|
33 |
+
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
34 |
+
parser.add_argument('--work_dir', type=str, default='', help='work dir')
|
35 |
+
parser.add_argument('--gt_img_path', type=str, default='data/sr_diff/benchmark', help='gt_img_path')
|
36 |
+
parser.add_argument('-hp', '--hparams', type=str, default='',
|
37 |
+
help='location of the data corpus')
|
38 |
+
parser.add_argument('--infer', action='store_true', help='infer')
|
39 |
+
parser.add_argument('--benchmark', action='store_true', help='test benchmark')
|
40 |
+
parser.add_argument('--benchmark_loop', action='store_true', help='loop test benchmark for all checkpoint')
|
41 |
+
parser.add_argument('--benchmark_name_list', nargs='+',
|
42 |
+
default=['test_Set5', 'test_Set14', 'test_Urban100', 'test_Manga109', 'test_BSDS100'])
|
43 |
+
parser.add_argument('--metric_list', nargs='+', default=['psnr-Y', 'ssim', 'fid'])
|
44 |
+
parser.add_argument('--validate', action='store_true', help='validate')
|
45 |
+
parser.add_argument('--val_steps', type=int, default=None, help='validate steps')
|
46 |
+
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
47 |
+
parser.add_argument('--debug', action='store_true', help='debug')
|
48 |
+
|
49 |
+
parser.add_argument('--img_dir', type=str, default='', help='infer input image dir')
|
50 |
+
parser.add_argument('--save_dir', type=str, default='', help='infer output image dir')
|
51 |
+
parser.add_argument('--ckpt_path', type=str, default='', help='infer ckpt path')
|
52 |
+
|
53 |
+
|
54 |
+
args, unknown = parser.parse_known_args()
|
55 |
+
print("| Unknow hparams: ", unknown)
|
56 |
+
else:
|
57 |
+
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
58 |
+
infer=False, validate=False, reset=False, debug=False)
|
59 |
+
global hparams
|
60 |
+
assert args.config != '' or args.exp_name != ''
|
61 |
+
if args.config != '':
|
62 |
+
assert os.path.exists(args.config)
|
63 |
+
|
64 |
+
config_chains = []
|
65 |
+
loaded_config = set()
|
66 |
+
|
67 |
+
def load_config(config_fn):
|
68 |
+
# deep first inheritance and avoid the second visit of one node
|
69 |
+
if not os.path.exists(config_fn):
|
70 |
+
return {}
|
71 |
+
with open(config_fn) as f:
|
72 |
+
hparams_ = yaml.safe_load(f)
|
73 |
+
loaded_config.add(config_fn)
|
74 |
+
if 'base_config' in hparams_:
|
75 |
+
ret_hparams = {}
|
76 |
+
if not isinstance(hparams_['base_config'], list):
|
77 |
+
hparams_['base_config'] = [hparams_['base_config']]
|
78 |
+
for c in hparams_['base_config']:
|
79 |
+
if c.startswith('.'):
|
80 |
+
c = f'{os.path.dirname(config_fn)}/{c}'
|
81 |
+
c = os.path.normpath(c)
|
82 |
+
if c not in loaded_config:
|
83 |
+
override_config(ret_hparams, load_config(c))
|
84 |
+
override_config(ret_hparams, hparams_)
|
85 |
+
else:
|
86 |
+
ret_hparams = hparams_
|
87 |
+
config_chains.append(config_fn)
|
88 |
+
return ret_hparams
|
89 |
+
|
90 |
+
saved_hparams = {}
|
91 |
+
args_work_dir = ''
|
92 |
+
if args.exp_name != '':
|
93 |
+
args_work_dir = os.path.join(args.work_dir, 'checkpoints', args.exp_name)
|
94 |
+
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
95 |
+
if os.path.exists(ckpt_config_path):
|
96 |
+
with open(ckpt_config_path) as f:
|
97 |
+
saved_hparams_ = yaml.safe_load(f)
|
98 |
+
if saved_hparams_ is not None:
|
99 |
+
saved_hparams.update(saved_hparams_)
|
100 |
+
hparams_ = {}
|
101 |
+
if args.config != '':
|
102 |
+
hparams_.update(load_config(args.config))
|
103 |
+
if not args.reset:
|
104 |
+
hparams_.update(saved_hparams)
|
105 |
+
hparams_['work_dir'] = args_work_dir
|
106 |
+
|
107 |
+
# Support config overriding in command line. Support list type config overriding.
|
108 |
+
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
|
109 |
+
if args.hparams != "":
|
110 |
+
for new_hparam in args.hparams.split(","):
|
111 |
+
k, v = new_hparam.split("=")
|
112 |
+
v = v.strip("\'\" ")
|
113 |
+
config_node = hparams_
|
114 |
+
for k_ in k.split(".")[:-1]:
|
115 |
+
config_node = config_node[k_]
|
116 |
+
k = k.split(".")[-1]
|
117 |
+
if k not in config_node:
|
118 |
+
config_node[k] = v
|
119 |
+
|
120 |
+
elif v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
|
121 |
+
if type(config_node[k]) == list:
|
122 |
+
v = v.replace(" ", ",")
|
123 |
+
config_node[k] = eval(v)
|
124 |
+
else:
|
125 |
+
config_node[k] = type(config_node[k])(v)
|
126 |
+
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
127 |
+
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
128 |
+
with open(ckpt_config_path, 'w') as f:
|
129 |
+
yaml.safe_dump(hparams_, f)
|
130 |
+
|
131 |
+
hparams_['infer'] = args.infer
|
132 |
+
hparams_['debug'] = args.debug
|
133 |
+
hparams_['validate'] = args.validate
|
134 |
+
hparams_['exp_name'] = args.exp_name
|
135 |
+
hparams_['val_steps'] = args.val_steps
|
136 |
+
hparams_['benchmark'] = args.benchmark
|
137 |
+
hparams_['benchmark_loop'] = args.benchmark_loop
|
138 |
+
hparams_['benchmark_name_list'] = args.benchmark_name_list
|
139 |
+
hparams_['gt_img_path'] = args.gt_img_path
|
140 |
+
hparams_['metric_list'] = args.metric_list
|
141 |
+
|
142 |
+
hparams_['img_dir'] = args.img_dir
|
143 |
+
hparams_['save_dir'] = args.save_dir
|
144 |
+
hparams_['ckpt_path'] = args.ckpt_path
|
145 |
+
|
146 |
+
global global_print_hparams
|
147 |
+
if global_hparams:
|
148 |
+
hparams.clear()
|
149 |
+
hparams.update(hparams_)
|
150 |
+
if print_hparams and global_print_hparams and global_hparams:
|
151 |
+
print('| Hparams chains: ', config_chains)
|
152 |
+
print('| Hparams: ')
|
153 |
+
for i, (k, v) in enumerate(sorted(hparams_.items())):
|
154 |
+
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
155 |
+
print("")
|
156 |
+
global_print_hparams = False
|
157 |
+
return hparams_
|
sam_diffsr/utils_sr/indexed_datasets.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class IndexedDataset:
|
7 |
+
def __init__(self, path):
|
8 |
+
super().__init__()
|
9 |
+
self.path = path
|
10 |
+
self.data_file = None
|
11 |
+
index_data = np.load(f"{path}.idx", allow_pickle=True).item()
|
12 |
+
self.byte_offsets = index_data['offsets']
|
13 |
+
self.id2pos = index_data.get('id2pos', {})
|
14 |
+
self.data_file = open(f"{path}.data", 'rb', buffering=-1)
|
15 |
+
|
16 |
+
def check_index(self, i):
|
17 |
+
if i < 0 or i >= len(self.byte_offsets) - 1:
|
18 |
+
raise IndexError('index out of range')
|
19 |
+
|
20 |
+
def __del__(self):
|
21 |
+
if self.data_file:
|
22 |
+
self.data_file.close()
|
23 |
+
|
24 |
+
def __getitem__(self, i):
|
25 |
+
if self.id2pos is not None and len(self.id2pos) > 0:
|
26 |
+
i = self.id2pos[i]
|
27 |
+
self.check_index(i)
|
28 |
+
self.data_file.seek(self.byte_offsets[i])
|
29 |
+
b = self.data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
|
30 |
+
item = pickle.loads(b)
|
31 |
+
return item
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.byte_offsets) - 1
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
self.iter_i = 0
|
38 |
+
return self
|
39 |
+
|
40 |
+
def __next__(self):
|
41 |
+
if self.iter_i == len(self):
|
42 |
+
raise StopIteration
|
43 |
+
else:
|
44 |
+
item = self[self.iter_i]
|
45 |
+
self.iter_i += 1
|
46 |
+
return item
|
47 |
+
|
48 |
+
|
49 |
+
class IndexedDatasetBuilder:
|
50 |
+
def __init__(self, path, append=False):
|
51 |
+
self.path = path
|
52 |
+
if append:
|
53 |
+
self.data_file = open(f"{path}.data", 'ab')
|
54 |
+
index_data = np.load(f"{path}.idx", allow_pickle=True).item()
|
55 |
+
self.byte_offsets = index_data['offsets']
|
56 |
+
self.id2pos = index_data.get('id2pos', {})
|
57 |
+
else:
|
58 |
+
self.data_file = open(f"{path}.data", 'wb')
|
59 |
+
self.byte_offsets = [0]
|
60 |
+
self.id2pos = {}
|
61 |
+
|
62 |
+
def add_item(self, item, id=None):
|
63 |
+
s = pickle.dumps(item)
|
64 |
+
bytes = self.data_file.write(s)
|
65 |
+
if id is not None:
|
66 |
+
self.id2pos[id] = len(self.byte_offsets) - 1
|
67 |
+
self.byte_offsets.append(self.byte_offsets[-1] + bytes)
|
68 |
+
|
69 |
+
def finalize(self):
|
70 |
+
self.data_file.close()
|
71 |
+
np.save(open(f"{self.path}.idx", 'wb'),
|
72 |
+
{'offsets': self.byte_offsets, 'id2pos': self.id2pos})
|
sam_diffsr/utils_sr/matlab_resize.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/fatheral/matlab_imresize
|
2 |
+
#
|
3 |
+
# MIT License
|
4 |
+
#
|
5 |
+
# Copyright (c) 2020 Alex
|
6 |
+
#
|
7 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
# of this software and associated documentation files (the "Software"), to deal
|
9 |
+
# in the Software without restriction, including without limitation the rights
|
10 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
# copies of the Software, and to permit persons to whom the Software is
|
12 |
+
# furnished to do so, subject to the following conditions:
|
13 |
+
#
|
14 |
+
# The above copyright notice and this permission notice shall be included in all
|
15 |
+
# copies or substantial portions of the Software.
|
16 |
+
#
|
17 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
# SOFTWARE.
|
24 |
+
|
25 |
+
|
26 |
+
from __future__ import print_function
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
from math import ceil
|
30 |
+
|
31 |
+
|
32 |
+
def deriveSizeFromScale(img_shape, scale):
|
33 |
+
output_shape = []
|
34 |
+
for k in range(2):
|
35 |
+
output_shape.append(int(ceil(scale[k] * img_shape[k])))
|
36 |
+
return output_shape
|
37 |
+
|
38 |
+
|
39 |
+
def deriveScaleFromSize(img_shape_in, img_shape_out):
|
40 |
+
scale = []
|
41 |
+
for k in range(2):
|
42 |
+
scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
|
43 |
+
return scale
|
44 |
+
|
45 |
+
|
46 |
+
def triangle(x):
|
47 |
+
x = np.array(x).astype(np.float64)
|
48 |
+
lessthanzero = np.logical_and((x >= -1), x < 0)
|
49 |
+
greaterthanzero = np.logical_and((x <= 1), x >= 0)
|
50 |
+
f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
|
51 |
+
return f
|
52 |
+
|
53 |
+
|
54 |
+
def cubic(x):
|
55 |
+
x = np.array(x).astype(np.float64)
|
56 |
+
absx = np.absolute(x)
|
57 |
+
absx2 = np.multiply(absx, absx)
|
58 |
+
absx3 = np.multiply(absx2, absx)
|
59 |
+
f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
|
60 |
+
(1 < absx) & (absx <= 2))
|
61 |
+
return f
|
62 |
+
|
63 |
+
|
64 |
+
def contributions(in_length, out_length, scale, kernel, k_width):
|
65 |
+
if scale < 1:
|
66 |
+
h = lambda x: scale * kernel(scale * x)
|
67 |
+
kernel_width = 1.0 * k_width / scale
|
68 |
+
else:
|
69 |
+
h = kernel
|
70 |
+
kernel_width = k_width
|
71 |
+
x = np.arange(1, out_length + 1).astype(np.float64)
|
72 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
73 |
+
left = np.floor(u - kernel_width / 2)
|
74 |
+
P = int(ceil(kernel_width)) + 2
|
75 |
+
ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
|
76 |
+
indices = ind.astype(np.int32)
|
77 |
+
weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
|
78 |
+
weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
|
79 |
+
aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
|
80 |
+
indices = aux[np.mod(indices, aux.size)]
|
81 |
+
ind2store = np.nonzero(np.any(weights, axis=0))
|
82 |
+
weights = weights[:, ind2store]
|
83 |
+
indices = indices[:, ind2store]
|
84 |
+
return weights, indices
|
85 |
+
|
86 |
+
|
87 |
+
def imresizemex(inimg, weights, indices, dim):
|
88 |
+
in_shape = inimg.shape
|
89 |
+
w_shape = weights.shape
|
90 |
+
out_shape = list(in_shape)
|
91 |
+
out_shape[dim] = w_shape[0]
|
92 |
+
outimg = np.zeros(out_shape)
|
93 |
+
if dim == 0:
|
94 |
+
for i_img in range(in_shape[1]):
|
95 |
+
for i_w in range(w_shape[0]):
|
96 |
+
w = weights[i_w, :]
|
97 |
+
ind = indices[i_w, :]
|
98 |
+
im_slice = inimg[ind, i_img].astype(np.float64)
|
99 |
+
outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
|
100 |
+
elif dim == 1:
|
101 |
+
for i_img in range(in_shape[0]):
|
102 |
+
for i_w in range(w_shape[0]):
|
103 |
+
w = weights[i_w, :]
|
104 |
+
ind = indices[i_w, :]
|
105 |
+
im_slice = inimg[i_img, ind].astype(np.float64)
|
106 |
+
outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
|
107 |
+
if inimg.dtype == np.uint8:
|
108 |
+
outimg = np.clip(outimg, 0, 255)
|
109 |
+
return np.around(outimg).astype(np.uint8)
|
110 |
+
else:
|
111 |
+
return outimg
|
112 |
+
|
113 |
+
|
114 |
+
def imresizevec(inimg, weights, indices, dim):
|
115 |
+
wshape = weights.shape
|
116 |
+
if dim == 0:
|
117 |
+
weights = weights.reshape((wshape[0], wshape[2], 1, 1))
|
118 |
+
outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
|
119 |
+
elif dim == 1:
|
120 |
+
weights = weights.reshape((1, wshape[0], wshape[2], 1))
|
121 |
+
outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
|
122 |
+
if inimg.dtype == np.uint8:
|
123 |
+
outimg = np.clip(outimg, 0, 255)
|
124 |
+
return np.around(outimg).astype(np.uint8)
|
125 |
+
else:
|
126 |
+
return outimg
|
127 |
+
|
128 |
+
|
129 |
+
def resizeAlongDim(A, dim, weights, indices, mode="vec"):
|
130 |
+
if mode == "org":
|
131 |
+
out = imresizemex(A, weights, indices, dim)
|
132 |
+
else:
|
133 |
+
out = imresizevec(A, weights, indices, dim)
|
134 |
+
return out
|
135 |
+
|
136 |
+
|
137 |
+
def imresize(I, scale=None, method='bicubic', sizes=None, mode="vec"):
|
138 |
+
if method == 'bicubic':
|
139 |
+
kernel = cubic
|
140 |
+
elif method == 'bilinear':
|
141 |
+
kernel = triangle
|
142 |
+
else:
|
143 |
+
print('Error: Unidentified method supplied')
|
144 |
+
|
145 |
+
kernel_width = 4.0
|
146 |
+
# Fill scale and output_size
|
147 |
+
if scale is not None:
|
148 |
+
scale = float(scale)
|
149 |
+
scale = [scale, scale]
|
150 |
+
output_size = deriveSizeFromScale(I.shape, scale)
|
151 |
+
elif sizes is not None:
|
152 |
+
scale = deriveScaleFromSize(I.shape, sizes)
|
153 |
+
output_size = list(sizes)
|
154 |
+
else:
|
155 |
+
print('Error: scalar_scale OR output_shape should be defined!')
|
156 |
+
return
|
157 |
+
scale_np = np.array(scale)
|
158 |
+
order = np.argsort(scale_np)
|
159 |
+
weights = []
|
160 |
+
indices = []
|
161 |
+
for k in range(2):
|
162 |
+
w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
|
163 |
+
weights.append(w)
|
164 |
+
indices.append(ind)
|
165 |
+
B = np.copy(I)
|
166 |
+
flag2D = False
|
167 |
+
if B.ndim == 2:
|
168 |
+
B = np.expand_dims(B, axis=2)
|
169 |
+
flag2D = True
|
170 |
+
for k in range(2):
|
171 |
+
dim = order[k]
|
172 |
+
B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
|
173 |
+
if flag2D:
|
174 |
+
B = np.squeeze(B, axis=2)
|
175 |
+
return B
|
176 |
+
|
177 |
+
|
178 |
+
def convertDouble2Byte(I):
|
179 |
+
B = np.clip(I, 0.0, 1.0)
|
180 |
+
B = 255 * B
|
181 |
+
return np.around(B).astype(np.uint8)
|
sam_diffsr/utils_sr/plt_img.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
|
7 |
+
|
8 |
+
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
9 |
+
"""Convert torch Tensors into image numpy arrays.
|
10 |
+
|
11 |
+
After clamping to (min, max), image values will be normalized to [0, 1].
|
12 |
+
|
13 |
+
For different tensor shapes, this function will have different behaviors:
|
14 |
+
|
15 |
+
1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
|
16 |
+
Use `make_grid` to stitch images in the batch dimension, and then
|
17 |
+
convert it to numpy array.
|
18 |
+
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
|
19 |
+
Directly change to numpy array.
|
20 |
+
|
21 |
+
Note that the image channel in input tensors should be RGB order. This
|
22 |
+
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
|
23 |
+
order.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
tensor (Tensor | list[Tensor]): Input tensors.
|
27 |
+
out_type (numpy type): Output types. If ``np.uint8``, transform outputs
|
28 |
+
to uint8 type with range [0, 255]; otherwise, float type with
|
29 |
+
range [0, 1]. Default: ``np.uint8``.
|
30 |
+
min_max (tuple): min and max values for clamp.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
|
34 |
+
of shape (H x W).
|
35 |
+
"""
|
36 |
+
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
37 |
+
raise TypeError(
|
38 |
+
f'tensor or list of tensors expected, got {type(tensor)}')
|
39 |
+
|
40 |
+
if torch.is_tensor(tensor):
|
41 |
+
tensor = [tensor]
|
42 |
+
result = []
|
43 |
+
for _tensor in tensor:
|
44 |
+
# Squeeze two times so that:
|
45 |
+
# 1. (1, 1, h, w) -> (h, w) or
|
46 |
+
# 3. (1, 3, h, w) -> (3, h, w) or
|
47 |
+
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
|
48 |
+
_tensor = _tensor.squeeze(0).squeeze(0)
|
49 |
+
_tensor = _tensor.float().detach().cpu().clamp_(*min_max)
|
50 |
+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
51 |
+
n_dim = _tensor.dim()
|
52 |
+
if n_dim == 4:
|
53 |
+
img_np = make_grid(
|
54 |
+
_tensor, nrow=int(math.sqrt(_tensor.size(0))),
|
55 |
+
normalize=False).numpy()
|
56 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
|
57 |
+
elif n_dim == 3:
|
58 |
+
img_np = _tensor.numpy()
|
59 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
|
60 |
+
elif n_dim == 2:
|
61 |
+
img_np = _tensor.numpy()
|
62 |
+
else:
|
63 |
+
raise ValueError('Only support 4D, 3D or 2D tensor. '
|
64 |
+
f'But received with dimension: {n_dim}')
|
65 |
+
if out_type == np.uint8:
|
66 |
+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
67 |
+
img_np = (img_np * 255.0).round()
|
68 |
+
img_np = img_np.astype(out_type)
|
69 |
+
result.append(img_np)
|
70 |
+
result = result[0] if len(result) == 1 else result
|
71 |
+
return result
|
72 |
+
|
73 |
+
|
74 |
+
def plt_tensor_img(tensor, save_path=None):
|
75 |
+
plt.imshow(tensor2img(tensor))
|
76 |
+
plt.show()
|
77 |
+
|
78 |
+
if save_path:
|
79 |
+
plt.savefig(save_path)
|
80 |
+
|
81 |
+
|
82 |
+
def plt_tensor_img_one(tensor, t_dim=1):
|
83 |
+
if isinstance(tensor, list):
|
84 |
+
tensor = torch.cat(tensor, dim=t_dim)
|
85 |
+
nums = tensor.shape[t_dim]
|
86 |
+
|
87 |
+
mash = math.ceil(math.sqrt(nums))
|
88 |
+
|
89 |
+
plt.figure(dpi=300)
|
90 |
+
plt_range = min(nums, mash ** 2)
|
91 |
+
for i in range(plt_range):
|
92 |
+
plt.subplot(mash, mash, i + 1)
|
93 |
+
if t_dim == 1:
|
94 |
+
img = tensor2img(tensor[:, i, ...])
|
95 |
+
elif t_dim == 0:
|
96 |
+
img = tensor2img(tensor[i, ...])
|
97 |
+
plt.imshow(img)
|
98 |
+
plt.xticks([])
|
99 |
+
plt.yticks([])
|
100 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
101 |
+
plt.tight_layout()
|
102 |
+
plt.show()
|
103 |
+
|
104 |
+
|
105 |
+
def plt_img(img, save_path=None):
|
106 |
+
plt.imshow(img)
|
107 |
+
plt.show()
|
108 |
+
if save_path:
|
109 |
+
plt.savefig(save_path)
|
sam_diffsr/utils_sr/sr_utils.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torchvision
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import numpy as np
|
6 |
+
from math import exp
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class ImgMerger:
|
11 |
+
def __init__(self, eval_fn):
|
12 |
+
self.eval_fn = eval_fn
|
13 |
+
self.loc2imgs = {}
|
14 |
+
self.max_x = 0
|
15 |
+
self.max_y = 0
|
16 |
+
self.clear()
|
17 |
+
|
18 |
+
def clear(self):
|
19 |
+
self.loc2imgs = {}
|
20 |
+
self.max_x = 0
|
21 |
+
self.max_y = 0
|
22 |
+
|
23 |
+
def push(self, imgs, loc, loc_bdr):
|
24 |
+
"""
|
25 |
+
|
26 |
+
Args:
|
27 |
+
imgs: each of img is [C, H, W] np.array, range: [0, 255]
|
28 |
+
loc: string, e.g., 0_0, 0_1 ...
|
29 |
+
"""
|
30 |
+
self.max_x, self.max_y = loc_bdr
|
31 |
+
x, y = loc
|
32 |
+
self.loc2imgs[f'{x},{y}'] = imgs
|
33 |
+
if len(self.loc2imgs) == self.max_x * self.max_y:
|
34 |
+
return self.compute()
|
35 |
+
|
36 |
+
def compute(self):
|
37 |
+
img_inputs = []
|
38 |
+
for i in range(len(self.loc2imgs['0,0'])):
|
39 |
+
img_full = []
|
40 |
+
for x in range(self.max_x):
|
41 |
+
imgx = []
|
42 |
+
for y in range(self.max_y):
|
43 |
+
imgx.append(self.loc2imgs[f'{x},{y}'][i])
|
44 |
+
img_full.append(np.concatenate(imgx, 2))
|
45 |
+
img_inputs.append(np.concatenate(img_full, 1))
|
46 |
+
self.clear()
|
47 |
+
return self.eval_fn(*img_inputs)
|
48 |
+
|
49 |
+
|
50 |
+
##########
|
51 |
+
# SSIM
|
52 |
+
##########
|
53 |
+
def gaussian(window_size, sigma):
|
54 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
55 |
+
return gauss / gauss.sum()
|
56 |
+
|
57 |
+
|
58 |
+
def create_window(window_size, channel):
|
59 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
60 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
61 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
62 |
+
return window
|
63 |
+
|
64 |
+
|
65 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
66 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
67 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
68 |
+
|
69 |
+
mu1_sq = mu1.pow(2)
|
70 |
+
mu2_sq = mu2.pow(2)
|
71 |
+
mu1_mu2 = mu1 * mu2
|
72 |
+
|
73 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
74 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
75 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
76 |
+
|
77 |
+
C1 = 0.01 ** 2
|
78 |
+
C2 = 0.03 ** 2
|
79 |
+
|
80 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
81 |
+
|
82 |
+
if size_average:
|
83 |
+
return ssim_map.mean()
|
84 |
+
else:
|
85 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
86 |
+
|
87 |
+
|
88 |
+
class SSIM(torch.nn.Module):
|
89 |
+
def __init__(self, window_size=11, size_average=True):
|
90 |
+
super(SSIM, self).__init__()
|
91 |
+
self.window_size = window_size
|
92 |
+
self.size_average = size_average
|
93 |
+
self.channel = 1
|
94 |
+
self.window = create_window(window_size, self.channel)
|
95 |
+
|
96 |
+
def forward(self, img1, img2):
|
97 |
+
img1 = img1 * 0.5 + 0.5
|
98 |
+
img2 = img2 * 0.5 + 0.5
|
99 |
+
(_, channel, _, _) = img1.size()
|
100 |
+
|
101 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
102 |
+
window = self.window
|
103 |
+
else:
|
104 |
+
window = create_window(self.window_size, channel)
|
105 |
+
|
106 |
+
if img1.is_cuda:
|
107 |
+
window = window.cuda(img1.get_device())
|
108 |
+
window = window.type_as(img1)
|
109 |
+
|
110 |
+
self.window = window
|
111 |
+
self.channel = channel
|
112 |
+
|
113 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
114 |
+
|
115 |
+
|
116 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
117 |
+
(_, channel, _, _) = img1.size()
|
118 |
+
window = create_window(window_size, channel)
|
119 |
+
|
120 |
+
if img1.is_cuda:
|
121 |
+
window = window.cuda(img1.get_device())
|
122 |
+
window = window.type_as(img1)
|
123 |
+
|
124 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
125 |
+
|
126 |
+
|
127 |
+
class VGGFeatureExtractor(nn.Module):
|
128 |
+
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True):
|
129 |
+
super(VGGFeatureExtractor, self).__init__()
|
130 |
+
self.use_input_norm = use_input_norm
|
131 |
+
if use_bn:
|
132 |
+
model = torchvision.models.vgg19_bn(pretrained=True)
|
133 |
+
else:
|
134 |
+
model = torchvision.models.vgg19(pretrained=True)
|
135 |
+
if self.use_input_norm:
|
136 |
+
mean = torch.Tensor([0.485 - 1, 0.456 - 1, 0.406 - 1]).view(1, 3, 1, 1)
|
137 |
+
# mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
138 |
+
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
139 |
+
std = torch.Tensor([0.229 * 2, 0.224 * 2, 0.225 * 2]).view(1, 3, 1, 1)
|
140 |
+
# std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
141 |
+
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
142 |
+
self.register_buffer('mean', mean)
|
143 |
+
self.register_buffer('std', std)
|
144 |
+
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
145 |
+
# No need to BP to variable
|
146 |
+
for k, v in self.features.named_parameters():
|
147 |
+
v.requires_grad = False
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
# Assume input range is [0, 1]
|
151 |
+
if self.use_input_norm:
|
152 |
+
x = (x - self.mean) / self.std
|
153 |
+
output = self.features(x)
|
154 |
+
return output
|
155 |
+
|
156 |
+
|
157 |
+
class PerceptualLoss(nn.Module):
|
158 |
+
def __init__(self):
|
159 |
+
super(PerceptualLoss, self).__init__()
|
160 |
+
loss_network = VGGFeatureExtractor()
|
161 |
+
for param in loss_network.parameters():
|
162 |
+
param.requires_grad = False
|
163 |
+
self.loss_network = loss_network
|
164 |
+
self.l1_loss = nn.L1Loss()
|
165 |
+
|
166 |
+
def forward(self, high_resolution, fake_high_resolution):
|
167 |
+
if next(self.loss_network.parameters()).device != high_resolution.device:
|
168 |
+
self.loss_network.to(high_resolution.device)
|
169 |
+
self.loss_network.eval()
|
170 |
+
perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution))
|
171 |
+
return perception_loss
|
sam_diffsr/utils_sr/utils.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import subprocess
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import lpips
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr
|
12 |
+
from skimage.metrics import structural_similarity as ssim
|
13 |
+
|
14 |
+
from .matlab_resize import imresize
|
15 |
+
|
16 |
+
|
17 |
+
def reduce_tensors(metrics):
|
18 |
+
new_metrics = {}
|
19 |
+
for k, v in metrics.items():
|
20 |
+
if isinstance(v, torch.Tensor):
|
21 |
+
dist.all_reduce(v)
|
22 |
+
v = v / dist.get_world_size()
|
23 |
+
if type(v) is dict:
|
24 |
+
v = reduce_tensors(v)
|
25 |
+
new_metrics[k] = v
|
26 |
+
return new_metrics
|
27 |
+
|
28 |
+
|
29 |
+
def tensors_to_scalars(tensors):
|
30 |
+
if isinstance(tensors, torch.Tensor):
|
31 |
+
tensors = tensors.item()
|
32 |
+
return tensors
|
33 |
+
elif isinstance(tensors, dict):
|
34 |
+
new_tensors = {}
|
35 |
+
for k, v in tensors.items():
|
36 |
+
v = tensors_to_scalars(v)
|
37 |
+
new_tensors[k] = v
|
38 |
+
return new_tensors
|
39 |
+
elif isinstance(tensors, list):
|
40 |
+
return [tensors_to_scalars(v) for v in tensors]
|
41 |
+
else:
|
42 |
+
return tensors
|
43 |
+
|
44 |
+
|
45 |
+
def tensors_to_np(tensors):
|
46 |
+
if isinstance(tensors, dict):
|
47 |
+
new_np = {}
|
48 |
+
for k, v in tensors.items():
|
49 |
+
if isinstance(v, torch.Tensor):
|
50 |
+
v = v.cpu().numpy()
|
51 |
+
if type(v) is dict:
|
52 |
+
v = tensors_to_np(v)
|
53 |
+
new_np[k] = v
|
54 |
+
elif isinstance(tensors, list):
|
55 |
+
new_np = []
|
56 |
+
for v in tensors:
|
57 |
+
if isinstance(v, torch.Tensor):
|
58 |
+
v = v.cpu().numpy()
|
59 |
+
if type(v) is dict:
|
60 |
+
v = tensors_to_np(v)
|
61 |
+
new_np.append(v)
|
62 |
+
elif isinstance(tensors, torch.Tensor):
|
63 |
+
v = tensors
|
64 |
+
if isinstance(v, torch.Tensor):
|
65 |
+
v = v.cpu().numpy()
|
66 |
+
if type(v) is dict:
|
67 |
+
v = tensors_to_np(v)
|
68 |
+
new_np = v
|
69 |
+
else:
|
70 |
+
raise Exception(f'tensors_to_np does not support type {type(tensors)}.')
|
71 |
+
return new_np
|
72 |
+
|
73 |
+
|
74 |
+
def move_to_cpu(tensors):
|
75 |
+
ret = {}
|
76 |
+
for k, v in tensors.items():
|
77 |
+
if isinstance(v, torch.Tensor):
|
78 |
+
v = v.cpu()
|
79 |
+
if type(v) is dict:
|
80 |
+
v = move_to_cpu(v)
|
81 |
+
ret[k] = v
|
82 |
+
return ret
|
83 |
+
|
84 |
+
|
85 |
+
def move_to_cuda(batch, gpu_id=0):
|
86 |
+
# base case: object can be directly moved using `cuda` or `to`
|
87 |
+
if callable(getattr(batch, 'cuda', None)):
|
88 |
+
return batch.cuda(gpu_id, non_blocking=True)
|
89 |
+
elif callable(getattr(batch, 'to', None)):
|
90 |
+
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
|
91 |
+
elif isinstance(batch, list):
|
92 |
+
for i, x in enumerate(batch):
|
93 |
+
batch[i] = move_to_cuda(x, gpu_id)
|
94 |
+
return batch
|
95 |
+
elif isinstance(batch, tuple):
|
96 |
+
batch = list(batch)
|
97 |
+
for i, x in enumerate(batch):
|
98 |
+
batch[i] = move_to_cuda(x, gpu_id)
|
99 |
+
return tuple(batch)
|
100 |
+
elif isinstance(batch, dict):
|
101 |
+
for k, v in batch.items():
|
102 |
+
batch[k] = move_to_cuda(v, gpu_id)
|
103 |
+
return batch
|
104 |
+
return batch
|
105 |
+
|
106 |
+
|
107 |
+
def get_last_checkpoint(work_dir, steps=None):
|
108 |
+
checkpoint = None
|
109 |
+
last_ckpt_path = None
|
110 |
+
ckpt_paths = get_all_ckpts(work_dir, steps)
|
111 |
+
if len(ckpt_paths) > 0:
|
112 |
+
last_ckpt_path = ckpt_paths[0]
|
113 |
+
checkpoint = torch.load(last_ckpt_path, map_location='cpu')
|
114 |
+
return checkpoint, last_ckpt_path
|
115 |
+
|
116 |
+
|
117 |
+
def get_all_ckpts(work_dir, steps=None):
|
118 |
+
if steps is None:
|
119 |
+
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
|
120 |
+
else:
|
121 |
+
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
|
122 |
+
return sorted(glob.glob(ckpt_path_pattern),
|
123 |
+
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
|
124 |
+
|
125 |
+
|
126 |
+
def load_checkpoint(model, optimizer, work_dir, steps=None):
|
127 |
+
checkpoint, last_ckpt_path = get_last_checkpoint(work_dir, steps)
|
128 |
+
print(f'loding check from: {last_ckpt_path}')
|
129 |
+
if checkpoint is not None:
|
130 |
+
stat_dict = checkpoint['state_dict']['model']
|
131 |
+
|
132 |
+
new_state_dict = OrderedDict()
|
133 |
+
for k, v in stat_dict.items():
|
134 |
+
if k[:7] == 'module.':
|
135 |
+
k = k[7:] # 去掉 `module.`
|
136 |
+
new_state_dict[k] = v
|
137 |
+
|
138 |
+
model.load_state_dict(new_state_dict)
|
139 |
+
model.cuda()
|
140 |
+
optimizer.load_state_dict(checkpoint['optimizer_states'][0])
|
141 |
+
training_step = checkpoint['global_step']
|
142 |
+
del checkpoint
|
143 |
+
torch.cuda.empty_cache()
|
144 |
+
else:
|
145 |
+
training_step = 0
|
146 |
+
model.cuda()
|
147 |
+
return training_step
|
148 |
+
|
149 |
+
|
150 |
+
def save_checkpoint(model, optimizer, work_dir, global_step, num_ckpt_keep):
|
151 |
+
ckpt_path = f'{work_dir}/model_ckpt_steps_{global_step}.ckpt'
|
152 |
+
print(f'Step@{global_step}: saving model to {ckpt_path}')
|
153 |
+
checkpoint = {'global_step': global_step}
|
154 |
+
optimizer_states = []
|
155 |
+
optimizer_states.append(optimizer.state_dict())
|
156 |
+
checkpoint['optimizer_states'] = optimizer_states
|
157 |
+
checkpoint['state_dict'] = {'model': model.state_dict()}
|
158 |
+
torch.save(checkpoint, ckpt_path, _use_new_zipfile_serialization=False)
|
159 |
+
for old_ckpt in get_all_ckpts(work_dir)[num_ckpt_keep:]:
|
160 |
+
remove_file(old_ckpt)
|
161 |
+
print(f'Delete ckpt: {os.path.basename(old_ckpt)}')
|
162 |
+
|
163 |
+
|
164 |
+
def remove_file(*fns):
|
165 |
+
for f in fns:
|
166 |
+
subprocess.check_call(f'rm -rf "{f}"', shell=True)
|
167 |
+
|
168 |
+
|
169 |
+
def plot_img(img):
|
170 |
+
img = img.data.cpu().numpy()
|
171 |
+
return np.clip(img, 0, 1)
|
172 |
+
|
173 |
+
|
174 |
+
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True):
|
175 |
+
if os.path.isfile(ckpt_base_dir):
|
176 |
+
base_dir = os.path.dirname(ckpt_base_dir)
|
177 |
+
ckpt_path = ckpt_base_dir
|
178 |
+
checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
|
179 |
+
else:
|
180 |
+
base_dir = ckpt_base_dir
|
181 |
+
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir)
|
182 |
+
if checkpoint is not None:
|
183 |
+
state_dict = checkpoint["state_dict"]
|
184 |
+
if len([k for k in state_dict.keys() if '.' in k]) > 0:
|
185 |
+
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
|
186 |
+
if k.startswith(f'{model_name}.')}
|
187 |
+
else:
|
188 |
+
state_dict = state_dict[model_name]
|
189 |
+
if not strict:
|
190 |
+
cur_model_state_dict = cur_model.state_dict()
|
191 |
+
unmatched_keys = []
|
192 |
+
for key, param in state_dict.items():
|
193 |
+
if key in cur_model_state_dict:
|
194 |
+
new_param = cur_model_state_dict[key]
|
195 |
+
if new_param.shape != param.shape:
|
196 |
+
unmatched_keys.append(key)
|
197 |
+
print("| Unmatched keys: ", key, new_param.shape, param.shape)
|
198 |
+
for key in unmatched_keys:
|
199 |
+
del state_dict[key]
|
200 |
+
cur_model.load_state_dict(state_dict, strict=strict)
|
201 |
+
print(f"| load '{model_name}' from '{ckpt_path}'.")
|
202 |
+
else:
|
203 |
+
e_msg = f"| ckpt not found in {base_dir}."
|
204 |
+
if force:
|
205 |
+
assert False, e_msg
|
206 |
+
else:
|
207 |
+
print(e_msg)
|
208 |
+
|
209 |
+
|
210 |
+
class Measure:
|
211 |
+
def __init__(self, net='alex'):
|
212 |
+
self.model = lpips.LPIPS(net=net)
|
213 |
+
|
214 |
+
def measure(self, imgA, imgB, img_lr, sr_scale):
|
215 |
+
"""
|
216 |
+
|
217 |
+
Args:
|
218 |
+
imgA: [C, H, W] uint8 or torch.FloatTensor [-1,1]
|
219 |
+
imgB: [C, H, W] uint8 or torch.FloatTensor [-1,1]
|
220 |
+
img_lr: [C, H, W] uint8 or torch.FloatTensor [-1,1]
|
221 |
+
sr_scale:
|
222 |
+
|
223 |
+
Returns: dict of metrics
|
224 |
+
|
225 |
+
"""
|
226 |
+
if isinstance(imgA, torch.Tensor):
|
227 |
+
imgA = np.round((imgA.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8)
|
228 |
+
imgB = np.round((imgB.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8)
|
229 |
+
img_lr = np.round((img_lr.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8)
|
230 |
+
imgA = imgA.transpose(1, 2, 0)
|
231 |
+
imgA_lr = imresize(imgA, 1 / sr_scale)
|
232 |
+
imgB = imgB.transpose(1, 2, 0)
|
233 |
+
img_lr = img_lr.transpose(1, 2, 0)
|
234 |
+
psnr = self.psnr(imgA, imgB)
|
235 |
+
ssim = self.ssim(imgA, imgB)
|
236 |
+
lpips = self.lpips(imgA, imgB)
|
237 |
+
lr_psnr = self.psnr(imgA_lr, img_lr)
|
238 |
+
res = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips, 'lr_psnr': lr_psnr}
|
239 |
+
return {k: float(v) for k, v in res.items()}
|
240 |
+
|
241 |
+
def lpips(self, imgA, imgB, model=None):
|
242 |
+
device = next(self.model.parameters()).device
|
243 |
+
tA = t(imgA).to(device)
|
244 |
+
tB = t(imgB).to(device)
|
245 |
+
dist01 = self.model.forward(tA, tB).item()
|
246 |
+
return dist01
|
247 |
+
|
248 |
+
def ssim(self, imgA, imgB):
|
249 |
+
score, diff = ssim(imgA, imgB, full=True, channel_axis=2, data_range=255)
|
250 |
+
return score
|
251 |
+
|
252 |
+
def psnr(self, imgA, imgB):
|
253 |
+
return psnr(imgA, imgB, data_range=255)
|
254 |
+
|
255 |
+
|
256 |
+
def t(img):
|
257 |
+
def to_4d(img):
|
258 |
+
assert len(img.shape) == 3
|
259 |
+
img_new = np.expand_dims(img, axis=0)
|
260 |
+
assert len(img_new.shape) == 4
|
261 |
+
return img_new
|
262 |
+
|
263 |
+
def to_CHW(img):
|
264 |
+
return np.transpose(img, [2, 0, 1])
|
265 |
+
|
266 |
+
def to_tensor(img):
|
267 |
+
return torch.Tensor(img)
|
268 |
+
|
269 |
+
return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
|
sam_diffsr/weight/model_ckpt_steps_400000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab89ee4160be868422459918eb69880042dc12544b1bf7807aa479c7eb329e55
|
3 |
+
size 204945145
|