added code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +153 -0
- config/diffsketchedit.yaml +75 -0
- docs/figures/refine/ldm_generated_image0.png +3 -0
- docs/figures/refine/ldm_generated_image1.png +3 -0
- docs/figures/refine/ldm_generated_image2.png +3 -0
- docs/figures/refine/visual_best-rendered0.png +3 -0
- docs/figures/refine/visual_best-rendered1.png +3 -0
- docs/figures/refine/visual_best-rendered2.png +3 -0
- docs/figures/replace/ldm_generated_image0.png +3 -0
- docs/figures/replace/ldm_generated_image1.png +3 -0
- docs/figures/replace/ldm_generated_image2.png +3 -0
- docs/figures/replace/ldm_generated_image3.png +3 -0
- docs/figures/replace/visual_best-rendered0.png +3 -0
- docs/figures/replace/visual_best-rendered1.png +3 -0
- docs/figures/replace/visual_best-rendered2.png +3 -0
- docs/figures/replace/visual_best-rendered3.png +3 -0
- docs/figures/reweight/ldm_generated_image0.png +3 -0
- docs/figures/reweight/ldm_generated_image1.png +3 -0
- docs/figures/reweight/ldm_generated_image2.png +3 -0
- docs/figures/reweight/visual_best-rendered0.png +3 -0
- docs/figures/reweight/visual_best-rendered1.png +3 -0
- docs/figures/reweight/visual_best-rendered2.png +3 -0
- libs/__init__.py +9 -0
- libs/engine/__init__.py +7 -0
- libs/engine/config_processor.py +151 -0
- libs/engine/model_state.py +335 -0
- libs/metric/__init__.py +1 -0
- libs/metric/accuracy.py +25 -0
- libs/metric/clip_score/__init__.py +3 -0
- libs/metric/clip_score/openaiCLIP_loss.py +304 -0
- libs/metric/lpips_origin/__init__.py +3 -0
- libs/metric/lpips_origin/lpips.py +184 -0
- libs/metric/lpips_origin/pretrained_networks.py +196 -0
- libs/metric/lpips_origin/weights/v0.1/alex.pth +0 -0
- libs/metric/lpips_origin/weights/v0.1/squeeze.pth +0 -0
- libs/metric/lpips_origin/weights/v0.1/vgg.pth +0 -0
- libs/metric/piq/__init__.py +2 -0
- libs/metric/piq/functional/__init__.py +15 -0
- libs/metric/piq/functional/base.py +111 -0
- libs/metric/piq/functional/colour_conversion.py +136 -0
- libs/metric/piq/functional/filters.py +111 -0
- libs/metric/piq/functional/layers.py +33 -0
- libs/metric/piq/functional/resize.py +426 -0
- libs/metric/piq/perceptual.py +496 -0
- libs/metric/piq/utils/__init__.py +7 -0
- libs/metric/piq/utils/common.py +158 -0
- libs/metric/pytorch_fid/__init__.py +54 -0
- libs/metric/pytorch_fid/fid_score.py +322 -0
- libs/metric/pytorch_fid/inception.py +341 -0
- libs/modules/__init__.py +1 -0
.gitignore
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# .idea
|
132 |
+
.idea/
|
133 |
+
/idea/
|
134 |
+
*.ipr
|
135 |
+
*.iml
|
136 |
+
*.iws
|
137 |
+
|
138 |
+
# system
|
139 |
+
.DS_Store
|
140 |
+
|
141 |
+
# pytorch-lighting logs
|
142 |
+
lightning_logs/*
|
143 |
+
|
144 |
+
# Edit settings
|
145 |
+
.editorconfig
|
146 |
+
|
147 |
+
# local results
|
148 |
+
/workdir/
|
149 |
+
.workdir/
|
150 |
+
|
151 |
+
# dataset
|
152 |
+
/dataset/
|
153 |
+
!/dataset/placeholder.md
|
config/diffsketchedit.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 1
|
2 |
+
image_size: 224
|
3 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
4 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
5 |
+
|
6 |
+
# train
|
7 |
+
num_iter: 1000
|
8 |
+
batch_size: 1
|
9 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
10 |
+
lr_scheduler: False
|
11 |
+
lr_decay_rate: 0.1
|
12 |
+
decay_steps: [ 1000, 1500 ]
|
13 |
+
lr: 1
|
14 |
+
color_lr: 0.01
|
15 |
+
pruning_freq: 50
|
16 |
+
color_vars_threshold: 0.1
|
17 |
+
width_lr: 0.1
|
18 |
+
max_width: 50 # stroke width
|
19 |
+
|
20 |
+
# stroke attrs
|
21 |
+
num_paths: 96 # number of strokes
|
22 |
+
width: 1.0 # stroke width
|
23 |
+
control_points_per_seg: 4
|
24 |
+
num_segments: 1
|
25 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
26 |
+
optim_width: False # if True, the stroke width is optimized
|
27 |
+
optim_rgba: False # if True, the stroke RGBA is optimized
|
28 |
+
opacity_delta: 0 # stroke pruning
|
29 |
+
|
30 |
+
# init strokes
|
31 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
32 |
+
xdog_intersec: True # initialize along the edge, mix XDoG and attn up
|
33 |
+
softmax_temp: 0.5
|
34 |
+
cross_attn_res: 16
|
35 |
+
self_attn_res: 32
|
36 |
+
max_com: 20 # select the number of the self-attn maps
|
37 |
+
mean_comp: False # the average of the self-attn maps
|
38 |
+
comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
|
39 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
40 |
+
log_cross_attn: False # True if cross attn every step
|
41 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
42 |
+
|
43 |
+
# ldm
|
44 |
+
model_id: "sd14"
|
45 |
+
ldm_speed_up: False
|
46 |
+
enable_xformers: False
|
47 |
+
gradient_checkpoint: False
|
48 |
+
#token_ind: 1 # the index of CLIP prompt embedding, start from 1
|
49 |
+
use_ddim: True
|
50 |
+
num_inference_steps: 50
|
51 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
52 |
+
# ASDS loss
|
53 |
+
sds:
|
54 |
+
crop_size: 512
|
55 |
+
augmentations: "affine"
|
56 |
+
guidance_scale: 100
|
57 |
+
grad_scale: 1e-5
|
58 |
+
t_range: [ 0.05, 0.95 ]
|
59 |
+
warmup: 0
|
60 |
+
|
61 |
+
clip:
|
62 |
+
model_name: "RN101" # RN101, ViT-L/14
|
63 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
64 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
65 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
66 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
67 |
+
augmentations: "affine" # augmentation before clip visual computation
|
68 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
69 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
70 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
71 |
+
|
72 |
+
perceptual:
|
73 |
+
name: "lpips" # dists
|
74 |
+
lpips_net: 'vgg'
|
75 |
+
coeff: 0.2
|
docs/figures/refine/ldm_generated_image0.png
ADDED
![]() |
Git LFS Details
|
docs/figures/refine/ldm_generated_image1.png
ADDED
![]() |
Git LFS Details
|
docs/figures/refine/ldm_generated_image2.png
ADDED
![]() |
Git LFS Details
|
docs/figures/refine/visual_best-rendered0.png
ADDED
![]() |
Git LFS Details
|
docs/figures/refine/visual_best-rendered1.png
ADDED
![]() |
Git LFS Details
|
docs/figures/refine/visual_best-rendered2.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/ldm_generated_image0.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/ldm_generated_image1.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/ldm_generated_image2.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/ldm_generated_image3.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/visual_best-rendered0.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/visual_best-rendered1.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/visual_best-rendered2.png
ADDED
![]() |
Git LFS Details
|
docs/figures/replace/visual_best-rendered3.png
ADDED
![]() |
Git LFS Details
|
docs/figures/reweight/ldm_generated_image0.png
ADDED
![]() |
Git LFS Details
|
docs/figures/reweight/ldm_generated_image1.png
ADDED
![]() |
Git LFS Details
|
docs/figures/reweight/ldm_generated_image2.png
ADDED
![]() |
Git LFS Details
|
docs/figures/reweight/visual_best-rendered0.png
ADDED
![]() |
Git LFS Details
|
docs/figures/reweight/visual_best-rendered1.png
ADDED
![]() |
Git LFS Details
|
docs/figures/reweight/visual_best-rendered2.png
ADDED
![]() |
Git LFS Details
|
libs/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import lazy
|
2 |
+
|
3 |
+
__getattr__, __dir__, __all__ = lazy.attach(
|
4 |
+
__name__,
|
5 |
+
submodules={'engine', 'metric', 'modules', 'solver', 'utils'},
|
6 |
+
submod_attrs={}
|
7 |
+
)
|
8 |
+
|
9 |
+
__version__ = '0.0.1'
|
libs/engine/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model_state import ModelState
|
2 |
+
from .config_processor import merge_and_update_config
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
'ModelState',
|
6 |
+
'merge_and_update_config'
|
7 |
+
]
|
libs/engine/config_processor.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Tuple
|
3 |
+
from functools import reduce
|
4 |
+
|
5 |
+
from argparse import Namespace
|
6 |
+
from omegaconf import DictConfig, OmegaConf
|
7 |
+
|
8 |
+
|
9 |
+
#################################################################################
|
10 |
+
# merge yaml and argparse #
|
11 |
+
#################################################################################
|
12 |
+
|
13 |
+
def register_resolver():
|
14 |
+
OmegaConf.register_new_resolver(
|
15 |
+
"add", lambda *numbers: sum(numbers)
|
16 |
+
)
|
17 |
+
OmegaConf.register_new_resolver(
|
18 |
+
"multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers)
|
19 |
+
)
|
20 |
+
OmegaConf.register_new_resolver(
|
21 |
+
"sub", lambda n1, n2: n1 - n2
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def _merge_args_and_config(
|
26 |
+
cmd_args: Namespace,
|
27 |
+
yaml_config: DictConfig,
|
28 |
+
read_only: bool = False
|
29 |
+
) -> Tuple[DictConfig, DictConfig, DictConfig]:
|
30 |
+
# convert cmd line args to OmegaConf
|
31 |
+
cmd_args_dict = vars(cmd_args)
|
32 |
+
cmd_args_list = []
|
33 |
+
for k, v in cmd_args_dict.items():
|
34 |
+
cmd_args_list.append(f"{k}={v}")
|
35 |
+
cmd_args_conf = OmegaConf.from_cli(cmd_args_list)
|
36 |
+
|
37 |
+
# The following overrides the previous configuration
|
38 |
+
# cmd_args_list > configs
|
39 |
+
args_ = OmegaConf.merge(yaml_config, cmd_args_conf)
|
40 |
+
|
41 |
+
if read_only:
|
42 |
+
OmegaConf.set_readonly(args_, True)
|
43 |
+
|
44 |
+
return args_, cmd_args_conf, yaml_config
|
45 |
+
|
46 |
+
|
47 |
+
def merge_configs(args, method_cfg_path):
|
48 |
+
"""merge command line args (argparse) and config file (OmegaConf)"""
|
49 |
+
yaml_config_path = os.path.join("./", "config", method_cfg_path)
|
50 |
+
try:
|
51 |
+
yaml_config = OmegaConf.load(yaml_config_path)
|
52 |
+
except FileNotFoundError as e:
|
53 |
+
print(f"error: {e}")
|
54 |
+
print(f"input file path: `{method_cfg_path}`")
|
55 |
+
print(f"config path: `{yaml_config_path}` not found.")
|
56 |
+
raise FileNotFoundError(e)
|
57 |
+
return _merge_args_and_config(args, yaml_config, read_only=False)
|
58 |
+
|
59 |
+
|
60 |
+
def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True):
|
61 |
+
"""update config file (OmegaConf) with dotlist"""
|
62 |
+
if update_nodes is None:
|
63 |
+
return source_args
|
64 |
+
|
65 |
+
update_args_list = str(update_nodes).split()
|
66 |
+
if len(update_args_list) < 1:
|
67 |
+
return source_args
|
68 |
+
|
69 |
+
# check update_args
|
70 |
+
for item in update_args_list:
|
71 |
+
item_key_ = str(item).split('=')[0] # get key
|
72 |
+
# item_val_ = str(item).split('=')[1] # get value
|
73 |
+
|
74 |
+
if strict:
|
75 |
+
# Tests if a key is existing
|
76 |
+
# assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing."
|
77 |
+
|
78 |
+
# Tests if a value is missing
|
79 |
+
assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing."
|
80 |
+
|
81 |
+
# if keys is None, then add key and set the value
|
82 |
+
if OmegaConf.select(source_args, item_key_) is None:
|
83 |
+
source_args.item_key_ = item_key_
|
84 |
+
|
85 |
+
# update original yaml params
|
86 |
+
update_nodes = OmegaConf.from_dotlist(update_args_list)
|
87 |
+
merged_args = OmegaConf.merge(source_args, update_nodes)
|
88 |
+
|
89 |
+
# remove update_args
|
90 |
+
if remove_update_nodes:
|
91 |
+
OmegaConf.update(merged_args, 'update', '')
|
92 |
+
return merged_args
|
93 |
+
|
94 |
+
|
95 |
+
def update_if_exist(source_args, update_nodes):
|
96 |
+
"""update config file (OmegaConf) with dotlist"""
|
97 |
+
if update_nodes is None:
|
98 |
+
return source_args
|
99 |
+
|
100 |
+
upd_args_list = str(update_nodes).split()
|
101 |
+
if len(upd_args_list) < 1:
|
102 |
+
return source_args
|
103 |
+
|
104 |
+
update_args_list = []
|
105 |
+
for item in upd_args_list:
|
106 |
+
item_key_ = str(item).split('=')[0] # get key
|
107 |
+
|
108 |
+
# if a key is existing
|
109 |
+
# if OmegaConf.select(source_args, item_key_) is not None:
|
110 |
+
# update_args_list.append(item)
|
111 |
+
|
112 |
+
update_args_list.append(item)
|
113 |
+
|
114 |
+
# update source_args if key be selected
|
115 |
+
if len(update_args_list) < 1:
|
116 |
+
merged_args = source_args
|
117 |
+
else:
|
118 |
+
update_nodes = OmegaConf.from_dotlist(update_args_list)
|
119 |
+
merged_args = OmegaConf.merge(source_args, update_nodes)
|
120 |
+
|
121 |
+
return merged_args
|
122 |
+
|
123 |
+
|
124 |
+
def merge_and_update_config(args):
|
125 |
+
register_resolver()
|
126 |
+
|
127 |
+
# if yaml_config is existing, then merge command line args and yaml_config
|
128 |
+
# if os.path.isfile(args.config) and args.config is not None:
|
129 |
+
if args.config is not None and str(args.config).endswith('.yaml'):
|
130 |
+
merged_args, cmd_args, yaml_config = merge_configs(args, args.config)
|
131 |
+
else:
|
132 |
+
merged_args, cmd_args, yaml_config = args, args, None
|
133 |
+
|
134 |
+
# update the yaml_config with the cmd '-update' flag
|
135 |
+
update_nodes = args.update
|
136 |
+
final_args = update_configs(merged_args, update_nodes)
|
137 |
+
|
138 |
+
# to simplify log output, we empty this
|
139 |
+
yaml_config_update = update_if_exist(yaml_config, update_nodes)
|
140 |
+
cmd_args_update = update_if_exist(cmd_args, update_nodes)
|
141 |
+
cmd_args_update.update = "" # clear update params
|
142 |
+
|
143 |
+
final_args.yaml_config = yaml_config_update
|
144 |
+
final_args.cmd_args = cmd_args_update
|
145 |
+
|
146 |
+
# update seed
|
147 |
+
if final_args.seed < 0:
|
148 |
+
import random
|
149 |
+
final_args.seed = random.randint(0, 65535)
|
150 |
+
|
151 |
+
return final_args
|
libs/engine/model_state.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import partial
|
3 |
+
from typing import Union, List
|
4 |
+
from pathlib import Path
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
|
7 |
+
from omegaconf import DictConfig
|
8 |
+
from pprint import pprint
|
9 |
+
import torch
|
10 |
+
from accelerate.utils import LoggerType
|
11 |
+
from accelerate import (
|
12 |
+
Accelerator,
|
13 |
+
GradScalerKwargs,
|
14 |
+
DistributedDataParallelKwargs,
|
15 |
+
InitProcessGroupKwargs
|
16 |
+
)
|
17 |
+
|
18 |
+
from ..modules.ema import EMA
|
19 |
+
from ..utils.logging import get_logger
|
20 |
+
|
21 |
+
|
22 |
+
class ModelState:
|
23 |
+
"""
|
24 |
+
Handling logger and `hugging face` accelerate training
|
25 |
+
|
26 |
+
features:
|
27 |
+
- Mixed Precision
|
28 |
+
- Gradient Scaler
|
29 |
+
- Gradient Accumulation
|
30 |
+
- Optimizer
|
31 |
+
- EMA
|
32 |
+
- Logger (default: python print)
|
33 |
+
- Monitor (default: wandb, tensorboard)
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
args,
|
39 |
+
log_path_suffix: str = None,
|
40 |
+
ignore_log=False, # whether to create log file or not
|
41 |
+
) -> None:
|
42 |
+
self.args: DictConfig = args
|
43 |
+
|
44 |
+
"""check valid"""
|
45 |
+
mixed_precision = self.args.get("mixed_precision")
|
46 |
+
# Bug: omegaconf convert 'no' to false
|
47 |
+
mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision
|
48 |
+
split_batches = self.args.get("split_batches", False)
|
49 |
+
gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1)
|
50 |
+
assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}"
|
51 |
+
|
52 |
+
"""create working space"""
|
53 |
+
# rule: ['./config'. 'method_name', 'exp_name.yaml']
|
54 |
+
# -> results_path: ./runs/{method_name}-{exp_name}, as a base folder
|
55 |
+
# config_prefix, config_name = str(self.args.get("config")).split('/')
|
56 |
+
# config_name_only = str(config_name).split(".")[0]
|
57 |
+
|
58 |
+
config_name_only = str(self.args.get("config")).split(".")[0]
|
59 |
+
results_folder = self.args.get("results_path", None)
|
60 |
+
if results_folder is None:
|
61 |
+
# self.results_path = Path("./workdir") / f"{config_prefix}-{config_name_only}"
|
62 |
+
self.results_path = Path("./workdir")
|
63 |
+
else:
|
64 |
+
# self.results_path = Path(results_folder) / f"{config_prefix}-{config_name_only}"
|
65 |
+
self.results_path = Path(os.path.join(results_folder, self.args.get("edit_type"), ))
|
66 |
+
|
67 |
+
# update results_path: ./runs/{method_name}-{exp_name}/{log_path_suffix}
|
68 |
+
# noting: can be understood as "results dir / methods / ablation study / your result"
|
69 |
+
if log_path_suffix is not None:
|
70 |
+
self.results_path = self.results_path / log_path_suffix
|
71 |
+
|
72 |
+
kwargs_handlers = []
|
73 |
+
"""mixed precision training"""
|
74 |
+
if args.mixed_precision == "no":
|
75 |
+
scaler_handler = GradScalerKwargs(
|
76 |
+
init_scale=args.init_scale,
|
77 |
+
growth_factor=args.growth_factor,
|
78 |
+
backoff_factor=args.backoff_factor,
|
79 |
+
growth_interval=args.growth_interval,
|
80 |
+
enabled=True
|
81 |
+
)
|
82 |
+
kwargs_handlers.append(scaler_handler)
|
83 |
+
|
84 |
+
"""distributed training"""
|
85 |
+
ddp_handler = DistributedDataParallelKwargs(
|
86 |
+
dim=0,
|
87 |
+
broadcast_buffers=True,
|
88 |
+
static_graph=False,
|
89 |
+
bucket_cap_mb=25,
|
90 |
+
find_unused_parameters=False,
|
91 |
+
check_reduction=False,
|
92 |
+
gradient_as_bucket_view=False
|
93 |
+
)
|
94 |
+
kwargs_handlers.append(ddp_handler)
|
95 |
+
|
96 |
+
init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200))
|
97 |
+
kwargs_handlers.append(init_handler)
|
98 |
+
|
99 |
+
"""init visualized tracker"""
|
100 |
+
log_with = []
|
101 |
+
self.args.visual = False
|
102 |
+
if args.use_wandb:
|
103 |
+
log_with.append(LoggerType.WANDB)
|
104 |
+
if args.tensorboard:
|
105 |
+
log_with.append(LoggerType.TENSORBOARD)
|
106 |
+
|
107 |
+
"""hugging face Accelerator"""
|
108 |
+
self.accelerator = Accelerator(
|
109 |
+
device_placement=True,
|
110 |
+
split_batches=split_batches,
|
111 |
+
mixed_precision=mixed_precision,
|
112 |
+
gradient_accumulation_steps=args.gradient_accumulate_step,
|
113 |
+
cpu=True if args.use_cpu else False,
|
114 |
+
log_with=None if len(log_with) == 0 else log_with,
|
115 |
+
project_dir=self.results_path / "vis",
|
116 |
+
kwargs_handlers=kwargs_handlers,
|
117 |
+
)
|
118 |
+
|
119 |
+
"""logs"""
|
120 |
+
if self.accelerator.is_local_main_process:
|
121 |
+
# for logging results in a folder periodically
|
122 |
+
self.results_path.mkdir(parents=True, exist_ok=True)
|
123 |
+
if not ignore_log:
|
124 |
+
now_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
|
125 |
+
# self.logger = get_logger(
|
126 |
+
# logs_dir=self.results_path.as_posix(),
|
127 |
+
# file_name=f"log.txt"
|
128 |
+
# )
|
129 |
+
|
130 |
+
print("==> command line args: ")
|
131 |
+
print(args.cmd_args)
|
132 |
+
print("==> yaml config args: ")
|
133 |
+
print(args.yaml_config)
|
134 |
+
|
135 |
+
print("\n***** Model State *****")
|
136 |
+
if self.accelerator.distributed_type != "NO":
|
137 |
+
print(f"-> Distributed Type: {self.accelerator.distributed_type}")
|
138 |
+
print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}")
|
139 |
+
print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp},"
|
140 |
+
f" Gradient Accumulate Step: {gradient_accumulate_step}")
|
141 |
+
print(f"-> Weight dtype: {self.weight_dtype}")
|
142 |
+
|
143 |
+
if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled:
|
144 |
+
print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}")
|
145 |
+
|
146 |
+
if args.use_wandb:
|
147 |
+
print(f"-> Init trackers: 'wandb' ")
|
148 |
+
self.args.visual = True
|
149 |
+
self.__init_tracker(project_name="my_project", tags=None, entity="")
|
150 |
+
|
151 |
+
print(f"-> Working Space: '{self.results_path}'")
|
152 |
+
|
153 |
+
"""EMA"""
|
154 |
+
self.use_ema = args.get('ema', False)
|
155 |
+
self.ema_wrapper = self.__build_ema_wrapper()
|
156 |
+
|
157 |
+
"""glob step"""
|
158 |
+
self.step = 0
|
159 |
+
|
160 |
+
"""log process"""
|
161 |
+
self.accelerator.wait_for_everyone()
|
162 |
+
print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}')
|
163 |
+
|
164 |
+
self.print("-> state initialization complete \n")
|
165 |
+
|
166 |
+
def __init_tracker(self, project_name, tags, entity):
|
167 |
+
self.accelerator.init_trackers(
|
168 |
+
project_name=project_name,
|
169 |
+
config=dict(self.args),
|
170 |
+
init_kwargs={
|
171 |
+
"wandb": {
|
172 |
+
"notes": "accelerate trainer pipeline",
|
173 |
+
"tags": [
|
174 |
+
f"total batch_size: {self.actual_batch_size}"
|
175 |
+
],
|
176 |
+
"entity": entity,
|
177 |
+
}}
|
178 |
+
)
|
179 |
+
|
180 |
+
def __build_ema_wrapper(self):
|
181 |
+
if self.use_ema:
|
182 |
+
self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, "
|
183 |
+
f"update_after_step: {self.args.ema_update_after_step}, "
|
184 |
+
f"update_every: {self.args.ema_update_every}")
|
185 |
+
ema_wrapper = partial(
|
186 |
+
EMA, beta=self.args.ema_decay,
|
187 |
+
update_after_step=self.args.ema_update_after_step,
|
188 |
+
update_every=self.args.ema_update_every
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
ema_wrapper = None
|
192 |
+
|
193 |
+
return ema_wrapper
|
194 |
+
|
195 |
+
@property
|
196 |
+
def device(self):
|
197 |
+
return self.accelerator.device
|
198 |
+
|
199 |
+
@property
|
200 |
+
def weight_dtype(self):
|
201 |
+
weight_dtype = torch.float32
|
202 |
+
if self.accelerator.mixed_precision == "fp16":
|
203 |
+
weight_dtype = torch.float16
|
204 |
+
elif self.accelerator.mixed_precision == "bf16":
|
205 |
+
weight_dtype = torch.bfloat16
|
206 |
+
return weight_dtype
|
207 |
+
|
208 |
+
@property
|
209 |
+
def actual_batch_size(self):
|
210 |
+
if self.accelerator.split_batches is False:
|
211 |
+
actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps
|
212 |
+
else:
|
213 |
+
assert self.actual_batch_size % self.accelerator.num_processes == 0
|
214 |
+
actual_batch_size = self.args.batch_size
|
215 |
+
return actual_batch_size
|
216 |
+
|
217 |
+
@property
|
218 |
+
def n_gpus(self):
|
219 |
+
return self.accelerator.num_processes
|
220 |
+
|
221 |
+
@property
|
222 |
+
def no_decay_params_names(self):
|
223 |
+
no_decay = [
|
224 |
+
"bn", "LayerNorm", "GroupNorm",
|
225 |
+
]
|
226 |
+
return no_decay
|
227 |
+
|
228 |
+
def no_decay_params(self, model, weight_decay):
|
229 |
+
"""optimization tricks"""
|
230 |
+
optimizer_grouped_parameters = [
|
231 |
+
{
|
232 |
+
"params": [
|
233 |
+
p for n, p in model.named_parameters()
|
234 |
+
if not any(nd in n for nd in self.no_decay_params_names)
|
235 |
+
],
|
236 |
+
"weight_decay": weight_decay,
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"params": [
|
240 |
+
p for n, p in model.named_parameters()
|
241 |
+
if any(nd in n for nd in self.no_decay_params_names)
|
242 |
+
],
|
243 |
+
"weight_decay": 0.0,
|
244 |
+
},
|
245 |
+
]
|
246 |
+
return optimizer_grouped_parameters
|
247 |
+
|
248 |
+
def optimized_params(self, model: torch.nn.Module, verbose=True) -> List:
|
249 |
+
"""return parameters if `requires_grad` is True
|
250 |
+
|
251 |
+
Args:
|
252 |
+
model: pytorch models
|
253 |
+
verbose: log optimized parameters
|
254 |
+
|
255 |
+
Examples:
|
256 |
+
>>> self.params_optimized = self.optimized_params(uvit, verbose=True)
|
257 |
+
>>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr)
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
a list of parameters
|
261 |
+
"""
|
262 |
+
params_optimized = []
|
263 |
+
for key, value in model.named_parameters():
|
264 |
+
if value.requires_grad:
|
265 |
+
params_optimized.append(value)
|
266 |
+
if verbose:
|
267 |
+
self.print("\t {}, {}, {}".format(key, value.numel(), value.shape))
|
268 |
+
return params_optimized
|
269 |
+
|
270 |
+
def save_everything(self, fpath: str):
|
271 |
+
"""Saving and loading the model, optimizer, RNG generators, and the GradScaler."""
|
272 |
+
if not self.accelerator.is_main_process:
|
273 |
+
return
|
274 |
+
self.accelerator.save_state(fpath)
|
275 |
+
|
276 |
+
def load_save_everything(self, fpath: str):
|
277 |
+
"""Loading the model, optimizer, RNG generators, and the GradScaler."""
|
278 |
+
self.accelerator.load_state(fpath)
|
279 |
+
|
280 |
+
def save(self, milestone: Union[str, float, int], checkpoint: object) -> None:
|
281 |
+
if not self.accelerator.is_main_process:
|
282 |
+
return
|
283 |
+
|
284 |
+
torch.save(checkpoint, self.results_path / f'model-{milestone}.pt')
|
285 |
+
|
286 |
+
def save_in(self, root: Union[str, Path], checkpoint: object) -> None:
|
287 |
+
if not self.accelerator.is_main_process:
|
288 |
+
return
|
289 |
+
|
290 |
+
torch.save(checkpoint, root)
|
291 |
+
|
292 |
+
def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False):
|
293 |
+
ckpt = torch.load(path, map_location=self.accelerator.device)
|
294 |
+
|
295 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
296 |
+
if rm_module_prefix:
|
297 |
+
unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
|
298 |
+
else:
|
299 |
+
unwrapped_model.load_state_dict(ckpt)
|
300 |
+
return unwrapped_model
|
301 |
+
|
302 |
+
def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]):
|
303 |
+
ckpt = torch.load(path, map_location=self.accelerator.device)
|
304 |
+
self.print(f"pretrained_dict len: {len(ckpt)}")
|
305 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
306 |
+
model_dict = unwrapped_model.state_dict()
|
307 |
+
pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
|
308 |
+
model_dict.update(pretrained_dict)
|
309 |
+
unwrapped_model.load_state_dict(model_dict, strict=False)
|
310 |
+
self.print(f"selected pretrained_dict: {len(model_dict)}")
|
311 |
+
return unwrapped_model
|
312 |
+
|
313 |
+
def print(self, *args, **kwargs):
|
314 |
+
"""Use in replacement of `print()` to only print once per server."""
|
315 |
+
self.accelerator.print(*args, **kwargs)
|
316 |
+
|
317 |
+
def pretty_print(self, msg):
|
318 |
+
if self.accelerator.is_local_main_process:
|
319 |
+
pprint(dict(msg))
|
320 |
+
|
321 |
+
def close_tracker(self):
|
322 |
+
self.accelerator.end_training()
|
323 |
+
|
324 |
+
def free_memory(self):
|
325 |
+
self.accelerator.clear()
|
326 |
+
|
327 |
+
def close(self, msg: str = "Training complete."):
|
328 |
+
"""Use in end of training."""
|
329 |
+
self.free_memory()
|
330 |
+
|
331 |
+
if torch.cuda.is_available():
|
332 |
+
self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
|
333 |
+
if self.args.visual:
|
334 |
+
self.close_tracker()
|
335 |
+
self.print(msg)
|
libs/metric/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
libs/metric/accuracy.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def accuracy(output, target, topk=(1,)):
|
2 |
+
"""
|
3 |
+
Computes the accuracy over the k top predictions for the specified values of k.
|
4 |
+
|
5 |
+
Args
|
6 |
+
output: logits or probs (num of batch, num of classes)
|
7 |
+
target: (num of batch, 1) or (num of batch, )
|
8 |
+
topk: list of returned k
|
9 |
+
|
10 |
+
refer: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
11 |
+
"""
|
12 |
+
maxK = max(topk) # get k in top-k
|
13 |
+
batch_size = target.size(0)
|
14 |
+
|
15 |
+
_, pred = output.topk(k=maxK, dim=1, largest=True, sorted=True) # pred: [num of batch, k]
|
16 |
+
pred = pred.t() # pred: [k, num of batch]
|
17 |
+
|
18 |
+
# [1, num of batch] -> [k, num_of_batch] : bool
|
19 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
20 |
+
|
21 |
+
res = []
|
22 |
+
for k in topk:
|
23 |
+
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
24 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
25 |
+
return res # np.shape(res): [k, 1]
|
libs/metric/clip_score/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .openaiCLIP_loss import CLIPScoreWrapper
|
2 |
+
|
3 |
+
__all__ = ['CLIPScoreWrapper']
|
libs/metric/clip_score/openaiCLIP_loss.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, List, Tuple
|
2 |
+
from collections import OrderedDict
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
|
10 |
+
|
11 |
+
class CLIPScoreWrapper(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
clip_model_name: str,
|
15 |
+
download_root: str = None,
|
16 |
+
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
17 |
+
jit: bool = False,
|
18 |
+
# additional params
|
19 |
+
visual_score: bool = False,
|
20 |
+
feats_loss_type: str = None,
|
21 |
+
feats_loss_weights: List[float] = None,
|
22 |
+
fc_loss_weight: float = None,
|
23 |
+
context_length: int = 77):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
import clip # local import
|
27 |
+
|
28 |
+
# check model info
|
29 |
+
self.clip_model_name = clip_model_name
|
30 |
+
self.device = device
|
31 |
+
self.available_models = clip.available_models()
|
32 |
+
assert clip_model_name in self.available_models, f"A model backbone: {clip_model_name} that does not exist"
|
33 |
+
|
34 |
+
# load CLIP
|
35 |
+
self.model, self.preprocess = clip.load(clip_model_name, device, jit=jit, download_root=download_root)
|
36 |
+
self.model.eval()
|
37 |
+
|
38 |
+
# load tokenize
|
39 |
+
self.tokenize_fn = partial(clip.tokenize, context_length=context_length)
|
40 |
+
|
41 |
+
# load CLIP visual
|
42 |
+
self.visual_encoder = VisualEncoderWrapper(self.model, clip_model_name).to(device)
|
43 |
+
self.visual_encoder.eval()
|
44 |
+
|
45 |
+
# check loss weights
|
46 |
+
self.visual_score = visual_score
|
47 |
+
if visual_score:
|
48 |
+
assert feats_loss_type in ["l1", "l2", "cosine"], f"{feats_loss_type} is not exist."
|
49 |
+
if clip_model_name.startswith("ViT"): assert len(feats_loss_weights) == 12
|
50 |
+
if clip_model_name.startswith("RN"): assert len(feats_loss_weights) == 5
|
51 |
+
|
52 |
+
# load visual loss wrapper
|
53 |
+
self.visual_loss_fn = CLIPVisualLossWrapper(self.visual_encoder, feats_loss_type,
|
54 |
+
feats_loss_weights,
|
55 |
+
fc_loss_weight)
|
56 |
+
|
57 |
+
@property
|
58 |
+
def input_resolution(self):
|
59 |
+
return self.model.visual.input_resolution # default: 224
|
60 |
+
|
61 |
+
@property
|
62 |
+
def resize(self): # Resize only
|
63 |
+
return transforms.Compose([self.preprocess.transforms[0]])
|
64 |
+
|
65 |
+
@property
|
66 |
+
def normalize(self):
|
67 |
+
return transforms.Compose([
|
68 |
+
self.preprocess.transforms[0], # Resize
|
69 |
+
self.preprocess.transforms[1], # CenterCrop
|
70 |
+
self.preprocess.transforms[-1], # Normalize
|
71 |
+
])
|
72 |
+
|
73 |
+
@property
|
74 |
+
def norm_(self): # Normalize only
|
75 |
+
return transforms.Compose([self.preprocess.transforms[-1]])
|
76 |
+
|
77 |
+
def encode_image_layer_wise(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
78 |
+
semantic_vec, feature_maps = self.visual_encoder(x)
|
79 |
+
return semantic_vec, feature_maps
|
80 |
+
|
81 |
+
def encode_text(self, text: Union[str, List[str]], norm: bool = True) -> torch.Tensor:
|
82 |
+
tokens = self.tokenize_fn(text).to(self.device)
|
83 |
+
text_features = self.model.encode_text(tokens)
|
84 |
+
if norm:
|
85 |
+
text_features = text_features.mean(axis=0, keepdim=True)
|
86 |
+
text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
|
87 |
+
return text_features_norm
|
88 |
+
return text_features
|
89 |
+
|
90 |
+
def encode_image(self, image: torch.Tensor, norm: bool = True) -> torch.Tensor:
|
91 |
+
image_features = self.model.encode_image(image)
|
92 |
+
if norm:
|
93 |
+
image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
|
94 |
+
return image_features_norm
|
95 |
+
return image_features
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def predict(self,
|
99 |
+
image: torch.Tensor,
|
100 |
+
text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
|
101 |
+
image_features = self.model.encode_image(image)
|
102 |
+
text_tokenize = self.tokenize_fn(text).to(self.device)
|
103 |
+
text_features = self.model.encode_text(text_tokenize)
|
104 |
+
logits_per_image, logits_per_text = self.model(image, text)
|
105 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
106 |
+
return image_features, text_features, probs
|
107 |
+
|
108 |
+
def compute_text_visual_distance(
|
109 |
+
self, image: torch.Tensor, text: Union[str, List[str]]
|
110 |
+
) -> torch.Tensor:
|
111 |
+
image_features = self.model.encode_image(image)
|
112 |
+
text_tokenize = self.tokenize_fn(text).to(self.device)
|
113 |
+
with torch.no_grad():
|
114 |
+
text_features = self.model.encode_text(text_tokenize)
|
115 |
+
|
116 |
+
image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
|
117 |
+
text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
|
118 |
+
# loss = - (image_features_norm @ text_features_norm.T)
|
119 |
+
loss = 1 - torch.cosine_similarity(image_features_norm, text_features_norm, dim=1)
|
120 |
+
return loss.mean()
|
121 |
+
|
122 |
+
def directional_text_visual_distance(self, src_text, src_img, tar_text, tar_img):
|
123 |
+
src_image_features = self.model.encode_image(src_img).detach()
|
124 |
+
tar_image_features = self.model.encode_image(tar_img)
|
125 |
+
src_text_tokenize = self.tokenize_fn(src_text).to(self.device)
|
126 |
+
tar_text_tokenize = self.tokenize_fn(tar_text).to(self.device)
|
127 |
+
with torch.no_grad():
|
128 |
+
src_text_features = self.model.encode_text(src_text_tokenize)
|
129 |
+
tar_text_features = self.model.encode_text(tar_text_tokenize)
|
130 |
+
|
131 |
+
delta_image_features = tar_image_features - src_image_features
|
132 |
+
delta_text_features = tar_text_features - src_text_features
|
133 |
+
|
134 |
+
# # avold zero divisor
|
135 |
+
# delta_image_features_norm = delta_image_features / delta_image_features.norm(dim=-1, keepdim=True)
|
136 |
+
# delta_text_features_norm = delta_text_features / delta_text_features.norm(dim=-1, keepdim=True)
|
137 |
+
|
138 |
+
loss = 1 - torch.cosine_similarity(delta_image_features, delta_text_features, dim=1, eps=1e-3)
|
139 |
+
return loss.mean()
|
140 |
+
|
141 |
+
def compute_visual_distance(
|
142 |
+
self, x: torch.Tensor, y: torch.Tensor, clip_norm: bool = True,
|
143 |
+
) -> Tuple[torch.Tensor, List]:
|
144 |
+
# return a fc loss and the list of feat loss
|
145 |
+
assert self.visual_score is True
|
146 |
+
assert x.shape == y.shape
|
147 |
+
assert x.shape[-1] == self.input_resolution and x.shape[-2] == self.input_resolution
|
148 |
+
assert y.shape[-1] == self.input_resolution and y.shape[-2] == self.input_resolution
|
149 |
+
|
150 |
+
if clip_norm:
|
151 |
+
return self.visual_loss_fn(self.normalize(x), self.normalize(y))
|
152 |
+
else:
|
153 |
+
return self.visual_loss_fn(x, y)
|
154 |
+
|
155 |
+
|
156 |
+
class VisualEncoderWrapper(nn.Module):
|
157 |
+
"""
|
158 |
+
semantic features and layer by layer feature maps are obtained from CLIP visual encoder.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, clip_model: nn.Module, clip_model_name: str):
|
162 |
+
super().__init__()
|
163 |
+
self.clip_model = clip_model
|
164 |
+
self.clip_model_name = clip_model_name
|
165 |
+
|
166 |
+
if clip_model_name.startswith("ViT"):
|
167 |
+
self.feature_maps = OrderedDict()
|
168 |
+
for i in range(12): # 12 ResBlocks in ViT visual transformer
|
169 |
+
self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
|
170 |
+
self.make_hook(i)
|
171 |
+
)
|
172 |
+
|
173 |
+
if clip_model_name.startswith("RN"):
|
174 |
+
layers = list(self.clip_model.visual.children())
|
175 |
+
init_layers = torch.nn.Sequential(*layers)[:8]
|
176 |
+
self.layer1 = layers[8]
|
177 |
+
self.layer2 = layers[9]
|
178 |
+
self.layer3 = layers[10]
|
179 |
+
self.layer4 = layers[11]
|
180 |
+
self.att_pool2d = layers[12]
|
181 |
+
|
182 |
+
def make_hook(self, name):
|
183 |
+
def hook(module, input, output):
|
184 |
+
if len(output.shape) == 3:
|
185 |
+
# LND -> NLD (B, 77, 768)
|
186 |
+
self.feature_maps[name] = output.permute(1, 0, 2)
|
187 |
+
else:
|
188 |
+
self.feature_maps[name] = output
|
189 |
+
|
190 |
+
return hook
|
191 |
+
|
192 |
+
def _forward_vit(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
|
193 |
+
fc_feature = self.clip_model.encode_image(x).float()
|
194 |
+
feature_maps = [self.feature_maps[k] for k in range(12)]
|
195 |
+
|
196 |
+
# fc_feature len: 1 ,feature_maps len: 12
|
197 |
+
return fc_feature, feature_maps
|
198 |
+
|
199 |
+
def _forward_resnet(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
|
200 |
+
def stem(m, x):
|
201 |
+
for conv, bn, relu in [(m.conv1, m.bn1, m.relu1), (m.conv2, m.bn2, m.relu2), (m.conv3, m.bn3, m.relu3)]:
|
202 |
+
x = torch.relu(bn(conv(x)))
|
203 |
+
x = m.avgpool(x)
|
204 |
+
return x
|
205 |
+
|
206 |
+
x = x.type(self.clip_model.visual.conv1.weight.dtype)
|
207 |
+
x = stem(self.clip_model.visual, x)
|
208 |
+
x1 = self.layer1(x)
|
209 |
+
x2 = self.layer2(x1)
|
210 |
+
x3 = self.layer3(x2)
|
211 |
+
x4 = self.layer4(x3)
|
212 |
+
y = self.att_pool2d(x4)
|
213 |
+
|
214 |
+
# fc_features len: 1 ,feature_maps len: 5
|
215 |
+
return y, [x, x1, x2, x3, x4]
|
216 |
+
|
217 |
+
def forward(self, x) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
218 |
+
if self.clip_model_name.startswith("ViT"):
|
219 |
+
fc_feat, visual_feat_maps = self._forward_vit(x)
|
220 |
+
if self.clip_model_name.startswith("RN"):
|
221 |
+
fc_feat, visual_feat_maps = self._forward_resnet(x)
|
222 |
+
|
223 |
+
return fc_feat, visual_feat_maps
|
224 |
+
|
225 |
+
|
226 |
+
class CLIPVisualLossWrapper(nn.Module):
|
227 |
+
"""
|
228 |
+
Visual Feature Loss + FC loss
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
visual_encoder: nn.Module,
|
234 |
+
feats_loss_type: str = None,
|
235 |
+
feats_loss_weights: List[float] = None,
|
236 |
+
fc_loss_weight: float = None,
|
237 |
+
):
|
238 |
+
super().__init__()
|
239 |
+
self.visual_encoder = visual_encoder
|
240 |
+
self.feats_loss_weights = feats_loss_weights
|
241 |
+
self.fc_loss_weight = fc_loss_weight
|
242 |
+
|
243 |
+
self.layer_criterion = layer_wise_distance(feats_loss_type)
|
244 |
+
|
245 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
246 |
+
x_fc_feature, x_feat_maps = self.visual_encoder(x)
|
247 |
+
y_fc_feature, y_feat_maps = self.visual_encoder(y)
|
248 |
+
|
249 |
+
# visual feature loss
|
250 |
+
if sum(self.feats_loss_weights) == 0:
|
251 |
+
feats_loss_list = [torch.tensor(0, device=x.device)]
|
252 |
+
else:
|
253 |
+
feats_loss = self.layer_criterion(x_feat_maps, y_feat_maps, self.visual_encoder.clip_model_name)
|
254 |
+
feats_loss_list = []
|
255 |
+
for layer, w in enumerate(self.feats_loss_weights):
|
256 |
+
if w:
|
257 |
+
feats_loss_list.append(feats_loss[layer] * w)
|
258 |
+
|
259 |
+
# visual fc loss, default: cosine similarity
|
260 |
+
if self.fc_loss_weight == 0:
|
261 |
+
fc_loss = torch.tensor(0, device=x.device)
|
262 |
+
else:
|
263 |
+
fc_loss = (1 - torch.cosine_similarity(x_fc_feature, y_fc_feature, dim=1)).mean()
|
264 |
+
fc_loss = fc_loss * self.fc_loss_weight
|
265 |
+
|
266 |
+
return fc_loss, feats_loss_list
|
267 |
+
|
268 |
+
|
269 |
+
#################################################################################
|
270 |
+
# layer wise metric #
|
271 |
+
#################################################################################
|
272 |
+
|
273 |
+
def layer_wise_distance(metric_name: str):
|
274 |
+
return {
|
275 |
+
"l1": l1_layer_wise,
|
276 |
+
"l2": l2_layer_wise,
|
277 |
+
"cosine": cosine_layer_wise
|
278 |
+
}.get(metric_name.lower())
|
279 |
+
|
280 |
+
|
281 |
+
def l2_layer_wise(x_features, y_features, clip_model_name):
|
282 |
+
return [
|
283 |
+
torch.square(x_conv - y_conv).mean()
|
284 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
285 |
+
]
|
286 |
+
|
287 |
+
|
288 |
+
def l1_layer_wise(x_features, y_features, clip_model_name):
|
289 |
+
return [
|
290 |
+
torch.abs(x_conv - y_conv).mean()
|
291 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
292 |
+
]
|
293 |
+
|
294 |
+
|
295 |
+
def cosine_layer_wise(x_features, y_features, clip_model_name):
|
296 |
+
if clip_model_name.startswith("RN"):
|
297 |
+
return [
|
298 |
+
(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
|
299 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
300 |
+
]
|
301 |
+
return [
|
302 |
+
(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
|
303 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
304 |
+
]
|
libs/metric/lpips_origin/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .lpips import LPIPS
|
2 |
+
|
3 |
+
__all__ = ['LPIPS']
|
libs/metric/lpips_origin/lpips.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from . import pretrained_networks as pretrained_torch_models
|
9 |
+
|
10 |
+
|
11 |
+
def spatial_average(x, keepdim=True):
|
12 |
+
return x.mean([2, 3], keepdim=keepdim)
|
13 |
+
|
14 |
+
|
15 |
+
def upsample(x):
|
16 |
+
return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x)
|
17 |
+
|
18 |
+
|
19 |
+
def normalize_tensor(in_feat, eps=1e-10):
|
20 |
+
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
|
21 |
+
return in_feat / (norm_factor + eps)
|
22 |
+
|
23 |
+
|
24 |
+
# Learned perceptual metric
|
25 |
+
class LPIPS(nn.Module):
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
pretrained=True,
|
29 |
+
net='alex',
|
30 |
+
version='0.1',
|
31 |
+
lpips=True,
|
32 |
+
spatial=False,
|
33 |
+
pnet_rand=False,
|
34 |
+
pnet_tune=False,
|
35 |
+
use_dropout=True,
|
36 |
+
model_path=None,
|
37 |
+
eval_mode=True,
|
38 |
+
verbose=True):
|
39 |
+
""" Initializes a perceptual loss torch.nn.Module
|
40 |
+
|
41 |
+
Parameters (default listed first)
|
42 |
+
---------------------------------
|
43 |
+
lpips : bool
|
44 |
+
[True] use linear layers on top of base/trunk network
|
45 |
+
[False] means no linear layers; each layer is averaged together
|
46 |
+
pretrained : bool
|
47 |
+
This flag controls the linear layers, which are only in effect when lpips=True above
|
48 |
+
[True] means linear layers are calibrated with human perceptual judgments
|
49 |
+
[False] means linear layers are randomly initialized
|
50 |
+
pnet_rand : bool
|
51 |
+
[False] means trunk loaded with ImageNet classification weights
|
52 |
+
[True] means randomly initialized trunk
|
53 |
+
net : str
|
54 |
+
['alex','vgg','squeeze'] are the base/trunk networks available
|
55 |
+
version : str
|
56 |
+
['v0.1'] is the default and latest
|
57 |
+
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
|
58 |
+
model_path : 'str'
|
59 |
+
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
|
60 |
+
|
61 |
+
The following parameters should only be changed if training the network:
|
62 |
+
|
63 |
+
eval_mode : bool
|
64 |
+
[True] is for test mode (default)
|
65 |
+
[False] is for training mode
|
66 |
+
pnet_tune
|
67 |
+
[False] keep base/trunk frozen
|
68 |
+
[True] tune the base/trunk network
|
69 |
+
use_dropout : bool
|
70 |
+
[True] to use dropout when training linear layers
|
71 |
+
[False] for no dropout when training linear layers
|
72 |
+
"""
|
73 |
+
super(LPIPS, self).__init__()
|
74 |
+
if verbose:
|
75 |
+
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %
|
76 |
+
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
|
77 |
+
|
78 |
+
self.pnet_type = net
|
79 |
+
self.pnet_tune = pnet_tune
|
80 |
+
self.pnet_rand = pnet_rand
|
81 |
+
self.spatial = spatial
|
82 |
+
self.lpips = lpips # false means baseline of just averaging all layers
|
83 |
+
self.version = version
|
84 |
+
self.scaling_layer = ScalingLayer()
|
85 |
+
|
86 |
+
if self.pnet_type in ['vgg', 'vgg16']:
|
87 |
+
net_type = pretrained_torch_models.vgg16
|
88 |
+
self.chns = [64, 128, 256, 512, 512]
|
89 |
+
elif self.pnet_type == 'alex':
|
90 |
+
net_type = pretrained_torch_models.alexnet
|
91 |
+
self.chns = [64, 192, 384, 256, 256]
|
92 |
+
elif self.pnet_type == 'squeeze':
|
93 |
+
net_type = pretrained_torch_models.squeezenet
|
94 |
+
self.chns = [64, 128, 256, 384, 384, 512, 512]
|
95 |
+
self.L = len(self.chns)
|
96 |
+
|
97 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
98 |
+
|
99 |
+
if lpips:
|
100 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
101 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
102 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
103 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
104 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
105 |
+
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
106 |
+
if self.pnet_type == 'squeeze': # 7 layers for squeezenet
|
107 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
108 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
109 |
+
self.lins += [self.lin5, self.lin6]
|
110 |
+
self.lins = nn.ModuleList(self.lins)
|
111 |
+
|
112 |
+
if pretrained:
|
113 |
+
if model_path is None:
|
114 |
+
model_path = os.path.join(
|
115 |
+
os.path.dirname(os.path.abspath(__file__)),
|
116 |
+
f"weights/v{version}/{net}.pth"
|
117 |
+
)
|
118 |
+
if verbose:
|
119 |
+
print('Loading model from: %s' % model_path)
|
120 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
|
121 |
+
|
122 |
+
if eval_mode:
|
123 |
+
self.eval()
|
124 |
+
|
125 |
+
def forward(self, in0, in1, return_per_layer=False, normalize=False):
|
126 |
+
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, 1]
|
127 |
+
in0 = 2 * in0 - 1
|
128 |
+
in1 = 2 * in1 - 1
|
129 |
+
|
130 |
+
# Noting: v0.0 - original release had a bug, where input was not scaled
|
131 |
+
if self.version == '0.1':
|
132 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1))
|
133 |
+
else:
|
134 |
+
in0_input, in1_input = in0, in1
|
135 |
+
|
136 |
+
# model forward
|
137 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
138 |
+
|
139 |
+
feats0, feats1, diffs = {}, {}, {}
|
140 |
+
for kk in range(self.L):
|
141 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
142 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
143 |
+
|
144 |
+
if self.lpips:
|
145 |
+
if self.spatial:
|
146 |
+
res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)]
|
147 |
+
else:
|
148 |
+
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
|
149 |
+
else:
|
150 |
+
if self.spatial:
|
151 |
+
res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)]
|
152 |
+
else:
|
153 |
+
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
|
154 |
+
|
155 |
+
loss = sum(res)
|
156 |
+
|
157 |
+
if return_per_layer:
|
158 |
+
return loss, res
|
159 |
+
else:
|
160 |
+
return loss
|
161 |
+
|
162 |
+
|
163 |
+
class ScalingLayer(nn.Module):
|
164 |
+
def __init__(self):
|
165 |
+
super(ScalingLayer, self).__init__()
|
166 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
167 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
168 |
+
|
169 |
+
def forward(self, inp):
|
170 |
+
return (inp - self.shift) / self.scale
|
171 |
+
|
172 |
+
|
173 |
+
class NetLinLayer(nn.Module):
|
174 |
+
"""A single linear layer which does a 1x1 conv"""
|
175 |
+
|
176 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
177 |
+
super(NetLinLayer, self).__init__()
|
178 |
+
|
179 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
180 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
181 |
+
self.model = nn.Sequential(*layers)
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
return self.model(x)
|
libs/metric/lpips_origin/pretrained_networks.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision.models as tv_models
|
5 |
+
|
6 |
+
|
7 |
+
class squeezenet(torch.nn.Module):
|
8 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
9 |
+
super(squeezenet, self).__init__()
|
10 |
+
pretrained_features = tv_models.squeezenet1_1(weights=pretrained).features
|
11 |
+
self.slice1 = torch.nn.Sequential()
|
12 |
+
self.slice2 = torch.nn.Sequential()
|
13 |
+
self.slice3 = torch.nn.Sequential()
|
14 |
+
self.slice4 = torch.nn.Sequential()
|
15 |
+
self.slice5 = torch.nn.Sequential()
|
16 |
+
self.slice6 = torch.nn.Sequential()
|
17 |
+
self.slice7 = torch.nn.Sequential()
|
18 |
+
self.N_slices = 7
|
19 |
+
for x in range(2):
|
20 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
21 |
+
for x in range(2, 5):
|
22 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
23 |
+
for x in range(5, 8):
|
24 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
25 |
+
for x in range(8, 10):
|
26 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
27 |
+
for x in range(10, 11):
|
28 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
29 |
+
for x in range(11, 12):
|
30 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
31 |
+
for x in range(12, 13):
|
32 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
33 |
+
if not requires_grad:
|
34 |
+
for param in self.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
|
37 |
+
def forward(self, X):
|
38 |
+
h = self.slice1(X)
|
39 |
+
h_relu1 = h
|
40 |
+
h = self.slice2(h)
|
41 |
+
h_relu2 = h
|
42 |
+
h = self.slice3(h)
|
43 |
+
h_relu3 = h
|
44 |
+
h = self.slice4(h)
|
45 |
+
h_relu4 = h
|
46 |
+
h = self.slice5(h)
|
47 |
+
h_relu5 = h
|
48 |
+
h = self.slice6(h)
|
49 |
+
h_relu6 = h
|
50 |
+
h = self.slice7(h)
|
51 |
+
h_relu7 = h
|
52 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
|
53 |
+
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
54 |
+
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class alexnet(torch.nn.Module):
|
59 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
60 |
+
super(alexnet, self).__init__()
|
61 |
+
weights = tv_models.AlexNet_Weights.IMAGENET1K_V1 if pretrained else None
|
62 |
+
alexnet_pretrained_features = tv_models.alexnet(weights=weights).features
|
63 |
+
self.slice1 = torch.nn.Sequential()
|
64 |
+
self.slice2 = torch.nn.Sequential()
|
65 |
+
self.slice3 = torch.nn.Sequential()
|
66 |
+
self.slice4 = torch.nn.Sequential()
|
67 |
+
self.slice5 = torch.nn.Sequential()
|
68 |
+
self.N_slices = 5
|
69 |
+
for x in range(2):
|
70 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
+
for x in range(2, 5):
|
72 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
+
for x in range(5, 8):
|
74 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
+
for x in range(8, 10):
|
76 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
+
for x in range(10, 12):
|
78 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
79 |
+
|
80 |
+
if not requires_grad:
|
81 |
+
for param in self.parameters():
|
82 |
+
param.requires_grad = False
|
83 |
+
|
84 |
+
def forward(self, X):
|
85 |
+
h = self.slice1(X)
|
86 |
+
h_relu1 = h
|
87 |
+
h = self.slice2(h)
|
88 |
+
h_relu2 = h
|
89 |
+
h = self.slice3(h)
|
90 |
+
h_relu3 = h
|
91 |
+
h = self.slice4(h)
|
92 |
+
h_relu4 = h
|
93 |
+
h = self.slice5(h)
|
94 |
+
h_relu5 = h
|
95 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
96 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
97 |
+
|
98 |
+
return out
|
99 |
+
|
100 |
+
|
101 |
+
class vgg16(torch.nn.Module):
|
102 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
103 |
+
super(vgg16, self).__init__()
|
104 |
+
weights = tv_models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None
|
105 |
+
vgg_pretrained_features = tv_models.vgg16(weights=weights).features
|
106 |
+
self.slice1 = torch.nn.Sequential()
|
107 |
+
self.slice2 = torch.nn.Sequential()
|
108 |
+
self.slice3 = torch.nn.Sequential()
|
109 |
+
self.slice4 = torch.nn.Sequential()
|
110 |
+
self.slice5 = torch.nn.Sequential()
|
111 |
+
self.N_slices = 5
|
112 |
+
for x in range(4):
|
113 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
114 |
+
for x in range(4, 9):
|
115 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
116 |
+
for x in range(9, 16):
|
117 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
118 |
+
for x in range(16, 23):
|
119 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
120 |
+
for x in range(23, 30):
|
121 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
122 |
+
|
123 |
+
if not requires_grad:
|
124 |
+
for param in self.parameters():
|
125 |
+
param.requires_grad = False
|
126 |
+
|
127 |
+
def forward(self, X):
|
128 |
+
h = self.slice1(X)
|
129 |
+
h_relu1_2 = h
|
130 |
+
h = self.slice2(h)
|
131 |
+
h_relu2_2 = h
|
132 |
+
h = self.slice3(h)
|
133 |
+
h_relu3_3 = h
|
134 |
+
h = self.slice4(h)
|
135 |
+
h_relu4_3 = h
|
136 |
+
h = self.slice5(h)
|
137 |
+
h_relu5_3 = h
|
138 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
139 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
140 |
+
|
141 |
+
return out
|
142 |
+
|
143 |
+
|
144 |
+
class resnet(torch.nn.Module):
|
145 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
146 |
+
super(resnet, self).__init__()
|
147 |
+
|
148 |
+
if num == 18:
|
149 |
+
weights = tv_models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
|
150 |
+
self.net = tv_models.resnet18(weights=weights)
|
151 |
+
elif num == 34:
|
152 |
+
weights = tv_models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
|
153 |
+
self.net = tv_models.resnet34(weights=weights)
|
154 |
+
elif num == 50:
|
155 |
+
weights = tv_models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
|
156 |
+
self.net = tv_models.resnet50(weights=weights)
|
157 |
+
elif num == 101:
|
158 |
+
weights = tv_models.ResNet101_Weights.IMAGENET1K_V2 if pretrained else None
|
159 |
+
self.net = tv_models.resnet101(weights=weights)
|
160 |
+
elif num == 152:
|
161 |
+
weights = tv_models.ResNet152_Weights.IMAGENET1K_V2 if pretrained else None
|
162 |
+
self.net = tv_models.resnet152(weights=weights)
|
163 |
+
self.N_slices = 5
|
164 |
+
|
165 |
+
if not requires_grad:
|
166 |
+
for param in self.net.parameters():
|
167 |
+
param.requires_grad = False
|
168 |
+
|
169 |
+
self.conv1 = self.net.conv1
|
170 |
+
self.bn1 = self.net.bn1
|
171 |
+
self.relu = self.net.relu
|
172 |
+
self.maxpool = self.net.maxpool
|
173 |
+
self.layer1 = self.net.layer1
|
174 |
+
self.layer2 = self.net.layer2
|
175 |
+
self.layer3 = self.net.layer3
|
176 |
+
self.layer4 = self.net.layer4
|
177 |
+
|
178 |
+
def forward(self, X):
|
179 |
+
h = self.conv1(X)
|
180 |
+
h = self.bn1(h)
|
181 |
+
h = self.relu(h)
|
182 |
+
h_relu1 = h
|
183 |
+
h = self.maxpool(h)
|
184 |
+
h = self.layer1(h)
|
185 |
+
h_conv2 = h
|
186 |
+
h = self.layer2(h)
|
187 |
+
h_conv3 = h
|
188 |
+
h = self.layer3(h)
|
189 |
+
h_conv4 = h
|
190 |
+
h = self.layer4(h)
|
191 |
+
h_conv5 = h
|
192 |
+
|
193 |
+
outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
|
194 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
195 |
+
|
196 |
+
return out
|
libs/metric/lpips_origin/weights/v0.1/alex.pth
ADDED
Binary file (6.01 kB). View file
|
|
libs/metric/lpips_origin/weights/v0.1/squeeze.pth
ADDED
Binary file (10.8 kB). View file
|
|
libs/metric/lpips_origin/weights/v0.1/vgg.pth
ADDED
Binary file (7.29 kB). View file
|
|
libs/metric/piq/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# install: pip install piq
|
2 |
+
# repo: https://github.com/photosynthesis-team/piq
|
libs/metric/piq/functional/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex, crop_patches
|
2 |
+
from .colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm
|
3 |
+
from .filters import haar_filter, hann_filter, scharr_filter, prewitt_filter, gaussian_filter
|
4 |
+
from .filters import binomial_filter1d, average_filter2d
|
5 |
+
from .layers import L2Pool2d
|
6 |
+
from .resize import imresize
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex', 'crop_patches',
|
10 |
+
'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm',
|
11 |
+
'haar_filter', 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter',
|
12 |
+
'binomial_filter1d', 'average_filter2d',
|
13 |
+
'L2Pool2d',
|
14 |
+
'imresize',
|
15 |
+
]
|
libs/metric/piq/functional/base.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""General purpose functions"""
|
2 |
+
from typing import Tuple, Union, Optional
|
3 |
+
import torch
|
4 |
+
from ..utils import _parse_version
|
5 |
+
|
6 |
+
|
7 |
+
def ifftshift(x: torch.Tensor) -> torch.Tensor:
|
8 |
+
r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors"""
|
9 |
+
shift = [-(ax // 2) for ax in x.size()]
|
10 |
+
return torch.roll(x, shift, tuple(range(len(shift))))
|
11 |
+
|
12 |
+
|
13 |
+
def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
14 |
+
r"""Return coordinate grid matrices centered at zero point.
|
15 |
+
Args:
|
16 |
+
size: Shape of meshgrid to create
|
17 |
+
device: device to use for creation
|
18 |
+
dtype: dtype to use for creation
|
19 |
+
Returns:
|
20 |
+
Meshgrid of size on device with dtype values.
|
21 |
+
"""
|
22 |
+
if size[0] % 2:
|
23 |
+
# Odd
|
24 |
+
x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1)
|
25 |
+
else:
|
26 |
+
# Even
|
27 |
+
x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0]
|
28 |
+
|
29 |
+
if size[1] % 2:
|
30 |
+
# Odd
|
31 |
+
y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1)
|
32 |
+
else:
|
33 |
+
# Even
|
34 |
+
y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1]
|
35 |
+
# Use indexing param depending on torch version
|
36 |
+
recommended_torch_version = _parse_version("1.10.0")
|
37 |
+
torch_version = _parse_version(torch.__version__)
|
38 |
+
if len(torch_version) > 0 and torch_version >= recommended_torch_version:
|
39 |
+
return torch.meshgrid(x, y, indexing='ij')
|
40 |
+
return torch.meshgrid(x, y)
|
41 |
+
|
42 |
+
|
43 |
+
def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor:
|
44 |
+
r""" Compute similarity_map between two tensors using Dice-like equation.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
map_x: Tensor with map to be compared
|
48 |
+
map_y: Tensor with map to be compared
|
49 |
+
constant: Used for numerical stability
|
50 |
+
alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator
|
51 |
+
"""
|
52 |
+
return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \
|
53 |
+
(map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant)
|
54 |
+
|
55 |
+
|
56 |
+
def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor:
|
57 |
+
r""" Compute gradient map for a given tensor and stack of kernels.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
x: Tensor with shape (N, C, H, W).
|
61 |
+
kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W)
|
62 |
+
Returns:
|
63 |
+
Gradients of x per-channel with shape (N, C, H, W)
|
64 |
+
"""
|
65 |
+
padding = kernels.size(-1) // 2
|
66 |
+
grads = torch.nn.functional.conv2d(x, kernels, padding=padding)
|
67 |
+
|
68 |
+
return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True))
|
69 |
+
|
70 |
+
|
71 |
+
def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor:
|
72 |
+
r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values.
|
73 |
+
Complex numbers are represented by modulus and argument: r * \exp(i * \phi).
|
74 |
+
|
75 |
+
It will likely to be redundant with introduction of torch.ComplexTensor.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2).
|
79 |
+
exp: Exponent
|
80 |
+
Returns:
|
81 |
+
Complex tensor with shape (N, C, H, W, 2).
|
82 |
+
"""
|
83 |
+
if base.dim() == 4:
|
84 |
+
x_complex_r = base.abs()
|
85 |
+
x_complex_phi = torch.atan2(torch.zeros_like(base), base)
|
86 |
+
elif base.dim() == 5 and base.size(-1) == 2:
|
87 |
+
x_complex_r = base.pow(2).sum(dim=-1).sqrt()
|
88 |
+
x_complex_phi = torch.atan2(base[..., 1], base[..., 0])
|
89 |
+
else:
|
90 |
+
raise ValueError(f'Expected real or complex tensor, got {base.size()}')
|
91 |
+
|
92 |
+
x_complex_pow_r = x_complex_r ** exp
|
93 |
+
x_complex_pow_phi = x_complex_phi * exp
|
94 |
+
x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi)
|
95 |
+
x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi)
|
96 |
+
return torch.stack((x_real_pow, x_imag_pow), dim=-1)
|
97 |
+
|
98 |
+
|
99 |
+
def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor:
|
100 |
+
r"""Crop tensor with images into small patches
|
101 |
+
Args:
|
102 |
+
x: Tensor with shape (N, C, H, W), expected to be images-like entities
|
103 |
+
size: Size of a square patch
|
104 |
+
stride: Step between patches
|
105 |
+
"""
|
106 |
+
assert (x.shape[2] >= size) and (x.shape[3] >= size), \
|
107 |
+
f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})"
|
108 |
+
channels = x.shape[1]
|
109 |
+
patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride)
|
110 |
+
patches = patches.reshape(-1, channels, size, size)
|
111 |
+
return patches
|
libs/metric/piq/functional/colour_conversion.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""Colour space conversion functions"""
|
2 |
+
from typing import Union, Dict
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def rgb2lmn(x: torch.Tensor) -> torch.Tensor:
|
7 |
+
r"""Convert a batch of RGB images to a batch of LMN images
|
8 |
+
|
9 |
+
Args:
|
10 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
Batch of images with shape (N, 3, H, W). LMN colour space.
|
14 |
+
"""
|
15 |
+
weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27],
|
16 |
+
[0.30, 0.04, -0.35],
|
17 |
+
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
|
18 |
+
x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2)
|
19 |
+
return x_lmn
|
20 |
+
|
21 |
+
|
22 |
+
def rgb2xyz(x: torch.Tensor) -> torch.Tensor:
|
23 |
+
r"""Convert a batch of RGB images to a batch of XYZ images
|
24 |
+
|
25 |
+
Args:
|
26 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Batch of images with shape (N, 3, H, W). XYZ colour space.
|
30 |
+
"""
|
31 |
+
mask_below = (x <= 0.04045).type(x.dtype)
|
32 |
+
mask_above = (x > 0.04045).type(x.dtype)
|
33 |
+
|
34 |
+
tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above
|
35 |
+
|
36 |
+
weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
|
37 |
+
[0.2126729, 0.7151522, 0.0721750],
|
38 |
+
[0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device)
|
39 |
+
|
40 |
+
x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2)
|
41 |
+
return x_xyz
|
42 |
+
|
43 |
+
|
44 |
+
def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor:
|
45 |
+
r"""Convert a batch of XYZ images to a batch of LAB images
|
46 |
+
|
47 |
+
Args:
|
48 |
+
x: Batch of images with shape (N, 3, H, W). XYZ colour space.
|
49 |
+
illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant.
|
50 |
+
observer: {“2”, “10”}, optional. The aperture angle of the observer.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Batch of images with shape (N, 3, H, W). LAB colour space.
|
54 |
+
"""
|
55 |
+
epsilon = 0.008856
|
56 |
+
kappa = 903.3
|
57 |
+
illuminants: Dict[str, Dict] = \
|
58 |
+
{"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
|
59 |
+
'10': (1.111420406956693, 1, 0.3519978321919493)},
|
60 |
+
"D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
|
61 |
+
'10': (0.9672062750333777, 1, 0.8142801513128616)},
|
62 |
+
"D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
|
63 |
+
'10': (0.9579665682254781, 1, 0.9092525159847462)},
|
64 |
+
"D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
|
65 |
+
'10': (0.94809667673716, 1, 1.0730513595166162)},
|
66 |
+
"D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
|
67 |
+
'10': (0.9441713925645873, 1, 1.2064272211720228)},
|
68 |
+
"E": {'2': (1.0, 1.0, 1.0),
|
69 |
+
'10': (1.0, 1.0, 1.0)}}
|
70 |
+
|
71 |
+
illuminants_to_use = torch.tensor(illuminants[illuminant][observer],
|
72 |
+
dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
|
73 |
+
|
74 |
+
tmp = x / illuminants_to_use
|
75 |
+
|
76 |
+
mask_below = (tmp <= epsilon).type(x.dtype)
|
77 |
+
mask_above = (tmp > epsilon).type(x.dtype)
|
78 |
+
tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below
|
79 |
+
|
80 |
+
weights_xyz_to_lab = torch.tensor([[0, 116., 0],
|
81 |
+
[500., -500., 0],
|
82 |
+
[0, 200., -200.]], dtype=x.dtype, device=x.device)
|
83 |
+
bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
|
84 |
+
|
85 |
+
x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab
|
86 |
+
return x_lab
|
87 |
+
|
88 |
+
|
89 |
+
def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor:
|
90 |
+
r"""Convert a batch of RGB images to a batch of LAB images
|
91 |
+
|
92 |
+
Args:
|
93 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
94 |
+
data_range: dynamic range of the input image.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Batch of images with shape (N, 3, H, W). LAB colour space.
|
98 |
+
"""
|
99 |
+
return xyz2lab(rgb2xyz(x / float(data_range)))
|
100 |
+
|
101 |
+
|
102 |
+
def rgb2yiq(x: torch.Tensor) -> torch.Tensor:
|
103 |
+
r"""Convert a batch of RGB images to a batch of YIQ images
|
104 |
+
|
105 |
+
Args:
|
106 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Batch of images with shape (N, 3, H, W). YIQ colour space.
|
110 |
+
"""
|
111 |
+
yiq_weights = torch.tensor([
|
112 |
+
[0.299, 0.587, 0.114],
|
113 |
+
[0.5959, -0.2746, -0.3213],
|
114 |
+
[0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t()
|
115 |
+
x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2)
|
116 |
+
return x_yiq
|
117 |
+
|
118 |
+
|
119 |
+
def rgb2lhm(x: torch.Tensor) -> torch.Tensor:
|
120 |
+
r"""Convert a batch of RGB images to a batch of LHM images
|
121 |
+
|
122 |
+
Args:
|
123 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
Batch of images with shape (N, 3, H, W). LHM colour space.
|
127 |
+
|
128 |
+
Reference:
|
129 |
+
https://arxiv.org/pdf/1608.07433.pdf
|
130 |
+
"""
|
131 |
+
lhm_weights = torch.tensor([
|
132 |
+
[0.2989, 0.587, 0.114],
|
133 |
+
[0.3, 0.04, -0.35],
|
134 |
+
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
|
135 |
+
x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2)
|
136 |
+
return x_lhm
|
libs/metric/piq/functional/filters.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""Filters for gradient computation, bluring, etc."""
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
|
7 |
+
def haar_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
8 |
+
r"""Creates Haar kernel
|
9 |
+
|
10 |
+
Args:
|
11 |
+
kernel_size: size of the kernel
|
12 |
+
device: target device for kernel generation
|
13 |
+
dtype: target data type for kernel generation
|
14 |
+
Returns:
|
15 |
+
kernel: Tensor with shape (1, kernel_size, kernel_size)
|
16 |
+
"""
|
17 |
+
kernel = torch.ones((kernel_size, kernel_size), device=device, dtype=dtype) / kernel_size
|
18 |
+
kernel[kernel_size // 2:, :] = - kernel[kernel_size // 2:, :]
|
19 |
+
return kernel.unsqueeze(0)
|
20 |
+
|
21 |
+
|
22 |
+
def hann_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
23 |
+
r"""Creates Hann kernel
|
24 |
+
Args:
|
25 |
+
kernel_size: size of the kernel
|
26 |
+
device: target device for kernel generation
|
27 |
+
dtype: target data type for kernel generation
|
28 |
+
Returns:
|
29 |
+
kernel: Tensor with shape (1, kernel_size, kernel_size)
|
30 |
+
"""
|
31 |
+
# Take bigger window and drop borders
|
32 |
+
window = torch.hann_window(kernel_size + 2, periodic=False, device=device, dtype=dtype)[1:-1]
|
33 |
+
kernel = window[:, None] * window[None, :]
|
34 |
+
# Normalize and reshape kernel
|
35 |
+
return kernel.view(1, kernel_size, kernel_size) / kernel.sum()
|
36 |
+
|
37 |
+
|
38 |
+
def gaussian_filter(kernel_size: int, sigma: float, device: Optional[str] = None,
|
39 |
+
dtype: Optional[type] = None) -> torch.Tensor:
|
40 |
+
r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
|
41 |
+
Args:
|
42 |
+
size: Size of the kernel
|
43 |
+
sigma: Std of the distribution
|
44 |
+
device: target device for kernel generation
|
45 |
+
dtype: target data type for kernel generation
|
46 |
+
Returns:
|
47 |
+
gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
|
48 |
+
"""
|
49 |
+
coords = torch.arange(kernel_size, dtype=dtype, device=device)
|
50 |
+
coords -= (kernel_size - 1) / 2.
|
51 |
+
|
52 |
+
g = coords ** 2
|
53 |
+
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()
|
54 |
+
|
55 |
+
g /= g.sum()
|
56 |
+
return g.unsqueeze(0)
|
57 |
+
|
58 |
+
|
59 |
+
# Gradient operator kernels
|
60 |
+
def scharr_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
61 |
+
r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction
|
62 |
+
|
63 |
+
Args:
|
64 |
+
device: target device for kernel generation
|
65 |
+
dtype: target data type for kernel generation
|
66 |
+
Returns:
|
67 |
+
kernel: Tensor with shape (1, 3, 3)
|
68 |
+
"""
|
69 |
+
return torch.tensor([[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]]], device=device, dtype=dtype) / 16
|
70 |
+
|
71 |
+
|
72 |
+
def prewitt_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
73 |
+
r"""Utility function that returns a normalized 3x3 Prewitt kernel in X direction
|
74 |
+
|
75 |
+
Args:
|
76 |
+
device: target device for kernel generation
|
77 |
+
dtype: target data type for kernel generation
|
78 |
+
Returns:
|
79 |
+
kernel: Tensor with shape (1, 3, 3)"""
|
80 |
+
return torch.tensor([[[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]]], device=device, dtype=dtype) / 3
|
81 |
+
|
82 |
+
|
83 |
+
def binomial_filter1d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
84 |
+
r"""Creates 1D normalized binomial filter
|
85 |
+
|
86 |
+
Args:
|
87 |
+
kernel_size (int): kernel size
|
88 |
+
device: target device for kernel generation
|
89 |
+
dtype: target data type for kernel generation
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Binomial kernel with shape (1, 1, kernel_size)
|
93 |
+
"""
|
94 |
+
kernel = np.poly1d([0.5, 0.5]) ** (kernel_size - 1)
|
95 |
+
return torch.tensor(kernel.c, dtype=dtype, device=device).view(1, 1, kernel_size)
|
96 |
+
|
97 |
+
|
98 |
+
def average_filter2d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
99 |
+
r"""Creates 2D normalized average filter
|
100 |
+
|
101 |
+
Args:
|
102 |
+
kernel_size (int): kernel size
|
103 |
+
device: target device for kernel generation
|
104 |
+
dtype: target data type for kernel generation
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
kernel: Tensor with shape (1, kernel_size, kernel_size)
|
108 |
+
"""
|
109 |
+
window = torch.ones(kernel_size, dtype=dtype, device=device) / kernel_size
|
110 |
+
kernel = window[:, None] * window[None, :]
|
111 |
+
return kernel.unsqueeze(0)
|
libs/metric/piq/functional/layers.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""Custom layers used in metrics computations"""
|
2 |
+
import torch
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from .filters import hann_filter
|
6 |
+
|
7 |
+
|
8 |
+
class L2Pool2d(torch.nn.Module):
|
9 |
+
r"""Applies L2 pooling with Hann window of size 3x3
|
10 |
+
Args:
|
11 |
+
x: Tensor with shape (N, C, H, W)"""
|
12 |
+
EPS = 1e-12
|
13 |
+
|
14 |
+
def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None:
|
15 |
+
super().__init__()
|
16 |
+
self.kernel_size = kernel_size
|
17 |
+
self.stride = stride
|
18 |
+
self.padding = padding
|
19 |
+
|
20 |
+
self.kernel: Optional[torch.Tensor] = None
|
21 |
+
|
22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
23 |
+
if self.kernel is None:
|
24 |
+
C = x.size(1)
|
25 |
+
self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x)
|
26 |
+
|
27 |
+
out = torch.nn.functional.conv2d(
|
28 |
+
x ** 2, self.kernel,
|
29 |
+
stride=self.stride,
|
30 |
+
padding=self.padding,
|
31 |
+
groups=x.shape[1]
|
32 |
+
)
|
33 |
+
return (out + self.EPS).sqrt()
|
libs/metric/piq/functional/resize.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A standalone PyTorch implementation for fast and efficient bicubic resampling.
|
3 |
+
The resulting values are the same to MATLAB function imresize('bicubic').
|
4 |
+
## Author: Sanghyun Son
|
5 |
+
## Email: [email protected] (primary), [email protected] (secondary)
|
6 |
+
## Version: 1.2.0
|
7 |
+
## Last update: July 9th, 2020 (KST)
|
8 |
+
Dependency: torch
|
9 |
+
Example::
|
10 |
+
>>> import torch
|
11 |
+
>>> import core
|
12 |
+
>>> x = torch.arange(16).float().view(1, 1, 4, 4)
|
13 |
+
>>> y = core.imresize(x, sizes=(3, 3))
|
14 |
+
>>> print(y)
|
15 |
+
tensor([[[[ 0.7506, 2.1004, 3.4503],
|
16 |
+
[ 6.1505, 7.5000, 8.8499],
|
17 |
+
[11.5497, 12.8996, 14.2494]]]])
|
18 |
+
"""
|
19 |
+
|
20 |
+
import math
|
21 |
+
import typing
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch.nn import functional as F
|
25 |
+
|
26 |
+
__all__ = ['imresize']
|
27 |
+
|
28 |
+
_I = typing.Optional[int]
|
29 |
+
_D = typing.Optional[torch.dtype]
|
30 |
+
|
31 |
+
|
32 |
+
def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
|
33 |
+
range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
|
34 |
+
cont = range_around_0.to(dtype=x.dtype)
|
35 |
+
return cont
|
36 |
+
|
37 |
+
|
38 |
+
def linear_contribution(x: torch.Tensor) -> torch.Tensor:
|
39 |
+
ax = x.abs()
|
40 |
+
range_01 = ax.le(1)
|
41 |
+
cont = (1 - ax) * range_01.to(dtype=x.dtype)
|
42 |
+
return cont
|
43 |
+
|
44 |
+
|
45 |
+
def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
|
46 |
+
ax = x.abs()
|
47 |
+
ax2 = ax * ax
|
48 |
+
ax3 = ax * ax2
|
49 |
+
|
50 |
+
range_01 = ax.le(1)
|
51 |
+
range_12 = torch.logical_and(ax.gt(1), ax.le(2))
|
52 |
+
|
53 |
+
cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
|
54 |
+
cont_01 = cont_01 * range_01.to(dtype=x.dtype)
|
55 |
+
|
56 |
+
cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
|
57 |
+
cont_12 = cont_12 * range_12.to(dtype=x.dtype)
|
58 |
+
|
59 |
+
cont = cont_01 + cont_12
|
60 |
+
return cont
|
61 |
+
|
62 |
+
|
63 |
+
def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
|
64 |
+
range_3sigma = (x.abs() <= 3 * sigma + 1)
|
65 |
+
# Normalization will be done after
|
66 |
+
cont = torch.exp(-x.pow(2) / (2 * sigma ** 2))
|
67 |
+
cont = cont * range_3sigma.to(dtype=x.dtype)
|
68 |
+
return cont
|
69 |
+
|
70 |
+
|
71 |
+
def discrete_kernel(
|
72 |
+
kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
|
73 |
+
'''
|
74 |
+
For downsampling with integer scale only.
|
75 |
+
'''
|
76 |
+
downsampling_factor = int(1 / scale)
|
77 |
+
if kernel == 'cubic':
|
78 |
+
kernel_size_orig = 4
|
79 |
+
else:
|
80 |
+
raise ValueError('Pass!')
|
81 |
+
|
82 |
+
if antialiasing:
|
83 |
+
kernel_size = kernel_size_orig * downsampling_factor
|
84 |
+
else:
|
85 |
+
kernel_size = kernel_size_orig
|
86 |
+
|
87 |
+
if downsampling_factor % 2 == 0:
|
88 |
+
a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
|
89 |
+
else:
|
90 |
+
kernel_size -= 1
|
91 |
+
a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
r = torch.linspace(-a, a, steps=kernel_size)
|
95 |
+
k = cubic_contribution(r).view(-1, 1)
|
96 |
+
k = torch.matmul(k, k.t())
|
97 |
+
k /= k.sum()
|
98 |
+
|
99 |
+
return k
|
100 |
+
|
101 |
+
|
102 |
+
def reflect_padding(
|
103 |
+
x: torch.Tensor,
|
104 |
+
dim: int,
|
105 |
+
pad_pre: int,
|
106 |
+
pad_post: int) -> torch.Tensor:
|
107 |
+
'''
|
108 |
+
Apply reflect padding to the given Tensor.
|
109 |
+
Note that it is slightly different from the PyTorch functional.pad,
|
110 |
+
where boundary elements are used only once.
|
111 |
+
Instead, we follow the MATLAB implementation
|
112 |
+
which uses boundary elements twice.
|
113 |
+
For example,
|
114 |
+
[a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
|
115 |
+
while our implementation yields [a, a, b, c, d, d].
|
116 |
+
'''
|
117 |
+
b, c, h, w = x.size()
|
118 |
+
if dim == 2 or dim == -2:
|
119 |
+
padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
|
120 |
+
padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
|
121 |
+
for p in range(pad_pre):
|
122 |
+
padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
|
123 |
+
for p in range(pad_post):
|
124 |
+
padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
|
125 |
+
else:
|
126 |
+
padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
|
127 |
+
padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
|
128 |
+
for p in range(pad_pre):
|
129 |
+
padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
|
130 |
+
for p in range(pad_post):
|
131 |
+
padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
|
132 |
+
|
133 |
+
return padding_buffer
|
134 |
+
|
135 |
+
|
136 |
+
def padding(
|
137 |
+
x: torch.Tensor,
|
138 |
+
dim: int,
|
139 |
+
pad_pre: int,
|
140 |
+
pad_post: int,
|
141 |
+
padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
|
142 |
+
if padding_type is None:
|
143 |
+
return x
|
144 |
+
elif padding_type == 'reflect':
|
145 |
+
x_pad = reflect_padding(x, dim, pad_pre, pad_post)
|
146 |
+
else:
|
147 |
+
raise ValueError('{} padding is not supported!'.format(padding_type))
|
148 |
+
|
149 |
+
return x_pad
|
150 |
+
|
151 |
+
|
152 |
+
def get_padding(
|
153 |
+
base: torch.Tensor,
|
154 |
+
kernel_size: int,
|
155 |
+
x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
|
156 |
+
base = base.long()
|
157 |
+
r_min = base.min()
|
158 |
+
r_max = base.max() + kernel_size - 1
|
159 |
+
|
160 |
+
if r_min <= 0:
|
161 |
+
pad_pre = -r_min
|
162 |
+
pad_pre = pad_pre.item()
|
163 |
+
base += pad_pre
|
164 |
+
else:
|
165 |
+
pad_pre = 0
|
166 |
+
|
167 |
+
if r_max >= x_size:
|
168 |
+
pad_post = r_max - x_size + 1
|
169 |
+
pad_post = pad_post.item()
|
170 |
+
else:
|
171 |
+
pad_post = 0
|
172 |
+
|
173 |
+
return pad_pre, pad_post, base
|
174 |
+
|
175 |
+
|
176 |
+
def get_weight(
|
177 |
+
dist: torch.Tensor,
|
178 |
+
kernel_size: int,
|
179 |
+
kernel: str = 'cubic',
|
180 |
+
sigma: float = 2.0,
|
181 |
+
antialiasing_factor: float = 1) -> torch.Tensor:
|
182 |
+
buffer_pos = dist.new_zeros(kernel_size, len(dist))
|
183 |
+
for idx, buffer_sub in enumerate(buffer_pos):
|
184 |
+
buffer_sub.copy_(dist - idx)
|
185 |
+
|
186 |
+
# Expand (downsampling) / Shrink (upsampling) the receptive field.
|
187 |
+
buffer_pos *= antialiasing_factor
|
188 |
+
if kernel == 'cubic':
|
189 |
+
weight = cubic_contribution(buffer_pos)
|
190 |
+
elif kernel == 'gaussian':
|
191 |
+
weight = gaussian_contribution(buffer_pos, sigma=sigma)
|
192 |
+
else:
|
193 |
+
raise ValueError('{} kernel is not supported!'.format(kernel))
|
194 |
+
|
195 |
+
weight /= weight.sum(dim=0, keepdim=True)
|
196 |
+
return weight
|
197 |
+
|
198 |
+
|
199 |
+
def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
|
200 |
+
# Resize height
|
201 |
+
if dim == 2 or dim == -2:
|
202 |
+
k = (kernel_size, 1)
|
203 |
+
h_out = x.size(-2) - kernel_size + 1
|
204 |
+
w_out = x.size(-1)
|
205 |
+
# Resize width
|
206 |
+
else:
|
207 |
+
k = (1, kernel_size)
|
208 |
+
h_out = x.size(-2)
|
209 |
+
w_out = x.size(-1) - kernel_size + 1
|
210 |
+
|
211 |
+
unfold = F.unfold(x, k)
|
212 |
+
unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
|
213 |
+
return unfold
|
214 |
+
|
215 |
+
|
216 |
+
def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
|
217 |
+
if x.dim() == 4:
|
218 |
+
b, c, h, w = x.size()
|
219 |
+
elif x.dim() == 3:
|
220 |
+
c, h, w = x.size()
|
221 |
+
b = None
|
222 |
+
elif x.dim() == 2:
|
223 |
+
h, w = x.size()
|
224 |
+
b = c = None
|
225 |
+
else:
|
226 |
+
raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
|
227 |
+
|
228 |
+
x = x.view(-1, 1, h, w)
|
229 |
+
return x, b, c, h, w
|
230 |
+
|
231 |
+
|
232 |
+
def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
|
233 |
+
rh = x.size(-2)
|
234 |
+
rw = x.size(-1)
|
235 |
+
# Back to the original dimension
|
236 |
+
if b is not None:
|
237 |
+
x = x.view(b, c, rh, rw) # 4-dim
|
238 |
+
else:
|
239 |
+
if c is not None:
|
240 |
+
x = x.view(c, rh, rw) # 3-dim
|
241 |
+
else:
|
242 |
+
x = x.view(rh, rw) # 2-dim
|
243 |
+
|
244 |
+
return x
|
245 |
+
|
246 |
+
|
247 |
+
def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
|
248 |
+
if x.dtype != torch.float32 or x.dtype != torch.float64:
|
249 |
+
dtype = x.dtype
|
250 |
+
x = x.float()
|
251 |
+
else:
|
252 |
+
dtype = None
|
253 |
+
|
254 |
+
return x, dtype
|
255 |
+
|
256 |
+
|
257 |
+
def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
|
258 |
+
if dtype is not None:
|
259 |
+
if not dtype.is_floating_point:
|
260 |
+
x = x.round()
|
261 |
+
# To prevent over/underflow when converting types
|
262 |
+
if dtype is torch.uint8:
|
263 |
+
x = x.clamp(0, 255)
|
264 |
+
|
265 |
+
x = x.to(dtype=dtype)
|
266 |
+
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
def resize_1d(
|
271 |
+
x: torch.Tensor,
|
272 |
+
dim: int,
|
273 |
+
size: int,
|
274 |
+
scale: float,
|
275 |
+
kernel: str = 'cubic',
|
276 |
+
sigma: float = 2.0,
|
277 |
+
padding_type: str = 'reflect',
|
278 |
+
antialiasing: bool = True) -> torch.Tensor:
|
279 |
+
'''
|
280 |
+
Args:
|
281 |
+
x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
|
282 |
+
dim (int):
|
283 |
+
scale (float):
|
284 |
+
size (int):
|
285 |
+
Return:
|
286 |
+
'''
|
287 |
+
# Identity case
|
288 |
+
if scale == 1:
|
289 |
+
return x
|
290 |
+
|
291 |
+
# Default bicubic kernel with antialiasing (only when downsampling)
|
292 |
+
if kernel == 'cubic':
|
293 |
+
kernel_size = 4
|
294 |
+
else:
|
295 |
+
kernel_size = math.floor(6 * sigma)
|
296 |
+
|
297 |
+
if antialiasing and (scale < 1):
|
298 |
+
antialiasing_factor = scale
|
299 |
+
kernel_size = math.ceil(kernel_size / antialiasing_factor)
|
300 |
+
else:
|
301 |
+
antialiasing_factor = 1
|
302 |
+
|
303 |
+
# We allow margin to both sizes
|
304 |
+
kernel_size += 2
|
305 |
+
|
306 |
+
# Weights only depend on the shape of input and output,
|
307 |
+
# so we do not calculate gradients here.
|
308 |
+
with torch.no_grad():
|
309 |
+
pos = torch.linspace(
|
310 |
+
0, size - 1, steps=size, dtype=x.dtype, device=x.device,
|
311 |
+
)
|
312 |
+
pos = (pos + 0.5) / scale - 0.5
|
313 |
+
base = pos.floor() - (kernel_size // 2) + 1
|
314 |
+
dist = pos - base
|
315 |
+
weight = get_weight(
|
316 |
+
dist,
|
317 |
+
kernel_size,
|
318 |
+
kernel=kernel,
|
319 |
+
sigma=sigma,
|
320 |
+
antialiasing_factor=antialiasing_factor,
|
321 |
+
)
|
322 |
+
pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
|
323 |
+
|
324 |
+
# To backpropagate through x
|
325 |
+
x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
|
326 |
+
unfold = reshape_tensor(x_pad, dim, kernel_size)
|
327 |
+
# Subsampling first
|
328 |
+
if dim == 2 or dim == -2:
|
329 |
+
sample = unfold[..., base, :]
|
330 |
+
weight = weight.view(1, kernel_size, sample.size(2), 1)
|
331 |
+
else:
|
332 |
+
sample = unfold[..., base]
|
333 |
+
weight = weight.view(1, kernel_size, 1, sample.size(3))
|
334 |
+
|
335 |
+
# Apply the kernel
|
336 |
+
x = sample * weight
|
337 |
+
x = x.sum(dim=1, keepdim=True)
|
338 |
+
return x
|
339 |
+
|
340 |
+
|
341 |
+
def downsampling_2d(
|
342 |
+
x: torch.Tensor,
|
343 |
+
k: torch.Tensor,
|
344 |
+
scale: int,
|
345 |
+
padding_type: str = 'reflect') -> torch.Tensor:
|
346 |
+
c = x.size(1)
|
347 |
+
k_h = k.size(-2)
|
348 |
+
k_w = k.size(-1)
|
349 |
+
|
350 |
+
k = k.to(dtype=x.dtype, device=x.device)
|
351 |
+
k = k.view(1, 1, k_h, k_w)
|
352 |
+
k = k.repeat(c, c, 1, 1)
|
353 |
+
e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
|
354 |
+
e = e.view(c, c, 1, 1)
|
355 |
+
k = k * e
|
356 |
+
|
357 |
+
pad_h = (k_h - scale) // 2
|
358 |
+
pad_w = (k_w - scale) // 2
|
359 |
+
x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
|
360 |
+
x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
|
361 |
+
y = F.conv2d(x, k, padding=0, stride=scale)
|
362 |
+
return y
|
363 |
+
|
364 |
+
|
365 |
+
def imresize(
|
366 |
+
x: torch.Tensor,
|
367 |
+
scale: typing.Optional[float] = None,
|
368 |
+
sizes: typing.Optional[typing.Tuple[int, int]] = None,
|
369 |
+
kernel: typing.Union[str, torch.Tensor] = 'cubic',
|
370 |
+
sigma: float = 2,
|
371 |
+
rotation_degree: float = 0,
|
372 |
+
padding_type: str = 'reflect',
|
373 |
+
antialiasing: bool = True) -> torch.Tensor:
|
374 |
+
"""
|
375 |
+
Args:
|
376 |
+
x (torch.Tensor):
|
377 |
+
scale (float):
|
378 |
+
sizes (tuple(int, int)):
|
379 |
+
kernel (str, default='cubic'):
|
380 |
+
sigma (float, default=2):
|
381 |
+
rotation_degree (float, default=0):
|
382 |
+
padding_type (str, default='reflect'):
|
383 |
+
antialiasing (bool, default=True):
|
384 |
+
Return:
|
385 |
+
torch.Tensor:
|
386 |
+
"""
|
387 |
+
if scale is None and sizes is None:
|
388 |
+
raise ValueError('One of scale or sizes must be specified!')
|
389 |
+
if scale is not None and sizes is not None:
|
390 |
+
raise ValueError('Please specify scale or sizes to avoid conflict!')
|
391 |
+
|
392 |
+
x, b, c, h, w = reshape_input(x)
|
393 |
+
|
394 |
+
if sizes is None and scale is not None:
|
395 |
+
'''
|
396 |
+
# Check if we can apply the convolution algorithm
|
397 |
+
scale_inv = 1 / scale
|
398 |
+
if isinstance(kernel, str) and scale_inv.is_integer():
|
399 |
+
kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
|
400 |
+
elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
|
401 |
+
raise ValueError(
|
402 |
+
'An integer downsampling factor '
|
403 |
+
'should be used with a predefined kernel!'
|
404 |
+
)
|
405 |
+
'''
|
406 |
+
# Determine output size
|
407 |
+
sizes = (math.ceil(h * scale), math.ceil(w * scale))
|
408 |
+
scales = (scale, scale)
|
409 |
+
|
410 |
+
if scale is None and sizes is not None:
|
411 |
+
scales = (sizes[0] / h, sizes[1] / w)
|
412 |
+
|
413 |
+
x, dtype = cast_input(x)
|
414 |
+
|
415 |
+
if isinstance(kernel, str) and sizes is not None:
|
416 |
+
# Core resizing module
|
417 |
+
x = resize_1d(x, -2, size=sizes[0], scale=scales[0], kernel=kernel, sigma=sigma, padding_type=padding_type,
|
418 |
+
antialiasing=antialiasing)
|
419 |
+
x = resize_1d(x, -1, size=sizes[1], scale=scales[1], kernel=kernel, sigma=sigma, padding_type=padding_type,
|
420 |
+
antialiasing=antialiasing)
|
421 |
+
elif isinstance(kernel, torch.Tensor) and scale is not None:
|
422 |
+
x = downsampling_2d(x, kernel, scale=int(1 / scale))
|
423 |
+
|
424 |
+
x = reshape_output(x, b, c)
|
425 |
+
x = cast_output(x, dtype)
|
426 |
+
return x
|
libs/metric/piq/perceptual.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of Content loss, Style loss, LPIPS and DISTS metrics
|
3 |
+
References:
|
4 |
+
.. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias
|
5 |
+
(2016). A Neural Algorithm of Artistic Style}
|
6 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
7 |
+
https://arxiv.org/abs/1508.06576
|
8 |
+
.. [2] Zhang, Richard and Isola, Phillip and Efros, et al.
|
9 |
+
(2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
10 |
+
2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
11 |
+
https://arxiv.org/abs/1801.03924
|
12 |
+
"""
|
13 |
+
from typing import List, Union, Collection
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from torch.nn.modules.loss import _Loss
|
18 |
+
from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights
|
19 |
+
|
20 |
+
from .utils import _validate_input, _reduce
|
21 |
+
from .functional import similarity_map, L2Pool2d
|
22 |
+
|
23 |
+
# Map VGG names to corresponding number in torchvision layer
|
24 |
+
VGG16_LAYERS = {
|
25 |
+
"conv1_1": '0', "relu1_1": '1',
|
26 |
+
"conv1_2": '2', "relu1_2": '3',
|
27 |
+
"pool1": '4',
|
28 |
+
"conv2_1": '5', "relu2_1": '6',
|
29 |
+
"conv2_2": '7', "relu2_2": '8',
|
30 |
+
"pool2": '9',
|
31 |
+
"conv3_1": '10', "relu3_1": '11',
|
32 |
+
"conv3_2": '12', "relu3_2": '13',
|
33 |
+
"conv3_3": '14', "relu3_3": '15',
|
34 |
+
"pool3": '16',
|
35 |
+
"conv4_1": '17', "relu4_1": '18',
|
36 |
+
"conv4_2": '19', "relu4_2": '20',
|
37 |
+
"conv4_3": '21', "relu4_3": '22',
|
38 |
+
"pool4": '23',
|
39 |
+
"conv5_1": '24', "relu5_1": '25',
|
40 |
+
"conv5_2": '26', "relu5_2": '27',
|
41 |
+
"conv5_3": '28', "relu5_3": '29',
|
42 |
+
"pool5": '30',
|
43 |
+
}
|
44 |
+
|
45 |
+
VGG19_LAYERS = {
|
46 |
+
"conv1_1": '0', "relu1_1": '1',
|
47 |
+
"conv1_2": '2', "relu1_2": '3',
|
48 |
+
"pool1": '4',
|
49 |
+
"conv2_1": '5', "relu2_1": '6',
|
50 |
+
"conv2_2": '7', "relu2_2": '8',
|
51 |
+
"pool2": '9',
|
52 |
+
"conv3_1": '10', "relu3_1": '11',
|
53 |
+
"conv3_2": '12', "relu3_2": '13',
|
54 |
+
"conv3_3": '14', "relu3_3": '15',
|
55 |
+
"conv3_4": '16', "relu3_4": '17',
|
56 |
+
"pool3": '18',
|
57 |
+
"conv4_1": '19', "relu4_1": '20',
|
58 |
+
"conv4_2": '21', "relu4_2": '22',
|
59 |
+
"conv4_3": '23', "relu4_3": '24',
|
60 |
+
"conv4_4": '25', "relu4_4": '26',
|
61 |
+
"pool4": '27',
|
62 |
+
"conv5_1": '28', "relu5_1": '29',
|
63 |
+
"conv5_2": '30', "relu5_2": '31',
|
64 |
+
"conv5_3": '32', "relu5_3": '33',
|
65 |
+
"conv5_4": '34', "relu5_4": '35',
|
66 |
+
"pool5": '36',
|
67 |
+
}
|
68 |
+
|
69 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
70 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
71 |
+
|
72 |
+
# Constant used in feature normalization to avoid zero division
|
73 |
+
EPS = 1e-10
|
74 |
+
|
75 |
+
|
76 |
+
class ContentLoss(_Loss):
|
77 |
+
r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks.
|
78 |
+
Uses pretrained VGG models from torchvision.
|
79 |
+
Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1]
|
80 |
+
|
81 |
+
Args:
|
82 |
+
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
|
83 |
+
layers: List of strings with layer names. Default: ``'relu3_3'``
|
84 |
+
weights: List of float weight to balance different layers
|
85 |
+
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
|
86 |
+
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
|
87 |
+
reduction: Specifies the reduction type:
|
88 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
89 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
90 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
91 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
92 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
93 |
+
normalize_features: If true, unit-normalize each feature in channel dimension before scaling
|
94 |
+
and computing distance. See references for details.
|
95 |
+
|
96 |
+
Examples:
|
97 |
+
>>> loss = ContentLoss()
|
98 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
99 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
100 |
+
>>> output = loss(x, y)
|
101 |
+
>>> output.backward()
|
102 |
+
|
103 |
+
References:
|
104 |
+
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
|
105 |
+
A Neural Algorithm of Artistic Style
|
106 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
107 |
+
https://arxiv.org/abs/1508.06576
|
108 |
+
|
109 |
+
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
|
110 |
+
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
111 |
+
IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
112 |
+
https://arxiv.org/abs/1801.03924
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",),
|
116 |
+
weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False,
|
117 |
+
distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
|
118 |
+
std: List[float] = IMAGENET_STD, normalize_features: bool = False,
|
119 |
+
allow_layers_weights_mismatch: bool = False) -> None:
|
120 |
+
|
121 |
+
assert allow_layers_weights_mismatch or len(layers) == len(weights), \
|
122 |
+
f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \
|
123 |
+
f'which will cause incorrect results. Please provide weight for each layer.'
|
124 |
+
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
if callable(feature_extractor):
|
128 |
+
self.model = feature_extractor
|
129 |
+
self.layers = layers
|
130 |
+
else:
|
131 |
+
if feature_extractor == "vgg16":
|
132 |
+
# self.model = vgg16(pretrained=True, progress=False).features
|
133 |
+
self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features
|
134 |
+
self.layers = [VGG16_LAYERS[l] for l in layers]
|
135 |
+
elif feature_extractor == "vgg19":
|
136 |
+
# self.model = vgg19(pretrained=True, progress=False).features
|
137 |
+
self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features
|
138 |
+
self.layers = [VGG19_LAYERS[l] for l in layers]
|
139 |
+
else:
|
140 |
+
raise ValueError("Unknown feature extractor")
|
141 |
+
|
142 |
+
if replace_pooling:
|
143 |
+
self.model = self.replace_pooling(self.model)
|
144 |
+
|
145 |
+
# Disable gradients
|
146 |
+
for param in self.model.parameters():
|
147 |
+
param.requires_grad_(False)
|
148 |
+
|
149 |
+
self.distance = {
|
150 |
+
"mse": nn.MSELoss,
|
151 |
+
"mae": nn.L1Loss,
|
152 |
+
}[distance](reduction='none')
|
153 |
+
|
154 |
+
self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights]
|
155 |
+
|
156 |
+
mean = torch.tensor(mean)
|
157 |
+
std = torch.tensor(std)
|
158 |
+
self.mean = mean.view(1, -1, 1, 1)
|
159 |
+
self.std = std.view(1, -1, 1, 1)
|
160 |
+
|
161 |
+
self.normalize_features = normalize_features
|
162 |
+
self.reduction = reduction
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
165 |
+
r"""Computation of Content loss between feature representations of prediction :math:`x` and
|
166 |
+
target :math:`y` tensors.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
170 |
+
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
Content loss between feature representations
|
174 |
+
"""
|
175 |
+
_validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))
|
176 |
+
|
177 |
+
self.model.to(x)
|
178 |
+
x_features = self.get_features(x)
|
179 |
+
y_features = self.get_features(y)
|
180 |
+
|
181 |
+
distances = self.compute_distance(x_features, y_features)
|
182 |
+
|
183 |
+
# Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
|
184 |
+
loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1)
|
185 |
+
|
186 |
+
return _reduce(loss, self.reduction)
|
187 |
+
|
188 |
+
def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]:
|
189 |
+
r"""Take L2 or L1 distance between feature maps depending on ``distance``.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
x_features: Features of the input tensor.
|
193 |
+
y_features: Features of the target tensor.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Distance between feature maps
|
197 |
+
"""
|
198 |
+
return [self.distance(x, y) for x, y in zip(x_features, y_features)]
|
199 |
+
|
200 |
+
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
201 |
+
r"""
|
202 |
+
Args:
|
203 |
+
x: Tensor. Shape :math:`(N, C, H, W)`.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
List of features extracted from intermediate layers
|
207 |
+
"""
|
208 |
+
# Normalize input
|
209 |
+
x = (x - self.mean.to(x)) / self.std.to(x)
|
210 |
+
|
211 |
+
features = []
|
212 |
+
for name, module in self.model._modules.items():
|
213 |
+
x = module(x)
|
214 |
+
if name in self.layers:
|
215 |
+
features.append(self.normalize(x) if self.normalize_features else x)
|
216 |
+
|
217 |
+
return features
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def normalize(x: torch.Tensor) -> torch.Tensor:
|
221 |
+
r"""Normalize feature maps in channel direction to unit length.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
x: Tensor. Shape :math:`(N, C, H, W)`.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
Normalized input
|
228 |
+
"""
|
229 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
230 |
+
return x / (norm_factor + EPS)
|
231 |
+
|
232 |
+
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
|
233 |
+
r"""Turn All MaxPool layers into AveragePool
|
234 |
+
|
235 |
+
Args:
|
236 |
+
module: Module to change MaxPool int AveragePool
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
Module with AveragePool instead MaxPool
|
240 |
+
|
241 |
+
"""
|
242 |
+
module_output = module
|
243 |
+
if isinstance(module, torch.nn.MaxPool2d):
|
244 |
+
module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
245 |
+
|
246 |
+
for name, child in module.named_children():
|
247 |
+
module_output.add_module(name, self.replace_pooling(child))
|
248 |
+
return module_output
|
249 |
+
|
250 |
+
|
251 |
+
class StyleLoss(ContentLoss):
|
252 |
+
r"""Creates Style loss that can be used for image style transfer or as a measure in
|
253 |
+
image to image tasks. Computes distance between Gram matrices of feature maps.
|
254 |
+
Uses pretrained VGG models from torchvision.
|
255 |
+
|
256 |
+
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
|
257 |
+
If no normalisation is required, change `mean` and `std` values accordingly.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
|
261 |
+
layers: List of strings with layer names. Default: ``'relu3_3'``
|
262 |
+
weights: List of float weight to balance different layers
|
263 |
+
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
|
264 |
+
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
|
265 |
+
reduction: Specifies the reduction type:
|
266 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
267 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
268 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
269 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
270 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
271 |
+
normalize_features: If true, unit-normalize each feature in channel dimension before scaling
|
272 |
+
and computing distance. See references for details.
|
273 |
+
|
274 |
+
Examples:
|
275 |
+
>>> loss = StyleLoss()
|
276 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
277 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
278 |
+
>>> output = loss(x, y)
|
279 |
+
>>> output.backward()
|
280 |
+
|
281 |
+
References:
|
282 |
+
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
|
283 |
+
A Neural Algorithm of Artistic Style
|
284 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
285 |
+
https://arxiv.org/abs/1508.06576
|
286 |
+
|
287 |
+
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
|
288 |
+
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
289 |
+
IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
290 |
+
https://arxiv.org/abs/1801.03924
|
291 |
+
"""
|
292 |
+
|
293 |
+
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor):
|
294 |
+
r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
x_features: Features of the input tensor.
|
298 |
+
y_features: Features of the target tensor.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
Distance between Gram matrices
|
302 |
+
"""
|
303 |
+
x_gram = [self.gram_matrix(x) for x in x_features]
|
304 |
+
y_gram = [self.gram_matrix(x) for x in y_features]
|
305 |
+
return [self.distance(x, y) for x, y in zip(x_gram, y_gram)]
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def gram_matrix(x: torch.Tensor) -> torch.Tensor:
|
309 |
+
r"""Compute Gram matrix for batch of features.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
x: Tensor. Shape :math:`(N, C, H, W)`.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
Gram matrix for given input
|
316 |
+
"""
|
317 |
+
B, C, H, W = x.size()
|
318 |
+
gram = []
|
319 |
+
for i in range(B):
|
320 |
+
features = x[i].view(C, H * W)
|
321 |
+
|
322 |
+
# Add fake channel dimension
|
323 |
+
gram.append(torch.mm(features, features.t()).unsqueeze(0))
|
324 |
+
|
325 |
+
return torch.stack(gram)
|
326 |
+
|
327 |
+
|
328 |
+
class LPIPS(ContentLoss):
|
329 |
+
r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported.
|
330 |
+
|
331 |
+
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
|
332 |
+
If no normalisation is required, change `mean` and `std` values accordingly.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
|
336 |
+
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
|
337 |
+
reduction: Specifies the reduction type:
|
338 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
339 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
340 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
341 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
342 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
343 |
+
|
344 |
+
Examples:
|
345 |
+
>>> loss = LPIPS()
|
346 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
347 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
348 |
+
>>> output = loss(x, y)
|
349 |
+
>>> output.backward()
|
350 |
+
|
351 |
+
References:
|
352 |
+
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
|
353 |
+
A Neural Algorithm of Artistic Style
|
354 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
355 |
+
https://arxiv.org/abs/1508.06576
|
356 |
+
|
357 |
+
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
|
358 |
+
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
359 |
+
IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
360 |
+
https://arxiv.org/abs/1801.03924
|
361 |
+
https://github.com/richzhang/PerceptualSimilarity
|
362 |
+
"""
|
363 |
+
_weights_url = "https://github.com/photosynthesis-team/" + \
|
364 |
+
"photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt"
|
365 |
+
|
366 |
+
def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean",
|
367 |
+
mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None:
|
368 |
+
lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
|
369 |
+
lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
|
370 |
+
super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights,
|
371 |
+
replace_pooling=replace_pooling, distance=distance,
|
372 |
+
reduction=reduction, mean=mean, std=std,
|
373 |
+
normalize_features=True)
|
374 |
+
|
375 |
+
|
376 |
+
class DISTS(ContentLoss):
|
377 |
+
r"""Deep Image Structure and Texture Similarity metric.
|
378 |
+
|
379 |
+
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
|
380 |
+
If no normalisation is required, change `mean` and `std` values accordingly.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
reduction: Specifies the reduction type:
|
384 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
385 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
386 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
387 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
388 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
389 |
+
|
390 |
+
Examples:
|
391 |
+
>>> loss = DISTS()
|
392 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
393 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
394 |
+
>>> output = loss(x, y)
|
395 |
+
>>> output.backward()
|
396 |
+
|
397 |
+
References:
|
398 |
+
Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020).
|
399 |
+
Image Quality Assessment: Unifying Structure and Texture Similarity.
|
400 |
+
https://arxiv.org/abs/2004.07728
|
401 |
+
https://github.com/dingkeyan93/DISTS
|
402 |
+
"""
|
403 |
+
_weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt"
|
404 |
+
|
405 |
+
def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
|
406 |
+
std: List[float] = IMAGENET_STD) -> None:
|
407 |
+
dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
|
408 |
+
channels = [3, 64, 128, 256, 512, 512]
|
409 |
+
|
410 |
+
weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
|
411 |
+
dists_weights = list(torch.split(weights['alpha'], channels, dim=1))
|
412 |
+
dists_weights.extend(torch.split(weights['beta'], channels, dim=1))
|
413 |
+
|
414 |
+
super().__init__("vgg16", layers=dists_layers, weights=dists_weights,
|
415 |
+
replace_pooling=True, reduction=reduction, mean=mean, std=std,
|
416 |
+
normalize_features=False, allow_layers_weights_mismatch=True)
|
417 |
+
|
418 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
419 |
+
r"""
|
420 |
+
|
421 |
+
Args:
|
422 |
+
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
423 |
+
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1].
|
427 |
+
"""
|
428 |
+
_, _, H, W = x.shape
|
429 |
+
|
430 |
+
if min(H, W) > 256:
|
431 |
+
x = torch.nn.functional.interpolate(
|
432 |
+
x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
|
433 |
+
y = torch.nn.functional.interpolate(
|
434 |
+
y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
|
435 |
+
|
436 |
+
loss = super().forward(x, y)
|
437 |
+
return 1 - loss
|
438 |
+
|
439 |
+
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]:
|
440 |
+
r"""Compute structure similarity between feature maps
|
441 |
+
|
442 |
+
Args:
|
443 |
+
x_features: Features of the input tensor.
|
444 |
+
y_features: Features of the target tensor.
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
Structural similarity distance between feature maps
|
448 |
+
"""
|
449 |
+
structure_distance, texture_distance = [], []
|
450 |
+
# Small constant for numerical stability
|
451 |
+
EPS = 1e-6
|
452 |
+
|
453 |
+
for x, y in zip(x_features, y_features):
|
454 |
+
x_mean = x.mean([2, 3], keepdim=True)
|
455 |
+
y_mean = y.mean([2, 3], keepdim=True)
|
456 |
+
structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS))
|
457 |
+
|
458 |
+
x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True)
|
459 |
+
y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True)
|
460 |
+
xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean
|
461 |
+
texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS))
|
462 |
+
|
463 |
+
return structure_distance + texture_distance
|
464 |
+
|
465 |
+
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
466 |
+
r"""
|
467 |
+
|
468 |
+
Args:
|
469 |
+
x: Input tensor
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
List of features extracted from input tensor
|
473 |
+
"""
|
474 |
+
features = super().get_features(x)
|
475 |
+
|
476 |
+
# Add input tensor as an additional feature
|
477 |
+
features.insert(0, x)
|
478 |
+
return features
|
479 |
+
|
480 |
+
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
|
481 |
+
r"""Turn All MaxPool layers into L2Pool
|
482 |
+
|
483 |
+
Args:
|
484 |
+
module: Module to change MaxPool into L2Pool
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
Module with L2Pool instead of MaxPool
|
488 |
+
"""
|
489 |
+
module_output = module
|
490 |
+
if isinstance(module, torch.nn.MaxPool2d):
|
491 |
+
module_output = L2Pool2d(kernel_size=3, stride=2, padding=1)
|
492 |
+
|
493 |
+
for name, child in module.named_children():
|
494 |
+
module_output.add_module(name, self.replace_pooling(child))
|
495 |
+
|
496 |
+
return module_output
|
libs/metric/piq/utils/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common import _validate_input, _reduce, _parse_version
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
"_validate_input",
|
5 |
+
"_reduce",
|
6 |
+
'_parse_version'
|
7 |
+
]
|
libs/metric/piq/utils/common.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import re
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
from typing import Tuple, List, Optional, Union, Dict, Any
|
6 |
+
|
7 |
+
SEMVER_VERSION_PATTERN = re.compile(
|
8 |
+
r"""
|
9 |
+
^
|
10 |
+
(?P<major>0|[1-9]\d*)
|
11 |
+
\.
|
12 |
+
(?P<minor>0|[1-9]\d*)
|
13 |
+
\.
|
14 |
+
(?P<patch>0|[1-9]\d*)
|
15 |
+
(?:-(?P<prerelease>
|
16 |
+
(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)
|
17 |
+
(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*
|
18 |
+
))?
|
19 |
+
(?:\+(?P<build>
|
20 |
+
[0-9a-zA-Z-]+
|
21 |
+
(?:\.[0-9a-zA-Z-]+)*
|
22 |
+
))?
|
23 |
+
$
|
24 |
+
""",
|
25 |
+
re.VERBOSE,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
PEP_440_VERSION_PATTERN = r"""
|
30 |
+
v?
|
31 |
+
(?:
|
32 |
+
(?:(?P<epoch>[0-9]+)!)? # epoch
|
33 |
+
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
34 |
+
(?P<pre> # pre-release
|
35 |
+
[-_\.]?
|
36 |
+
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
37 |
+
[-_\.]?
|
38 |
+
(?P<pre_n>[0-9]+)?
|
39 |
+
)?
|
40 |
+
(?P<post> # post release
|
41 |
+
(?:-(?P<post_n1>[0-9]+))
|
42 |
+
|
|
43 |
+
(?:
|
44 |
+
[-_\.]?
|
45 |
+
(?P<post_l>post|rev|r)
|
46 |
+
[-_\.]?
|
47 |
+
(?P<post_n2>[0-9]+)?
|
48 |
+
)
|
49 |
+
)?
|
50 |
+
(?P<dev> # dev release
|
51 |
+
[-_\.]?
|
52 |
+
(?P<dev_l>dev)
|
53 |
+
[-_\.]?
|
54 |
+
(?P<dev_n>[0-9]+)?
|
55 |
+
)?
|
56 |
+
)
|
57 |
+
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
58 |
+
"""
|
59 |
+
|
60 |
+
|
61 |
+
def _validate_input(
|
62 |
+
tensors: List[torch.Tensor],
|
63 |
+
dim_range: Tuple[int, int] = (0, -1),
|
64 |
+
data_range: Tuple[float, float] = (0., -1.),
|
65 |
+
# size_dim_range: Tuple[float, float] = (0., -1.),
|
66 |
+
size_range: Optional[Tuple[int, int]] = None,
|
67 |
+
) -> None:
|
68 |
+
r"""Check that input(-s) satisfies the requirements
|
69 |
+
Args:
|
70 |
+
tensors: Tensors to check
|
71 |
+
dim_range: Allowed number of dimensions. (min, max)
|
72 |
+
data_range: Allowed range of values in tensors. (min, max)
|
73 |
+
size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
|
74 |
+
"""
|
75 |
+
|
76 |
+
if not __debug__:
|
77 |
+
return
|
78 |
+
|
79 |
+
x = tensors[0]
|
80 |
+
|
81 |
+
for t in tensors:
|
82 |
+
assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
|
83 |
+
assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
|
84 |
+
|
85 |
+
if size_range is None:
|
86 |
+
assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
|
87 |
+
else:
|
88 |
+
assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
|
89 |
+
f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
|
90 |
+
|
91 |
+
if dim_range[0] == dim_range[1]:
|
92 |
+
assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
|
93 |
+
elif dim_range[0] < dim_range[1]:
|
94 |
+
assert dim_range[0] <= t.dim() <= dim_range[1], \
|
95 |
+
f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
|
96 |
+
|
97 |
+
if data_range[0] < data_range[1]:
|
98 |
+
assert data_range[0] <= t.min(), \
|
99 |
+
f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
|
100 |
+
assert t.max() <= data_range[1], \
|
101 |
+
f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
|
102 |
+
|
103 |
+
|
104 |
+
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
|
105 |
+
r"""Reduce input in batch dimension if needed.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
x: Tensor with shape (N, *).
|
109 |
+
reduction: Specifies the reduction type:
|
110 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
|
111 |
+
"""
|
112 |
+
if reduction == 'none':
|
113 |
+
return x
|
114 |
+
elif reduction == 'mean':
|
115 |
+
return x.mean(dim=0)
|
116 |
+
elif reduction == 'sum':
|
117 |
+
return x.sum(dim=0)
|
118 |
+
else:
|
119 |
+
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
120 |
+
|
121 |
+
|
122 |
+
def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
|
123 |
+
""" Parses valid Python versions according to Semver and PEP 440 specifications.
|
124 |
+
For more on Semver check: https://semver.org/
|
125 |
+
For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
|
126 |
+
|
127 |
+
Implementation is inspired by:
|
128 |
+
- https://github.com/python-semver
|
129 |
+
- https://github.com/pypa/packaging
|
130 |
+
|
131 |
+
Args:
|
132 |
+
version: unparsed information about the library of interest.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
parsed information about the library of interest.
|
136 |
+
"""
|
137 |
+
if isinstance(version, bytes):
|
138 |
+
version = version.decode("UTF-8")
|
139 |
+
elif not isinstance(version, str) and not isinstance(version, bytes):
|
140 |
+
raise TypeError(f"not expecting type {type(version)}")
|
141 |
+
|
142 |
+
# Semver processing
|
143 |
+
match = SEMVER_VERSION_PATTERN.match(version)
|
144 |
+
if match:
|
145 |
+
matched_version_parts: Dict[str, Any] = match.groupdict()
|
146 |
+
release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
|
147 |
+
return release
|
148 |
+
|
149 |
+
# PEP 440 processing
|
150 |
+
regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
151 |
+
match = regex.search(version)
|
152 |
+
|
153 |
+
if match is None:
|
154 |
+
warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
|
155 |
+
return tuple()
|
156 |
+
|
157 |
+
release = tuple(int(i) for i in match.group("release").split("."))
|
158 |
+
return release
|
libs/metric/pytorch_fid/__init__.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = '0.3.0'
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
from .inception import InceptionV3
|
7 |
+
from .fid_score import calculate_frechet_distance
|
8 |
+
|
9 |
+
|
10 |
+
class PytorchFIDFactory(torch.nn.Module):
|
11 |
+
"""
|
12 |
+
|
13 |
+
Args:
|
14 |
+
channels:
|
15 |
+
inception_block_idx:
|
16 |
+
|
17 |
+
Examples:
|
18 |
+
>>> fid_factory = PytorchFIDFactory()
|
19 |
+
>>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
|
20 |
+
>>> print(fid_score)
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
|
24 |
+
super().__init__()
|
25 |
+
self.channels = channels
|
26 |
+
|
27 |
+
# load models
|
28 |
+
assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
|
29 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
|
30 |
+
self.inception_v3 = InceptionV3([block_idx])
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def calculate_activation_statistics(self, samples):
|
34 |
+
features = self.inception_v3(samples)[0]
|
35 |
+
features = rearrange(features, '... 1 1 -> ...')
|
36 |
+
|
37 |
+
mu = torch.mean(features, dim=0).cpu()
|
38 |
+
sigma = torch.cov(features).cpu()
|
39 |
+
return mu, sigma
|
40 |
+
|
41 |
+
def score(self, real_samples, fake_samples):
|
42 |
+
if self.channels == 1:
|
43 |
+
real_samples, fake_samples = map(
|
44 |
+
lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
|
45 |
+
)
|
46 |
+
|
47 |
+
min_batch = min(real_samples.shape[0], fake_samples.shape[0])
|
48 |
+
real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
|
49 |
+
|
50 |
+
m1, s1 = self.calculate_activation_statistics(real_samples)
|
51 |
+
m2, s2 = self.calculate_activation_statistics(fake_samples)
|
52 |
+
|
53 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
54 |
+
return fid_value
|
libs/metric/pytorch_fid/fid_score.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
import os
|
35 |
+
import pathlib
|
36 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
37 |
+
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
import torchvision.transforms as TF
|
41 |
+
from PIL import Image
|
42 |
+
from scipy import linalg
|
43 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
44 |
+
|
45 |
+
try:
|
46 |
+
from tqdm import tqdm
|
47 |
+
except ImportError:
|
48 |
+
# If tqdm is not available, provide a mock version of it
|
49 |
+
def tqdm(x):
|
50 |
+
return x
|
51 |
+
|
52 |
+
from .inception import InceptionV3
|
53 |
+
|
54 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
55 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
56 |
+
help='Batch size to use')
|
57 |
+
parser.add_argument('--num-workers', type=int,
|
58 |
+
help=('Number of processes to use for data loading. '
|
59 |
+
'Defaults to `min(8, num_cpus)`'))
|
60 |
+
parser.add_argument('--device', type=str, default=None,
|
61 |
+
help='Device to use. Like cuda, cuda:0 or cpu')
|
62 |
+
parser.add_argument('--dims', type=int, default=2048,
|
63 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
64 |
+
help=('Dimensionality of Inception features to use. '
|
65 |
+
'By default, uses pool3 features'))
|
66 |
+
parser.add_argument('--save-stats', action='store_true',
|
67 |
+
help=('Generate an npz archive from a directory of samples. '
|
68 |
+
'The first path is used as input and the second as output.'))
|
69 |
+
parser.add_argument('path', type=str, nargs=2,
|
70 |
+
help=('Paths to the generated images or '
|
71 |
+
'to .npz statistic files'))
|
72 |
+
|
73 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
74 |
+
'tif', 'tiff', 'webp'}
|
75 |
+
|
76 |
+
|
77 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
78 |
+
def __init__(self, files, transforms=None):
|
79 |
+
self.files = files
|
80 |
+
self.transforms = transforms
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
return len(self.files)
|
84 |
+
|
85 |
+
def __getitem__(self, i):
|
86 |
+
path = self.files[i]
|
87 |
+
img = Image.open(path).convert('RGB')
|
88 |
+
if self.transforms is not None:
|
89 |
+
img = self.transforms(img)
|
90 |
+
return img
|
91 |
+
|
92 |
+
|
93 |
+
def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
|
94 |
+
num_workers=1):
|
95 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
96 |
+
|
97 |
+
Params:
|
98 |
+
-- files : List of image files paths
|
99 |
+
-- model : Instance of inception model
|
100 |
+
-- batch_size : Batch size of images for the model to process at once.
|
101 |
+
Make sure that the number of samples is a multiple of
|
102 |
+
the batch size, otherwise some samples are ignored. This
|
103 |
+
behavior is retained to match the original FID score
|
104 |
+
implementation.
|
105 |
+
-- dims : Dimensionality of features returned by Inception
|
106 |
+
-- device : Device to run calculations
|
107 |
+
-- num_workers : Number of parallel dataloader workers
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
111 |
+
activations of the given tensor when feeding inception with the
|
112 |
+
query tensor.
|
113 |
+
"""
|
114 |
+
model.eval()
|
115 |
+
|
116 |
+
if batch_size > len(files):
|
117 |
+
print(('Warning: batch size is bigger than the data size. '
|
118 |
+
'Setting batch size to data size'))
|
119 |
+
batch_size = len(files)
|
120 |
+
|
121 |
+
dataset = ImagePathDataset(files, transforms=TF.ToTensor())
|
122 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
123 |
+
batch_size=batch_size,
|
124 |
+
shuffle=False,
|
125 |
+
drop_last=False,
|
126 |
+
num_workers=num_workers)
|
127 |
+
|
128 |
+
pred_arr = np.empty((len(files), dims))
|
129 |
+
|
130 |
+
start_idx = 0
|
131 |
+
|
132 |
+
for batch in tqdm(dataloader):
|
133 |
+
batch = batch.to(device)
|
134 |
+
|
135 |
+
with torch.no_grad():
|
136 |
+
pred = model(batch)[0]
|
137 |
+
|
138 |
+
# If model output is not scalar, apply global spatial average pooling.
|
139 |
+
# This happens if you choose a dimensionality not equal 2048.
|
140 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
141 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
142 |
+
|
143 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
144 |
+
|
145 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
146 |
+
|
147 |
+
start_idx = start_idx + pred.shape[0]
|
148 |
+
|
149 |
+
return pred_arr
|
150 |
+
|
151 |
+
|
152 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
153 |
+
"""Numpy implementation of the Frechet Distance.
|
154 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
155 |
+
and X_2 ~ N(mu_2, C_2) is
|
156 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
157 |
+
|
158 |
+
Stable version by Dougal J. Sutherland.
|
159 |
+
|
160 |
+
Params:
|
161 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
162 |
+
inception net (like returned by the function 'get_predictions')
|
163 |
+
for generated samples.
|
164 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
165 |
+
representative data set.
|
166 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
167 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
168 |
+
representative data set.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
-- : The Frechet Distance.
|
172 |
+
"""
|
173 |
+
|
174 |
+
mu1 = np.atleast_1d(mu1)
|
175 |
+
mu2 = np.atleast_1d(mu2)
|
176 |
+
|
177 |
+
sigma1 = np.atleast_2d(sigma1)
|
178 |
+
sigma2 = np.atleast_2d(sigma2)
|
179 |
+
|
180 |
+
assert mu1.shape == mu2.shape, \
|
181 |
+
'Training and test mean vectors have different lengths'
|
182 |
+
assert sigma1.shape == sigma2.shape, \
|
183 |
+
'Training and test covariances have different dimensions'
|
184 |
+
|
185 |
+
diff = mu1 - mu2
|
186 |
+
|
187 |
+
# Product might be almost singular
|
188 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
189 |
+
if not np.isfinite(covmean).all():
|
190 |
+
msg = ('fid calculation produces singular product; '
|
191 |
+
'adding %s to diagonal of cov estimates') % eps
|
192 |
+
print(msg)
|
193 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
194 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
195 |
+
|
196 |
+
# Numerical error might give slight imaginary component
|
197 |
+
if np.iscomplexobj(covmean):
|
198 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
199 |
+
m = np.max(np.abs(covmean.imag))
|
200 |
+
raise ValueError('Imaginary component {}'.format(m))
|
201 |
+
covmean = covmean.real
|
202 |
+
|
203 |
+
tr_covmean = np.trace(covmean)
|
204 |
+
|
205 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
206 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
207 |
+
|
208 |
+
|
209 |
+
def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
|
210 |
+
device='cpu', num_workers=1):
|
211 |
+
"""Calculation of the statistics used by the FID.
|
212 |
+
Params:
|
213 |
+
-- files : List of image files paths
|
214 |
+
-- model : Instance of inception model
|
215 |
+
-- batch_size : The images numpy array is split into batches with
|
216 |
+
batch size batch_size. A reasonable batch size
|
217 |
+
depends on the hardware.
|
218 |
+
-- dims : Dimensionality of features returned by Inception
|
219 |
+
-- device : Device to run calculations
|
220 |
+
-- num_workers : Number of parallel dataloader workers
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
224 |
+
the inception model.
|
225 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
226 |
+
the inception model.
|
227 |
+
"""
|
228 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers)
|
229 |
+
mu = np.mean(act, axis=0)
|
230 |
+
sigma = np.cov(act, rowvar=False)
|
231 |
+
return mu, sigma
|
232 |
+
|
233 |
+
|
234 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device,
|
235 |
+
num_workers=1):
|
236 |
+
if path.endswith('.npz'):
|
237 |
+
with np.load(path) as f:
|
238 |
+
m, s = f['mu'][:], f['sigma'][:]
|
239 |
+
else:
|
240 |
+
path = pathlib.Path(path)
|
241 |
+
files = sorted([file for ext in IMAGE_EXTENSIONS
|
242 |
+
for file in path.glob('*.{}'.format(ext))])
|
243 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
244 |
+
dims, device, num_workers)
|
245 |
+
|
246 |
+
return m, s
|
247 |
+
|
248 |
+
|
249 |
+
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
|
250 |
+
"""Calculates the FID of two paths"""
|
251 |
+
for p in paths:
|
252 |
+
if not os.path.exists(p):
|
253 |
+
raise RuntimeError('Invalid path: %s' % p)
|
254 |
+
|
255 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
256 |
+
|
257 |
+
model = InceptionV3([block_idx]).to(device)
|
258 |
+
|
259 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
260 |
+
dims, device, num_workers)
|
261 |
+
m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
|
262 |
+
dims, device, num_workers)
|
263 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
264 |
+
|
265 |
+
return fid_value
|
266 |
+
|
267 |
+
|
268 |
+
def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
|
269 |
+
"""Calculates the FID of two paths"""
|
270 |
+
if not os.path.exists(paths[0]):
|
271 |
+
raise RuntimeError('Invalid path: %s' % paths[0])
|
272 |
+
|
273 |
+
if os.path.exists(paths[1]):
|
274 |
+
raise RuntimeError('Existing output file: %s' % paths[1])
|
275 |
+
|
276 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
277 |
+
|
278 |
+
model = InceptionV3([block_idx]).to(device)
|
279 |
+
|
280 |
+
print(f"Saving statistics for {paths[0]}")
|
281 |
+
|
282 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
283 |
+
dims, device, num_workers)
|
284 |
+
|
285 |
+
np.savez_compressed(paths[1], mu=m1, sigma=s1)
|
286 |
+
|
287 |
+
|
288 |
+
def main():
|
289 |
+
args = parser.parse_args()
|
290 |
+
|
291 |
+
if args.device is None:
|
292 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
293 |
+
else:
|
294 |
+
device = torch.device(args.device)
|
295 |
+
|
296 |
+
if args.num_workers is None:
|
297 |
+
try:
|
298 |
+
num_cpus = len(os.sched_getaffinity(0))
|
299 |
+
except AttributeError:
|
300 |
+
# os.sched_getaffinity is not available under Windows, use
|
301 |
+
# os.cpu_count instead (which may not return the *available* number
|
302 |
+
# of CPUs).
|
303 |
+
num_cpus = os.cpu_count()
|
304 |
+
|
305 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
306 |
+
else:
|
307 |
+
num_workers = args.num_workers
|
308 |
+
|
309 |
+
if args.save_stats:
|
310 |
+
save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
|
311 |
+
return
|
312 |
+
|
313 |
+
fid_value = calculate_fid_given_paths(args.path,
|
314 |
+
args.batch_size,
|
315 |
+
device,
|
316 |
+
args.dims,
|
317 |
+
num_workers)
|
318 |
+
print('FID: ', fid_value)
|
319 |
+
|
320 |
+
|
321 |
+
if __name__ == '__main__':
|
322 |
+
main()
|
libs/metric/pytorch_fid/inception.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = _inception_v3(weights='DEFAULT')
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def _inception_v3(*args, **kwargs):
|
167 |
+
"""Wraps `torchvision.models.inception_v3`"""
|
168 |
+
try:
|
169 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
170 |
+
except ValueError:
|
171 |
+
# Just a caution against weird version strings
|
172 |
+
version = (0,)
|
173 |
+
|
174 |
+
# Skips default weight inititialization if supported by torchvision
|
175 |
+
# version. See https://github.com/mseitzer/pytorch-fid/issues/28.
|
176 |
+
if version >= (0, 6):
|
177 |
+
kwargs['init_weights'] = False
|
178 |
+
|
179 |
+
# Backwards compatibility: `weights` argument was handled by `pretrained`
|
180 |
+
# argument prior to version 0.13.
|
181 |
+
if version < (0, 13) and 'weights' in kwargs:
|
182 |
+
if kwargs['weights'] == 'DEFAULT':
|
183 |
+
kwargs['pretrained'] = True
|
184 |
+
elif kwargs['weights'] is None:
|
185 |
+
kwargs['pretrained'] = False
|
186 |
+
else:
|
187 |
+
raise ValueError(
|
188 |
+
'weights=={} not supported in torchvision {}'.format(
|
189 |
+
kwargs['weights'], torchvision.__version__
|
190 |
+
)
|
191 |
+
)
|
192 |
+
del kwargs['weights']
|
193 |
+
|
194 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
195 |
+
|
196 |
+
|
197 |
+
def fid_inception_v3():
|
198 |
+
"""Build pretrained Inception model for FID computation
|
199 |
+
|
200 |
+
The Inception model for FID computation uses a different set of weights
|
201 |
+
and has a slightly different structure than torchvision's Inception.
|
202 |
+
|
203 |
+
This method first constructs torchvision's Inception and then patches the
|
204 |
+
necessary parts that are different in the FID Inception model.
|
205 |
+
"""
|
206 |
+
inception = _inception_v3(num_classes=1008,
|
207 |
+
aux_logits=False,
|
208 |
+
weights=None)
|
209 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
210 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
211 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
212 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
213 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
214 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
215 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
216 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
217 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
218 |
+
|
219 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
220 |
+
inception.load_state_dict(state_dict)
|
221 |
+
return inception
|
222 |
+
|
223 |
+
|
224 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
225 |
+
"""InceptionA block patched for FID computation"""
|
226 |
+
def __init__(self, in_channels, pool_features):
|
227 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
branch1x1 = self.branch1x1(x)
|
231 |
+
|
232 |
+
branch5x5 = self.branch5x5_1(x)
|
233 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
234 |
+
|
235 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
236 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
237 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
238 |
+
|
239 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
240 |
+
# its average calculation
|
241 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
242 |
+
count_include_pad=False)
|
243 |
+
branch_pool = self.branch_pool(branch_pool)
|
244 |
+
|
245 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
246 |
+
return torch.cat(outputs, 1)
|
247 |
+
|
248 |
+
|
249 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
250 |
+
"""InceptionC block patched for FID computation"""
|
251 |
+
def __init__(self, in_channels, channels_7x7):
|
252 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
branch1x1 = self.branch1x1(x)
|
256 |
+
|
257 |
+
branch7x7 = self.branch7x7_1(x)
|
258 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
259 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
260 |
+
|
261 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
262 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
263 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
264 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
265 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
266 |
+
|
267 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
268 |
+
# its average calculation
|
269 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
270 |
+
count_include_pad=False)
|
271 |
+
branch_pool = self.branch_pool(branch_pool)
|
272 |
+
|
273 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
274 |
+
return torch.cat(outputs, 1)
|
275 |
+
|
276 |
+
|
277 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
278 |
+
"""First InceptionE block patched for FID computation"""
|
279 |
+
def __init__(self, in_channels):
|
280 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
branch1x1 = self.branch1x1(x)
|
284 |
+
|
285 |
+
branch3x3 = self.branch3x3_1(x)
|
286 |
+
branch3x3 = [
|
287 |
+
self.branch3x3_2a(branch3x3),
|
288 |
+
self.branch3x3_2b(branch3x3),
|
289 |
+
]
|
290 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
291 |
+
|
292 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
293 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
294 |
+
branch3x3dbl = [
|
295 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
296 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
297 |
+
]
|
298 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
299 |
+
|
300 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
301 |
+
# its average calculation
|
302 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
303 |
+
count_include_pad=False)
|
304 |
+
branch_pool = self.branch_pool(branch_pool)
|
305 |
+
|
306 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
307 |
+
return torch.cat(outputs, 1)
|
308 |
+
|
309 |
+
|
310 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
311 |
+
"""Second InceptionE block patched for FID computation"""
|
312 |
+
def __init__(self, in_channels):
|
313 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
branch1x1 = self.branch1x1(x)
|
317 |
+
|
318 |
+
branch3x3 = self.branch3x3_1(x)
|
319 |
+
branch3x3 = [
|
320 |
+
self.branch3x3_2a(branch3x3),
|
321 |
+
self.branch3x3_2b(branch3x3),
|
322 |
+
]
|
323 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
324 |
+
|
325 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
326 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
327 |
+
branch3x3dbl = [
|
328 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
329 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
330 |
+
]
|
331 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
332 |
+
|
333 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
334 |
+
# pooling. This is likely an error in this specific Inception
|
335 |
+
# implementation, as other Inception models use average pooling here
|
336 |
+
# (which matches the description in the paper).
|
337 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
338 |
+
branch_pool = self.branch_pool(branch_pool)
|
339 |
+
|
340 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
341 |
+
return torch.cat(outputs, 1)
|
libs/modules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|