MarkMoHR commited on
Commit
7aefe45
·
1 Parent(s): 762579f

added code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +153 -0
  2. config/diffsketchedit.yaml +75 -0
  3. docs/figures/refine/ldm_generated_image0.png +3 -0
  4. docs/figures/refine/ldm_generated_image1.png +3 -0
  5. docs/figures/refine/ldm_generated_image2.png +3 -0
  6. docs/figures/refine/visual_best-rendered0.png +3 -0
  7. docs/figures/refine/visual_best-rendered1.png +3 -0
  8. docs/figures/refine/visual_best-rendered2.png +3 -0
  9. docs/figures/replace/ldm_generated_image0.png +3 -0
  10. docs/figures/replace/ldm_generated_image1.png +3 -0
  11. docs/figures/replace/ldm_generated_image2.png +3 -0
  12. docs/figures/replace/ldm_generated_image3.png +3 -0
  13. docs/figures/replace/visual_best-rendered0.png +3 -0
  14. docs/figures/replace/visual_best-rendered1.png +3 -0
  15. docs/figures/replace/visual_best-rendered2.png +3 -0
  16. docs/figures/replace/visual_best-rendered3.png +3 -0
  17. docs/figures/reweight/ldm_generated_image0.png +3 -0
  18. docs/figures/reweight/ldm_generated_image1.png +3 -0
  19. docs/figures/reweight/ldm_generated_image2.png +3 -0
  20. docs/figures/reweight/visual_best-rendered0.png +3 -0
  21. docs/figures/reweight/visual_best-rendered1.png +3 -0
  22. docs/figures/reweight/visual_best-rendered2.png +3 -0
  23. libs/__init__.py +9 -0
  24. libs/engine/__init__.py +7 -0
  25. libs/engine/config_processor.py +151 -0
  26. libs/engine/model_state.py +335 -0
  27. libs/metric/__init__.py +1 -0
  28. libs/metric/accuracy.py +25 -0
  29. libs/metric/clip_score/__init__.py +3 -0
  30. libs/metric/clip_score/openaiCLIP_loss.py +304 -0
  31. libs/metric/lpips_origin/__init__.py +3 -0
  32. libs/metric/lpips_origin/lpips.py +184 -0
  33. libs/metric/lpips_origin/pretrained_networks.py +196 -0
  34. libs/metric/lpips_origin/weights/v0.1/alex.pth +0 -0
  35. libs/metric/lpips_origin/weights/v0.1/squeeze.pth +0 -0
  36. libs/metric/lpips_origin/weights/v0.1/vgg.pth +0 -0
  37. libs/metric/piq/__init__.py +2 -0
  38. libs/metric/piq/functional/__init__.py +15 -0
  39. libs/metric/piq/functional/base.py +111 -0
  40. libs/metric/piq/functional/colour_conversion.py +136 -0
  41. libs/metric/piq/functional/filters.py +111 -0
  42. libs/metric/piq/functional/layers.py +33 -0
  43. libs/metric/piq/functional/resize.py +426 -0
  44. libs/metric/piq/perceptual.py +496 -0
  45. libs/metric/piq/utils/__init__.py +7 -0
  46. libs/metric/piq/utils/common.py +158 -0
  47. libs/metric/pytorch_fid/__init__.py +54 -0
  48. libs/metric/pytorch_fid/fid_score.py +322 -0
  49. libs/metric/pytorch_fid/inception.py +341 -0
  50. libs/modules/__init__.py +1 -0
.gitignore ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # .idea
132
+ .idea/
133
+ /idea/
134
+ *.ipr
135
+ *.iml
136
+ *.iws
137
+
138
+ # system
139
+ .DS_Store
140
+
141
+ # pytorch-lighting logs
142
+ lightning_logs/*
143
+
144
+ # Edit settings
145
+ .editorconfig
146
+
147
+ # local results
148
+ /workdir/
149
+ .workdir/
150
+
151
+ # dataset
152
+ /dataset/
153
+ !/dataset/placeholder.md
config/diffsketchedit.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1
2
+ image_size: 224
3
+ mask_object: False # if the target image contains background, it's better to mask it out
4
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
5
+
6
+ # train
7
+ num_iter: 1000
8
+ batch_size: 1
9
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
10
+ lr_scheduler: False
11
+ lr_decay_rate: 0.1
12
+ decay_steps: [ 1000, 1500 ]
13
+ lr: 1
14
+ color_lr: 0.01
15
+ pruning_freq: 50
16
+ color_vars_threshold: 0.1
17
+ width_lr: 0.1
18
+ max_width: 50 # stroke width
19
+
20
+ # stroke attrs
21
+ num_paths: 96 # number of strokes
22
+ width: 1.0 # stroke width
23
+ control_points_per_seg: 4
24
+ num_segments: 1
25
+ optim_opacity: True # if True, the stroke opacity is optimized
26
+ optim_width: False # if True, the stroke width is optimized
27
+ optim_rgba: False # if True, the stroke RGBA is optimized
28
+ opacity_delta: 0 # stroke pruning
29
+
30
+ # init strokes
31
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
32
+ xdog_intersec: True # initialize along the edge, mix XDoG and attn up
33
+ softmax_temp: 0.5
34
+ cross_attn_res: 16
35
+ self_attn_res: 32
36
+ max_com: 20 # select the number of the self-attn maps
37
+ mean_comp: False # the average of the self-attn maps
38
+ comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
39
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
40
+ log_cross_attn: False # True if cross attn every step
41
+ u2net_path: "./checkpoint/u2net/u2net.pth"
42
+
43
+ # ldm
44
+ model_id: "sd14"
45
+ ldm_speed_up: False
46
+ enable_xformers: False
47
+ gradient_checkpoint: False
48
+ #token_ind: 1 # the index of CLIP prompt embedding, start from 1
49
+ use_ddim: True
50
+ num_inference_steps: 50
51
+ guidance_scale: 7.5 # sdxl default 5.0
52
+ # ASDS loss
53
+ sds:
54
+ crop_size: 512
55
+ augmentations: "affine"
56
+ guidance_scale: 100
57
+ grad_scale: 1e-5
58
+ t_range: [ 0.05, 0.95 ]
59
+ warmup: 0
60
+
61
+ clip:
62
+ model_name: "RN101" # RN101, ViT-L/14
63
+ feats_loss_type: "l2" # clip visual loss type, conv layers
64
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
65
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
66
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
67
+ augmentations: "affine" # augmentation before clip visual computation
68
+ num_aug: 4 # num of augmentation before clip visual computation
69
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
70
+ text_visual_coeff: 0 # cosine similarity between text and img
71
+
72
+ perceptual:
73
+ name: "lpips" # dists
74
+ lpips_net: 'vgg'
75
+ coeff: 0.2
docs/figures/refine/ldm_generated_image0.png ADDED

Git LFS Details

  • SHA256: 75a0f634d343e08b4c9fe1486b1ff8e2ff330322ff2c8aa6d67f37704425e844
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
docs/figures/refine/ldm_generated_image1.png ADDED

Git LFS Details

  • SHA256: 94d8ccb932a0c088e8014dc89652b65b5325ebfb10b217c6be7c305cd50527a3
  • Pointer size: 131 Bytes
  • Size of remote file: 332 kB
docs/figures/refine/ldm_generated_image2.png ADDED

Git LFS Details

  • SHA256: b7770774aea4ecd71b3daebd81d551ef5ae0bfc457e99f9ace53dad266115dc3
  • Pointer size: 131 Bytes
  • Size of remote file: 334 kB
docs/figures/refine/visual_best-rendered0.png ADDED

Git LFS Details

  • SHA256: 6c61fe29bd5d1bb07b7f519587324f8dd50f83dbac2eb6256cd4716a6d36b63c
  • Pointer size: 130 Bytes
  • Size of remote file: 28.2 kB
docs/figures/refine/visual_best-rendered1.png ADDED

Git LFS Details

  • SHA256: c5efbe511ae39e108ae1f1d98431d6eb9c5650ae17706d1270bb839797ac0fb9
  • Pointer size: 130 Bytes
  • Size of remote file: 29.2 kB
docs/figures/refine/visual_best-rendered2.png ADDED

Git LFS Details

  • SHA256: 80871b6669d46533a38c17164ef5a1992ce4cfa3fb1ddb6422b1a451c2b2e0d7
  • Pointer size: 130 Bytes
  • Size of remote file: 29.5 kB
docs/figures/replace/ldm_generated_image0.png ADDED

Git LFS Details

  • SHA256: 12cf20fbcb849f388e004a95d873a141f18cb7645e006438d8e789cd5dd83a4c
  • Pointer size: 131 Bytes
  • Size of remote file: 462 kB
docs/figures/replace/ldm_generated_image1.png ADDED

Git LFS Details

  • SHA256: a35c9877e8790665b7ab05f49621241ebdf8b1f3ab0cb3079b5f932cc5337a0d
  • Pointer size: 131 Bytes
  • Size of remote file: 461 kB
docs/figures/replace/ldm_generated_image2.png ADDED

Git LFS Details

  • SHA256: 464ba1a05a0835904bec99233100e88efea7e9cc751c95817e27cec7e9bd4991
  • Pointer size: 131 Bytes
  • Size of remote file: 460 kB
docs/figures/replace/ldm_generated_image3.png ADDED

Git LFS Details

  • SHA256: 2f56a7cd6a71efe5231de48eb74699a91890778179111b0acff7cae0fda073f3
  • Pointer size: 131 Bytes
  • Size of remote file: 485 kB
docs/figures/replace/visual_best-rendered0.png ADDED

Git LFS Details

  • SHA256: bac660e47ecfc7f8b3d30181df9729afdaa55af72ae662cd67fab736c5869f09
  • Pointer size: 130 Bytes
  • Size of remote file: 44 kB
docs/figures/replace/visual_best-rendered1.png ADDED

Git LFS Details

  • SHA256: c34d8ab9e3f3c9563ef019b3be48262940c940a75b138635b43f0c54e452de17
  • Pointer size: 130 Bytes
  • Size of remote file: 50.3 kB
docs/figures/replace/visual_best-rendered2.png ADDED

Git LFS Details

  • SHA256: 61c57adbce024e3ae95c214e449aba74144bb8c374075c2ff4a37d0083c2f70e
  • Pointer size: 130 Bytes
  • Size of remote file: 52.7 kB
docs/figures/replace/visual_best-rendered3.png ADDED

Git LFS Details

  • SHA256: b030cdd81791c3a3bf350d60523a2ee47af4ab1dc0028e22ddf4ec85d3df1a12
  • Pointer size: 130 Bytes
  • Size of remote file: 59.1 kB
docs/figures/reweight/ldm_generated_image0.png ADDED

Git LFS Details

  • SHA256: af8c459a5f27d40ce4db6e3a4bcea442eac50d7d768010cee46db38e384e5af5
  • Pointer size: 131 Bytes
  • Size of remote file: 467 kB
docs/figures/reweight/ldm_generated_image1.png ADDED

Git LFS Details

  • SHA256: 3bf5adb91c7881eb2ec46dc9d54f980e6d5b1a0465b84fe0715313fcb9cb2312
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
docs/figures/reweight/ldm_generated_image2.png ADDED

Git LFS Details

  • SHA256: 3516da234863316239902f9058bef26c2c81bc48cbd49ac4060519c9d63797f1
  • Pointer size: 131 Bytes
  • Size of remote file: 509 kB
docs/figures/reweight/visual_best-rendered0.png ADDED

Git LFS Details

  • SHA256: 47648174430fd29712507baaf9bfe9afa9c810258531c7d724c9bfbee82d642b
  • Pointer size: 130 Bytes
  • Size of remote file: 32.1 kB
docs/figures/reweight/visual_best-rendered1.png ADDED

Git LFS Details

  • SHA256: 83ec8e4b3a5df5229a2503c3b728f8bf1cf7dc7aea09b77570f908ace62951cc
  • Pointer size: 130 Bytes
  • Size of remote file: 30.7 kB
docs/figures/reweight/visual_best-rendered2.png ADDED

Git LFS Details

  • SHA256: 3acd23ddb8c6e77f6133cd91ebd0e624fa72101f0c964e11796041ebe5f31e38
  • Pointer size: 130 Bytes
  • Size of remote file: 35.4 kB
libs/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import lazy
2
+
3
+ __getattr__, __dir__, __all__ = lazy.attach(
4
+ __name__,
5
+ submodules={'engine', 'metric', 'modules', 'solver', 'utils'},
6
+ submod_attrs={}
7
+ )
8
+
9
+ __version__ = '0.0.1'
libs/engine/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .model_state import ModelState
2
+ from .config_processor import merge_and_update_config
3
+
4
+ __all__ = [
5
+ 'ModelState',
6
+ 'merge_and_update_config'
7
+ ]
libs/engine/config_processor.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+ from functools import reduce
4
+
5
+ from argparse import Namespace
6
+ from omegaconf import DictConfig, OmegaConf
7
+
8
+
9
+ #################################################################################
10
+ # merge yaml and argparse #
11
+ #################################################################################
12
+
13
+ def register_resolver():
14
+ OmegaConf.register_new_resolver(
15
+ "add", lambda *numbers: sum(numbers)
16
+ )
17
+ OmegaConf.register_new_resolver(
18
+ "multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers)
19
+ )
20
+ OmegaConf.register_new_resolver(
21
+ "sub", lambda n1, n2: n1 - n2
22
+ )
23
+
24
+
25
+ def _merge_args_and_config(
26
+ cmd_args: Namespace,
27
+ yaml_config: DictConfig,
28
+ read_only: bool = False
29
+ ) -> Tuple[DictConfig, DictConfig, DictConfig]:
30
+ # convert cmd line args to OmegaConf
31
+ cmd_args_dict = vars(cmd_args)
32
+ cmd_args_list = []
33
+ for k, v in cmd_args_dict.items():
34
+ cmd_args_list.append(f"{k}={v}")
35
+ cmd_args_conf = OmegaConf.from_cli(cmd_args_list)
36
+
37
+ # The following overrides the previous configuration
38
+ # cmd_args_list > configs
39
+ args_ = OmegaConf.merge(yaml_config, cmd_args_conf)
40
+
41
+ if read_only:
42
+ OmegaConf.set_readonly(args_, True)
43
+
44
+ return args_, cmd_args_conf, yaml_config
45
+
46
+
47
+ def merge_configs(args, method_cfg_path):
48
+ """merge command line args (argparse) and config file (OmegaConf)"""
49
+ yaml_config_path = os.path.join("./", "config", method_cfg_path)
50
+ try:
51
+ yaml_config = OmegaConf.load(yaml_config_path)
52
+ except FileNotFoundError as e:
53
+ print(f"error: {e}")
54
+ print(f"input file path: `{method_cfg_path}`")
55
+ print(f"config path: `{yaml_config_path}` not found.")
56
+ raise FileNotFoundError(e)
57
+ return _merge_args_and_config(args, yaml_config, read_only=False)
58
+
59
+
60
+ def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True):
61
+ """update config file (OmegaConf) with dotlist"""
62
+ if update_nodes is None:
63
+ return source_args
64
+
65
+ update_args_list = str(update_nodes).split()
66
+ if len(update_args_list) < 1:
67
+ return source_args
68
+
69
+ # check update_args
70
+ for item in update_args_list:
71
+ item_key_ = str(item).split('=')[0] # get key
72
+ # item_val_ = str(item).split('=')[1] # get value
73
+
74
+ if strict:
75
+ # Tests if a key is existing
76
+ # assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing."
77
+
78
+ # Tests if a value is missing
79
+ assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing."
80
+
81
+ # if keys is None, then add key and set the value
82
+ if OmegaConf.select(source_args, item_key_) is None:
83
+ source_args.item_key_ = item_key_
84
+
85
+ # update original yaml params
86
+ update_nodes = OmegaConf.from_dotlist(update_args_list)
87
+ merged_args = OmegaConf.merge(source_args, update_nodes)
88
+
89
+ # remove update_args
90
+ if remove_update_nodes:
91
+ OmegaConf.update(merged_args, 'update', '')
92
+ return merged_args
93
+
94
+
95
+ def update_if_exist(source_args, update_nodes):
96
+ """update config file (OmegaConf) with dotlist"""
97
+ if update_nodes is None:
98
+ return source_args
99
+
100
+ upd_args_list = str(update_nodes).split()
101
+ if len(upd_args_list) < 1:
102
+ return source_args
103
+
104
+ update_args_list = []
105
+ for item in upd_args_list:
106
+ item_key_ = str(item).split('=')[0] # get key
107
+
108
+ # if a key is existing
109
+ # if OmegaConf.select(source_args, item_key_) is not None:
110
+ # update_args_list.append(item)
111
+
112
+ update_args_list.append(item)
113
+
114
+ # update source_args if key be selected
115
+ if len(update_args_list) < 1:
116
+ merged_args = source_args
117
+ else:
118
+ update_nodes = OmegaConf.from_dotlist(update_args_list)
119
+ merged_args = OmegaConf.merge(source_args, update_nodes)
120
+
121
+ return merged_args
122
+
123
+
124
+ def merge_and_update_config(args):
125
+ register_resolver()
126
+
127
+ # if yaml_config is existing, then merge command line args and yaml_config
128
+ # if os.path.isfile(args.config) and args.config is not None:
129
+ if args.config is not None and str(args.config).endswith('.yaml'):
130
+ merged_args, cmd_args, yaml_config = merge_configs(args, args.config)
131
+ else:
132
+ merged_args, cmd_args, yaml_config = args, args, None
133
+
134
+ # update the yaml_config with the cmd '-update' flag
135
+ update_nodes = args.update
136
+ final_args = update_configs(merged_args, update_nodes)
137
+
138
+ # to simplify log output, we empty this
139
+ yaml_config_update = update_if_exist(yaml_config, update_nodes)
140
+ cmd_args_update = update_if_exist(cmd_args, update_nodes)
141
+ cmd_args_update.update = "" # clear update params
142
+
143
+ final_args.yaml_config = yaml_config_update
144
+ final_args.cmd_args = cmd_args_update
145
+
146
+ # update seed
147
+ if final_args.seed < 0:
148
+ import random
149
+ final_args.seed = random.randint(0, 65535)
150
+
151
+ return final_args
libs/engine/model_state.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+ from typing import Union, List
4
+ from pathlib import Path
5
+ from datetime import datetime, timedelta
6
+
7
+ from omegaconf import DictConfig
8
+ from pprint import pprint
9
+ import torch
10
+ from accelerate.utils import LoggerType
11
+ from accelerate import (
12
+ Accelerator,
13
+ GradScalerKwargs,
14
+ DistributedDataParallelKwargs,
15
+ InitProcessGroupKwargs
16
+ )
17
+
18
+ from ..modules.ema import EMA
19
+ from ..utils.logging import get_logger
20
+
21
+
22
+ class ModelState:
23
+ """
24
+ Handling logger and `hugging face` accelerate training
25
+
26
+ features:
27
+ - Mixed Precision
28
+ - Gradient Scaler
29
+ - Gradient Accumulation
30
+ - Optimizer
31
+ - EMA
32
+ - Logger (default: python print)
33
+ - Monitor (default: wandb, tensorboard)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ args,
39
+ log_path_suffix: str = None,
40
+ ignore_log=False, # whether to create log file or not
41
+ ) -> None:
42
+ self.args: DictConfig = args
43
+
44
+ """check valid"""
45
+ mixed_precision = self.args.get("mixed_precision")
46
+ # Bug: omegaconf convert 'no' to false
47
+ mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision
48
+ split_batches = self.args.get("split_batches", False)
49
+ gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1)
50
+ assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}"
51
+
52
+ """create working space"""
53
+ # rule: ['./config'. 'method_name', 'exp_name.yaml']
54
+ # -> results_path: ./runs/{method_name}-{exp_name}, as a base folder
55
+ # config_prefix, config_name = str(self.args.get("config")).split('/')
56
+ # config_name_only = str(config_name).split(".")[0]
57
+
58
+ config_name_only = str(self.args.get("config")).split(".")[0]
59
+ results_folder = self.args.get("results_path", None)
60
+ if results_folder is None:
61
+ # self.results_path = Path("./workdir") / f"{config_prefix}-{config_name_only}"
62
+ self.results_path = Path("./workdir")
63
+ else:
64
+ # self.results_path = Path(results_folder) / f"{config_prefix}-{config_name_only}"
65
+ self.results_path = Path(os.path.join(results_folder, self.args.get("edit_type"), ))
66
+
67
+ # update results_path: ./runs/{method_name}-{exp_name}/{log_path_suffix}
68
+ # noting: can be understood as "results dir / methods / ablation study / your result"
69
+ if log_path_suffix is not None:
70
+ self.results_path = self.results_path / log_path_suffix
71
+
72
+ kwargs_handlers = []
73
+ """mixed precision training"""
74
+ if args.mixed_precision == "no":
75
+ scaler_handler = GradScalerKwargs(
76
+ init_scale=args.init_scale,
77
+ growth_factor=args.growth_factor,
78
+ backoff_factor=args.backoff_factor,
79
+ growth_interval=args.growth_interval,
80
+ enabled=True
81
+ )
82
+ kwargs_handlers.append(scaler_handler)
83
+
84
+ """distributed training"""
85
+ ddp_handler = DistributedDataParallelKwargs(
86
+ dim=0,
87
+ broadcast_buffers=True,
88
+ static_graph=False,
89
+ bucket_cap_mb=25,
90
+ find_unused_parameters=False,
91
+ check_reduction=False,
92
+ gradient_as_bucket_view=False
93
+ )
94
+ kwargs_handlers.append(ddp_handler)
95
+
96
+ init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200))
97
+ kwargs_handlers.append(init_handler)
98
+
99
+ """init visualized tracker"""
100
+ log_with = []
101
+ self.args.visual = False
102
+ if args.use_wandb:
103
+ log_with.append(LoggerType.WANDB)
104
+ if args.tensorboard:
105
+ log_with.append(LoggerType.TENSORBOARD)
106
+
107
+ """hugging face Accelerator"""
108
+ self.accelerator = Accelerator(
109
+ device_placement=True,
110
+ split_batches=split_batches,
111
+ mixed_precision=mixed_precision,
112
+ gradient_accumulation_steps=args.gradient_accumulate_step,
113
+ cpu=True if args.use_cpu else False,
114
+ log_with=None if len(log_with) == 0 else log_with,
115
+ project_dir=self.results_path / "vis",
116
+ kwargs_handlers=kwargs_handlers,
117
+ )
118
+
119
+ """logs"""
120
+ if self.accelerator.is_local_main_process:
121
+ # for logging results in a folder periodically
122
+ self.results_path.mkdir(parents=True, exist_ok=True)
123
+ if not ignore_log:
124
+ now_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
125
+ # self.logger = get_logger(
126
+ # logs_dir=self.results_path.as_posix(),
127
+ # file_name=f"log.txt"
128
+ # )
129
+
130
+ print("==> command line args: ")
131
+ print(args.cmd_args)
132
+ print("==> yaml config args: ")
133
+ print(args.yaml_config)
134
+
135
+ print("\n***** Model State *****")
136
+ if self.accelerator.distributed_type != "NO":
137
+ print(f"-> Distributed Type: {self.accelerator.distributed_type}")
138
+ print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}")
139
+ print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp},"
140
+ f" Gradient Accumulate Step: {gradient_accumulate_step}")
141
+ print(f"-> Weight dtype: {self.weight_dtype}")
142
+
143
+ if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled:
144
+ print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}")
145
+
146
+ if args.use_wandb:
147
+ print(f"-> Init trackers: 'wandb' ")
148
+ self.args.visual = True
149
+ self.__init_tracker(project_name="my_project", tags=None, entity="")
150
+
151
+ print(f"-> Working Space: '{self.results_path}'")
152
+
153
+ """EMA"""
154
+ self.use_ema = args.get('ema', False)
155
+ self.ema_wrapper = self.__build_ema_wrapper()
156
+
157
+ """glob step"""
158
+ self.step = 0
159
+
160
+ """log process"""
161
+ self.accelerator.wait_for_everyone()
162
+ print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}')
163
+
164
+ self.print("-> state initialization complete \n")
165
+
166
+ def __init_tracker(self, project_name, tags, entity):
167
+ self.accelerator.init_trackers(
168
+ project_name=project_name,
169
+ config=dict(self.args),
170
+ init_kwargs={
171
+ "wandb": {
172
+ "notes": "accelerate trainer pipeline",
173
+ "tags": [
174
+ f"total batch_size: {self.actual_batch_size}"
175
+ ],
176
+ "entity": entity,
177
+ }}
178
+ )
179
+
180
+ def __build_ema_wrapper(self):
181
+ if self.use_ema:
182
+ self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, "
183
+ f"update_after_step: {self.args.ema_update_after_step}, "
184
+ f"update_every: {self.args.ema_update_every}")
185
+ ema_wrapper = partial(
186
+ EMA, beta=self.args.ema_decay,
187
+ update_after_step=self.args.ema_update_after_step,
188
+ update_every=self.args.ema_update_every
189
+ )
190
+ else:
191
+ ema_wrapper = None
192
+
193
+ return ema_wrapper
194
+
195
+ @property
196
+ def device(self):
197
+ return self.accelerator.device
198
+
199
+ @property
200
+ def weight_dtype(self):
201
+ weight_dtype = torch.float32
202
+ if self.accelerator.mixed_precision == "fp16":
203
+ weight_dtype = torch.float16
204
+ elif self.accelerator.mixed_precision == "bf16":
205
+ weight_dtype = torch.bfloat16
206
+ return weight_dtype
207
+
208
+ @property
209
+ def actual_batch_size(self):
210
+ if self.accelerator.split_batches is False:
211
+ actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps
212
+ else:
213
+ assert self.actual_batch_size % self.accelerator.num_processes == 0
214
+ actual_batch_size = self.args.batch_size
215
+ return actual_batch_size
216
+
217
+ @property
218
+ def n_gpus(self):
219
+ return self.accelerator.num_processes
220
+
221
+ @property
222
+ def no_decay_params_names(self):
223
+ no_decay = [
224
+ "bn", "LayerNorm", "GroupNorm",
225
+ ]
226
+ return no_decay
227
+
228
+ def no_decay_params(self, model, weight_decay):
229
+ """optimization tricks"""
230
+ optimizer_grouped_parameters = [
231
+ {
232
+ "params": [
233
+ p for n, p in model.named_parameters()
234
+ if not any(nd in n for nd in self.no_decay_params_names)
235
+ ],
236
+ "weight_decay": weight_decay,
237
+ },
238
+ {
239
+ "params": [
240
+ p for n, p in model.named_parameters()
241
+ if any(nd in n for nd in self.no_decay_params_names)
242
+ ],
243
+ "weight_decay": 0.0,
244
+ },
245
+ ]
246
+ return optimizer_grouped_parameters
247
+
248
+ def optimized_params(self, model: torch.nn.Module, verbose=True) -> List:
249
+ """return parameters if `requires_grad` is True
250
+
251
+ Args:
252
+ model: pytorch models
253
+ verbose: log optimized parameters
254
+
255
+ Examples:
256
+ >>> self.params_optimized = self.optimized_params(uvit, verbose=True)
257
+ >>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr)
258
+
259
+ Returns:
260
+ a list of parameters
261
+ """
262
+ params_optimized = []
263
+ for key, value in model.named_parameters():
264
+ if value.requires_grad:
265
+ params_optimized.append(value)
266
+ if verbose:
267
+ self.print("\t {}, {}, {}".format(key, value.numel(), value.shape))
268
+ return params_optimized
269
+
270
+ def save_everything(self, fpath: str):
271
+ """Saving and loading the model, optimizer, RNG generators, and the GradScaler."""
272
+ if not self.accelerator.is_main_process:
273
+ return
274
+ self.accelerator.save_state(fpath)
275
+
276
+ def load_save_everything(self, fpath: str):
277
+ """Loading the model, optimizer, RNG generators, and the GradScaler."""
278
+ self.accelerator.load_state(fpath)
279
+
280
+ def save(self, milestone: Union[str, float, int], checkpoint: object) -> None:
281
+ if not self.accelerator.is_main_process:
282
+ return
283
+
284
+ torch.save(checkpoint, self.results_path / f'model-{milestone}.pt')
285
+
286
+ def save_in(self, root: Union[str, Path], checkpoint: object) -> None:
287
+ if not self.accelerator.is_main_process:
288
+ return
289
+
290
+ torch.save(checkpoint, root)
291
+
292
+ def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False):
293
+ ckpt = torch.load(path, map_location=self.accelerator.device)
294
+
295
+ unwrapped_model = self.accelerator.unwrap_model(model)
296
+ if rm_module_prefix:
297
+ unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
298
+ else:
299
+ unwrapped_model.load_state_dict(ckpt)
300
+ return unwrapped_model
301
+
302
+ def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]):
303
+ ckpt = torch.load(path, map_location=self.accelerator.device)
304
+ self.print(f"pretrained_dict len: {len(ckpt)}")
305
+ unwrapped_model = self.accelerator.unwrap_model(model)
306
+ model_dict = unwrapped_model.state_dict()
307
+ pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
308
+ model_dict.update(pretrained_dict)
309
+ unwrapped_model.load_state_dict(model_dict, strict=False)
310
+ self.print(f"selected pretrained_dict: {len(model_dict)}")
311
+ return unwrapped_model
312
+
313
+ def print(self, *args, **kwargs):
314
+ """Use in replacement of `print()` to only print once per server."""
315
+ self.accelerator.print(*args, **kwargs)
316
+
317
+ def pretty_print(self, msg):
318
+ if self.accelerator.is_local_main_process:
319
+ pprint(dict(msg))
320
+
321
+ def close_tracker(self):
322
+ self.accelerator.end_training()
323
+
324
+ def free_memory(self):
325
+ self.accelerator.clear()
326
+
327
+ def close(self, msg: str = "Training complete."):
328
+ """Use in end of training."""
329
+ self.free_memory()
330
+
331
+ if torch.cuda.is_available():
332
+ self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
333
+ if self.args.visual:
334
+ self.close_tracker()
335
+ self.print(msg)
libs/metric/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
libs/metric/accuracy.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def accuracy(output, target, topk=(1,)):
2
+ """
3
+ Computes the accuracy over the k top predictions for the specified values of k.
4
+
5
+ Args
6
+ output: logits or probs (num of batch, num of classes)
7
+ target: (num of batch, 1) or (num of batch, )
8
+ topk: list of returned k
9
+
10
+ refer: https://github.com/pytorch/examples/blob/master/imagenet/main.py
11
+ """
12
+ maxK = max(topk) # get k in top-k
13
+ batch_size = target.size(0)
14
+
15
+ _, pred = output.topk(k=maxK, dim=1, largest=True, sorted=True) # pred: [num of batch, k]
16
+ pred = pred.t() # pred: [k, num of batch]
17
+
18
+ # [1, num of batch] -> [k, num_of_batch] : bool
19
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
20
+
21
+ res = []
22
+ for k in topk:
23
+ correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
24
+ res.append(correct_k.mul_(100.0 / batch_size))
25
+ return res # np.shape(res): [k, 1]
libs/metric/clip_score/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .openaiCLIP_loss import CLIPScoreWrapper
2
+
3
+ __all__ = ['CLIPScoreWrapper']
libs/metric/clip_score/openaiCLIP_loss.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Tuple
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision.transforms as transforms
9
+
10
+
11
+ class CLIPScoreWrapper(nn.Module):
12
+
13
+ def __init__(self,
14
+ clip_model_name: str,
15
+ download_root: str = None,
16
+ device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
17
+ jit: bool = False,
18
+ # additional params
19
+ visual_score: bool = False,
20
+ feats_loss_type: str = None,
21
+ feats_loss_weights: List[float] = None,
22
+ fc_loss_weight: float = None,
23
+ context_length: int = 77):
24
+ super().__init__()
25
+
26
+ import clip # local import
27
+
28
+ # check model info
29
+ self.clip_model_name = clip_model_name
30
+ self.device = device
31
+ self.available_models = clip.available_models()
32
+ assert clip_model_name in self.available_models, f"A model backbone: {clip_model_name} that does not exist"
33
+
34
+ # load CLIP
35
+ self.model, self.preprocess = clip.load(clip_model_name, device, jit=jit, download_root=download_root)
36
+ self.model.eval()
37
+
38
+ # load tokenize
39
+ self.tokenize_fn = partial(clip.tokenize, context_length=context_length)
40
+
41
+ # load CLIP visual
42
+ self.visual_encoder = VisualEncoderWrapper(self.model, clip_model_name).to(device)
43
+ self.visual_encoder.eval()
44
+
45
+ # check loss weights
46
+ self.visual_score = visual_score
47
+ if visual_score:
48
+ assert feats_loss_type in ["l1", "l2", "cosine"], f"{feats_loss_type} is not exist."
49
+ if clip_model_name.startswith("ViT"): assert len(feats_loss_weights) == 12
50
+ if clip_model_name.startswith("RN"): assert len(feats_loss_weights) == 5
51
+
52
+ # load visual loss wrapper
53
+ self.visual_loss_fn = CLIPVisualLossWrapper(self.visual_encoder, feats_loss_type,
54
+ feats_loss_weights,
55
+ fc_loss_weight)
56
+
57
+ @property
58
+ def input_resolution(self):
59
+ return self.model.visual.input_resolution # default: 224
60
+
61
+ @property
62
+ def resize(self): # Resize only
63
+ return transforms.Compose([self.preprocess.transforms[0]])
64
+
65
+ @property
66
+ def normalize(self):
67
+ return transforms.Compose([
68
+ self.preprocess.transforms[0], # Resize
69
+ self.preprocess.transforms[1], # CenterCrop
70
+ self.preprocess.transforms[-1], # Normalize
71
+ ])
72
+
73
+ @property
74
+ def norm_(self): # Normalize only
75
+ return transforms.Compose([self.preprocess.transforms[-1]])
76
+
77
+ def encode_image_layer_wise(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
78
+ semantic_vec, feature_maps = self.visual_encoder(x)
79
+ return semantic_vec, feature_maps
80
+
81
+ def encode_text(self, text: Union[str, List[str]], norm: bool = True) -> torch.Tensor:
82
+ tokens = self.tokenize_fn(text).to(self.device)
83
+ text_features = self.model.encode_text(tokens)
84
+ if norm:
85
+ text_features = text_features.mean(axis=0, keepdim=True)
86
+ text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
87
+ return text_features_norm
88
+ return text_features
89
+
90
+ def encode_image(self, image: torch.Tensor, norm: bool = True) -> torch.Tensor:
91
+ image_features = self.model.encode_image(image)
92
+ if norm:
93
+ image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
94
+ return image_features_norm
95
+ return image_features
96
+
97
+ @torch.no_grad()
98
+ def predict(self,
99
+ image: torch.Tensor,
100
+ text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
101
+ image_features = self.model.encode_image(image)
102
+ text_tokenize = self.tokenize_fn(text).to(self.device)
103
+ text_features = self.model.encode_text(text_tokenize)
104
+ logits_per_image, logits_per_text = self.model(image, text)
105
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
106
+ return image_features, text_features, probs
107
+
108
+ def compute_text_visual_distance(
109
+ self, image: torch.Tensor, text: Union[str, List[str]]
110
+ ) -> torch.Tensor:
111
+ image_features = self.model.encode_image(image)
112
+ text_tokenize = self.tokenize_fn(text).to(self.device)
113
+ with torch.no_grad():
114
+ text_features = self.model.encode_text(text_tokenize)
115
+
116
+ image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
117
+ text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
118
+ # loss = - (image_features_norm @ text_features_norm.T)
119
+ loss = 1 - torch.cosine_similarity(image_features_norm, text_features_norm, dim=1)
120
+ return loss.mean()
121
+
122
+ def directional_text_visual_distance(self, src_text, src_img, tar_text, tar_img):
123
+ src_image_features = self.model.encode_image(src_img).detach()
124
+ tar_image_features = self.model.encode_image(tar_img)
125
+ src_text_tokenize = self.tokenize_fn(src_text).to(self.device)
126
+ tar_text_tokenize = self.tokenize_fn(tar_text).to(self.device)
127
+ with torch.no_grad():
128
+ src_text_features = self.model.encode_text(src_text_tokenize)
129
+ tar_text_features = self.model.encode_text(tar_text_tokenize)
130
+
131
+ delta_image_features = tar_image_features - src_image_features
132
+ delta_text_features = tar_text_features - src_text_features
133
+
134
+ # # avold zero divisor
135
+ # delta_image_features_norm = delta_image_features / delta_image_features.norm(dim=-1, keepdim=True)
136
+ # delta_text_features_norm = delta_text_features / delta_text_features.norm(dim=-1, keepdim=True)
137
+
138
+ loss = 1 - torch.cosine_similarity(delta_image_features, delta_text_features, dim=1, eps=1e-3)
139
+ return loss.mean()
140
+
141
+ def compute_visual_distance(
142
+ self, x: torch.Tensor, y: torch.Tensor, clip_norm: bool = True,
143
+ ) -> Tuple[torch.Tensor, List]:
144
+ # return a fc loss and the list of feat loss
145
+ assert self.visual_score is True
146
+ assert x.shape == y.shape
147
+ assert x.shape[-1] == self.input_resolution and x.shape[-2] == self.input_resolution
148
+ assert y.shape[-1] == self.input_resolution and y.shape[-2] == self.input_resolution
149
+
150
+ if clip_norm:
151
+ return self.visual_loss_fn(self.normalize(x), self.normalize(y))
152
+ else:
153
+ return self.visual_loss_fn(x, y)
154
+
155
+
156
+ class VisualEncoderWrapper(nn.Module):
157
+ """
158
+ semantic features and layer by layer feature maps are obtained from CLIP visual encoder.
159
+ """
160
+
161
+ def __init__(self, clip_model: nn.Module, clip_model_name: str):
162
+ super().__init__()
163
+ self.clip_model = clip_model
164
+ self.clip_model_name = clip_model_name
165
+
166
+ if clip_model_name.startswith("ViT"):
167
+ self.feature_maps = OrderedDict()
168
+ for i in range(12): # 12 ResBlocks in ViT visual transformer
169
+ self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
170
+ self.make_hook(i)
171
+ )
172
+
173
+ if clip_model_name.startswith("RN"):
174
+ layers = list(self.clip_model.visual.children())
175
+ init_layers = torch.nn.Sequential(*layers)[:8]
176
+ self.layer1 = layers[8]
177
+ self.layer2 = layers[9]
178
+ self.layer3 = layers[10]
179
+ self.layer4 = layers[11]
180
+ self.att_pool2d = layers[12]
181
+
182
+ def make_hook(self, name):
183
+ def hook(module, input, output):
184
+ if len(output.shape) == 3:
185
+ # LND -> NLD (B, 77, 768)
186
+ self.feature_maps[name] = output.permute(1, 0, 2)
187
+ else:
188
+ self.feature_maps[name] = output
189
+
190
+ return hook
191
+
192
+ def _forward_vit(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
193
+ fc_feature = self.clip_model.encode_image(x).float()
194
+ feature_maps = [self.feature_maps[k] for k in range(12)]
195
+
196
+ # fc_feature len: 1 ,feature_maps len: 12
197
+ return fc_feature, feature_maps
198
+
199
+ def _forward_resnet(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
200
+ def stem(m, x):
201
+ for conv, bn, relu in [(m.conv1, m.bn1, m.relu1), (m.conv2, m.bn2, m.relu2), (m.conv3, m.bn3, m.relu3)]:
202
+ x = torch.relu(bn(conv(x)))
203
+ x = m.avgpool(x)
204
+ return x
205
+
206
+ x = x.type(self.clip_model.visual.conv1.weight.dtype)
207
+ x = stem(self.clip_model.visual, x)
208
+ x1 = self.layer1(x)
209
+ x2 = self.layer2(x1)
210
+ x3 = self.layer3(x2)
211
+ x4 = self.layer4(x3)
212
+ y = self.att_pool2d(x4)
213
+
214
+ # fc_features len: 1 ,feature_maps len: 5
215
+ return y, [x, x1, x2, x3, x4]
216
+
217
+ def forward(self, x) -> Tuple[torch.Tensor, List[torch.Tensor]]:
218
+ if self.clip_model_name.startswith("ViT"):
219
+ fc_feat, visual_feat_maps = self._forward_vit(x)
220
+ if self.clip_model_name.startswith("RN"):
221
+ fc_feat, visual_feat_maps = self._forward_resnet(x)
222
+
223
+ return fc_feat, visual_feat_maps
224
+
225
+
226
+ class CLIPVisualLossWrapper(nn.Module):
227
+ """
228
+ Visual Feature Loss + FC loss
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ visual_encoder: nn.Module,
234
+ feats_loss_type: str = None,
235
+ feats_loss_weights: List[float] = None,
236
+ fc_loss_weight: float = None,
237
+ ):
238
+ super().__init__()
239
+ self.visual_encoder = visual_encoder
240
+ self.feats_loss_weights = feats_loss_weights
241
+ self.fc_loss_weight = fc_loss_weight
242
+
243
+ self.layer_criterion = layer_wise_distance(feats_loss_type)
244
+
245
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
246
+ x_fc_feature, x_feat_maps = self.visual_encoder(x)
247
+ y_fc_feature, y_feat_maps = self.visual_encoder(y)
248
+
249
+ # visual feature loss
250
+ if sum(self.feats_loss_weights) == 0:
251
+ feats_loss_list = [torch.tensor(0, device=x.device)]
252
+ else:
253
+ feats_loss = self.layer_criterion(x_feat_maps, y_feat_maps, self.visual_encoder.clip_model_name)
254
+ feats_loss_list = []
255
+ for layer, w in enumerate(self.feats_loss_weights):
256
+ if w:
257
+ feats_loss_list.append(feats_loss[layer] * w)
258
+
259
+ # visual fc loss, default: cosine similarity
260
+ if self.fc_loss_weight == 0:
261
+ fc_loss = torch.tensor(0, device=x.device)
262
+ else:
263
+ fc_loss = (1 - torch.cosine_similarity(x_fc_feature, y_fc_feature, dim=1)).mean()
264
+ fc_loss = fc_loss * self.fc_loss_weight
265
+
266
+ return fc_loss, feats_loss_list
267
+
268
+
269
+ #################################################################################
270
+ # layer wise metric #
271
+ #################################################################################
272
+
273
+ def layer_wise_distance(metric_name: str):
274
+ return {
275
+ "l1": l1_layer_wise,
276
+ "l2": l2_layer_wise,
277
+ "cosine": cosine_layer_wise
278
+ }.get(metric_name.lower())
279
+
280
+
281
+ def l2_layer_wise(x_features, y_features, clip_model_name):
282
+ return [
283
+ torch.square(x_conv - y_conv).mean()
284
+ for x_conv, y_conv in zip(x_features, y_features)
285
+ ]
286
+
287
+
288
+ def l1_layer_wise(x_features, y_features, clip_model_name):
289
+ return [
290
+ torch.abs(x_conv - y_conv).mean()
291
+ for x_conv, y_conv in zip(x_features, y_features)
292
+ ]
293
+
294
+
295
+ def cosine_layer_wise(x_features, y_features, clip_model_name):
296
+ if clip_model_name.startswith("RN"):
297
+ return [
298
+ (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
299
+ for x_conv, y_conv in zip(x_features, y_features)
300
+ ]
301
+ return [
302
+ (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
303
+ for x_conv, y_conv in zip(x_features, y_features)
304
+ ]
libs/metric/lpips_origin/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .lpips import LPIPS
2
+
3
+ __all__ = ['LPIPS']
libs/metric/lpips_origin/lpips.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import os
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from . import pretrained_networks as pretrained_torch_models
9
+
10
+
11
+ def spatial_average(x, keepdim=True):
12
+ return x.mean([2, 3], keepdim=keepdim)
13
+
14
+
15
+ def upsample(x):
16
+ return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x)
17
+
18
+
19
+ def normalize_tensor(in_feat, eps=1e-10):
20
+ norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
21
+ return in_feat / (norm_factor + eps)
22
+
23
+
24
+ # Learned perceptual metric
25
+ class LPIPS(nn.Module):
26
+
27
+ def __init__(self,
28
+ pretrained=True,
29
+ net='alex',
30
+ version='0.1',
31
+ lpips=True,
32
+ spatial=False,
33
+ pnet_rand=False,
34
+ pnet_tune=False,
35
+ use_dropout=True,
36
+ model_path=None,
37
+ eval_mode=True,
38
+ verbose=True):
39
+ """ Initializes a perceptual loss torch.nn.Module
40
+
41
+ Parameters (default listed first)
42
+ ---------------------------------
43
+ lpips : bool
44
+ [True] use linear layers on top of base/trunk network
45
+ [False] means no linear layers; each layer is averaged together
46
+ pretrained : bool
47
+ This flag controls the linear layers, which are only in effect when lpips=True above
48
+ [True] means linear layers are calibrated with human perceptual judgments
49
+ [False] means linear layers are randomly initialized
50
+ pnet_rand : bool
51
+ [False] means trunk loaded with ImageNet classification weights
52
+ [True] means randomly initialized trunk
53
+ net : str
54
+ ['alex','vgg','squeeze'] are the base/trunk networks available
55
+ version : str
56
+ ['v0.1'] is the default and latest
57
+ ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
58
+ model_path : 'str'
59
+ [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
60
+
61
+ The following parameters should only be changed if training the network:
62
+
63
+ eval_mode : bool
64
+ [True] is for test mode (default)
65
+ [False] is for training mode
66
+ pnet_tune
67
+ [False] keep base/trunk frozen
68
+ [True] tune the base/trunk network
69
+ use_dropout : bool
70
+ [True] to use dropout when training linear layers
71
+ [False] for no dropout when training linear layers
72
+ """
73
+ super(LPIPS, self).__init__()
74
+ if verbose:
75
+ print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %
76
+ ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
77
+
78
+ self.pnet_type = net
79
+ self.pnet_tune = pnet_tune
80
+ self.pnet_rand = pnet_rand
81
+ self.spatial = spatial
82
+ self.lpips = lpips # false means baseline of just averaging all layers
83
+ self.version = version
84
+ self.scaling_layer = ScalingLayer()
85
+
86
+ if self.pnet_type in ['vgg', 'vgg16']:
87
+ net_type = pretrained_torch_models.vgg16
88
+ self.chns = [64, 128, 256, 512, 512]
89
+ elif self.pnet_type == 'alex':
90
+ net_type = pretrained_torch_models.alexnet
91
+ self.chns = [64, 192, 384, 256, 256]
92
+ elif self.pnet_type == 'squeeze':
93
+ net_type = pretrained_torch_models.squeezenet
94
+ self.chns = [64, 128, 256, 384, 384, 512, 512]
95
+ self.L = len(self.chns)
96
+
97
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
98
+
99
+ if lpips:
100
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
101
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
102
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
103
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
104
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
105
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
106
+ if self.pnet_type == 'squeeze': # 7 layers for squeezenet
107
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
108
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
109
+ self.lins += [self.lin5, self.lin6]
110
+ self.lins = nn.ModuleList(self.lins)
111
+
112
+ if pretrained:
113
+ if model_path is None:
114
+ model_path = os.path.join(
115
+ os.path.dirname(os.path.abspath(__file__)),
116
+ f"weights/v{version}/{net}.pth"
117
+ )
118
+ if verbose:
119
+ print('Loading model from: %s' % model_path)
120
+ self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
121
+
122
+ if eval_mode:
123
+ self.eval()
124
+
125
+ def forward(self, in0, in1, return_per_layer=False, normalize=False):
126
+ if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, 1]
127
+ in0 = 2 * in0 - 1
128
+ in1 = 2 * in1 - 1
129
+
130
+ # Noting: v0.0 - original release had a bug, where input was not scaled
131
+ if self.version == '0.1':
132
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1))
133
+ else:
134
+ in0_input, in1_input = in0, in1
135
+
136
+ # model forward
137
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
138
+
139
+ feats0, feats1, diffs = {}, {}, {}
140
+ for kk in range(self.L):
141
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
142
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
143
+
144
+ if self.lpips:
145
+ if self.spatial:
146
+ res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)]
147
+ else:
148
+ res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
149
+ else:
150
+ if self.spatial:
151
+ res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)]
152
+ else:
153
+ res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
154
+
155
+ loss = sum(res)
156
+
157
+ if return_per_layer:
158
+ return loss, res
159
+ else:
160
+ return loss
161
+
162
+
163
+ class ScalingLayer(nn.Module):
164
+ def __init__(self):
165
+ super(ScalingLayer, self).__init__()
166
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
167
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
168
+
169
+ def forward(self, inp):
170
+ return (inp - self.shift) / self.scale
171
+
172
+
173
+ class NetLinLayer(nn.Module):
174
+ """A single linear layer which does a 1x1 conv"""
175
+
176
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
177
+ super(NetLinLayer, self).__init__()
178
+
179
+ layers = [nn.Dropout(), ] if (use_dropout) else []
180
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
181
+ self.model = nn.Sequential(*layers)
182
+
183
+ def forward(self, x):
184
+ return self.model(x)
libs/metric/lpips_origin/pretrained_networks.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ import torchvision.models as tv_models
5
+
6
+
7
+ class squeezenet(torch.nn.Module):
8
+ def __init__(self, requires_grad=False, pretrained=True):
9
+ super(squeezenet, self).__init__()
10
+ pretrained_features = tv_models.squeezenet1_1(weights=pretrained).features
11
+ self.slice1 = torch.nn.Sequential()
12
+ self.slice2 = torch.nn.Sequential()
13
+ self.slice3 = torch.nn.Sequential()
14
+ self.slice4 = torch.nn.Sequential()
15
+ self.slice5 = torch.nn.Sequential()
16
+ self.slice6 = torch.nn.Sequential()
17
+ self.slice7 = torch.nn.Sequential()
18
+ self.N_slices = 7
19
+ for x in range(2):
20
+ self.slice1.add_module(str(x), pretrained_features[x])
21
+ for x in range(2, 5):
22
+ self.slice2.add_module(str(x), pretrained_features[x])
23
+ for x in range(5, 8):
24
+ self.slice3.add_module(str(x), pretrained_features[x])
25
+ for x in range(8, 10):
26
+ self.slice4.add_module(str(x), pretrained_features[x])
27
+ for x in range(10, 11):
28
+ self.slice5.add_module(str(x), pretrained_features[x])
29
+ for x in range(11, 12):
30
+ self.slice6.add_module(str(x), pretrained_features[x])
31
+ for x in range(12, 13):
32
+ self.slice7.add_module(str(x), pretrained_features[x])
33
+ if not requires_grad:
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+ def forward(self, X):
38
+ h = self.slice1(X)
39
+ h_relu1 = h
40
+ h = self.slice2(h)
41
+ h_relu2 = h
42
+ h = self.slice3(h)
43
+ h_relu3 = h
44
+ h = self.slice4(h)
45
+ h_relu4 = h
46
+ h = self.slice5(h)
47
+ h_relu5 = h
48
+ h = self.slice6(h)
49
+ h_relu6 = h
50
+ h = self.slice7(h)
51
+ h_relu7 = h
52
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
53
+ out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
54
+
55
+ return out
56
+
57
+
58
+ class alexnet(torch.nn.Module):
59
+ def __init__(self, requires_grad=False, pretrained=True):
60
+ super(alexnet, self).__init__()
61
+ weights = tv_models.AlexNet_Weights.IMAGENET1K_V1 if pretrained else None
62
+ alexnet_pretrained_features = tv_models.alexnet(weights=weights).features
63
+ self.slice1 = torch.nn.Sequential()
64
+ self.slice2 = torch.nn.Sequential()
65
+ self.slice3 = torch.nn.Sequential()
66
+ self.slice4 = torch.nn.Sequential()
67
+ self.slice5 = torch.nn.Sequential()
68
+ self.N_slices = 5
69
+ for x in range(2):
70
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(2, 5):
72
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(5, 8):
74
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(8, 10):
76
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
77
+ for x in range(10, 12):
78
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
79
+
80
+ if not requires_grad:
81
+ for param in self.parameters():
82
+ param.requires_grad = False
83
+
84
+ def forward(self, X):
85
+ h = self.slice1(X)
86
+ h_relu1 = h
87
+ h = self.slice2(h)
88
+ h_relu2 = h
89
+ h = self.slice3(h)
90
+ h_relu3 = h
91
+ h = self.slice4(h)
92
+ h_relu4 = h
93
+ h = self.slice5(h)
94
+ h_relu5 = h
95
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
96
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
97
+
98
+ return out
99
+
100
+
101
+ class vgg16(torch.nn.Module):
102
+ def __init__(self, requires_grad=False, pretrained=True):
103
+ super(vgg16, self).__init__()
104
+ weights = tv_models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None
105
+ vgg_pretrained_features = tv_models.vgg16(weights=weights).features
106
+ self.slice1 = torch.nn.Sequential()
107
+ self.slice2 = torch.nn.Sequential()
108
+ self.slice3 = torch.nn.Sequential()
109
+ self.slice4 = torch.nn.Sequential()
110
+ self.slice5 = torch.nn.Sequential()
111
+ self.N_slices = 5
112
+ for x in range(4):
113
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
114
+ for x in range(4, 9):
115
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
116
+ for x in range(9, 16):
117
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
118
+ for x in range(16, 23):
119
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
120
+ for x in range(23, 30):
121
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
122
+
123
+ if not requires_grad:
124
+ for param in self.parameters():
125
+ param.requires_grad = False
126
+
127
+ def forward(self, X):
128
+ h = self.slice1(X)
129
+ h_relu1_2 = h
130
+ h = self.slice2(h)
131
+ h_relu2_2 = h
132
+ h = self.slice3(h)
133
+ h_relu3_3 = h
134
+ h = self.slice4(h)
135
+ h_relu4_3 = h
136
+ h = self.slice5(h)
137
+ h_relu5_3 = h
138
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
139
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
140
+
141
+ return out
142
+
143
+
144
+ class resnet(torch.nn.Module):
145
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
146
+ super(resnet, self).__init__()
147
+
148
+ if num == 18:
149
+ weights = tv_models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
150
+ self.net = tv_models.resnet18(weights=weights)
151
+ elif num == 34:
152
+ weights = tv_models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
153
+ self.net = tv_models.resnet34(weights=weights)
154
+ elif num == 50:
155
+ weights = tv_models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
156
+ self.net = tv_models.resnet50(weights=weights)
157
+ elif num == 101:
158
+ weights = tv_models.ResNet101_Weights.IMAGENET1K_V2 if pretrained else None
159
+ self.net = tv_models.resnet101(weights=weights)
160
+ elif num == 152:
161
+ weights = tv_models.ResNet152_Weights.IMAGENET1K_V2 if pretrained else None
162
+ self.net = tv_models.resnet152(weights=weights)
163
+ self.N_slices = 5
164
+
165
+ if not requires_grad:
166
+ for param in self.net.parameters():
167
+ param.requires_grad = False
168
+
169
+ self.conv1 = self.net.conv1
170
+ self.bn1 = self.net.bn1
171
+ self.relu = self.net.relu
172
+ self.maxpool = self.net.maxpool
173
+ self.layer1 = self.net.layer1
174
+ self.layer2 = self.net.layer2
175
+ self.layer3 = self.net.layer3
176
+ self.layer4 = self.net.layer4
177
+
178
+ def forward(self, X):
179
+ h = self.conv1(X)
180
+ h = self.bn1(h)
181
+ h = self.relu(h)
182
+ h_relu1 = h
183
+ h = self.maxpool(h)
184
+ h = self.layer1(h)
185
+ h_conv2 = h
186
+ h = self.layer2(h)
187
+ h_conv3 = h
188
+ h = self.layer3(h)
189
+ h_conv4 = h
190
+ h = self.layer4(h)
191
+ h_conv5 = h
192
+
193
+ outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
194
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
195
+
196
+ return out
libs/metric/lpips_origin/weights/v0.1/alex.pth ADDED
Binary file (6.01 kB). View file
 
libs/metric/lpips_origin/weights/v0.1/squeeze.pth ADDED
Binary file (10.8 kB). View file
 
libs/metric/lpips_origin/weights/v0.1/vgg.pth ADDED
Binary file (7.29 kB). View file
 
libs/metric/piq/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # install: pip install piq
2
+ # repo: https://github.com/photosynthesis-team/piq
libs/metric/piq/functional/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex, crop_patches
2
+ from .colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm
3
+ from .filters import haar_filter, hann_filter, scharr_filter, prewitt_filter, gaussian_filter
4
+ from .filters import binomial_filter1d, average_filter2d
5
+ from .layers import L2Pool2d
6
+ from .resize import imresize
7
+
8
+ __all__ = [
9
+ 'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex', 'crop_patches',
10
+ 'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm',
11
+ 'haar_filter', 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter',
12
+ 'binomial_filter1d', 'average_filter2d',
13
+ 'L2Pool2d',
14
+ 'imresize',
15
+ ]
libs/metric/piq/functional/base.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""General purpose functions"""
2
+ from typing import Tuple, Union, Optional
3
+ import torch
4
+ from ..utils import _parse_version
5
+
6
+
7
+ def ifftshift(x: torch.Tensor) -> torch.Tensor:
8
+ r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors"""
9
+ shift = [-(ax // 2) for ax in x.size()]
10
+ return torch.roll(x, shift, tuple(range(len(shift))))
11
+
12
+
13
+ def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
14
+ r"""Return coordinate grid matrices centered at zero point.
15
+ Args:
16
+ size: Shape of meshgrid to create
17
+ device: device to use for creation
18
+ dtype: dtype to use for creation
19
+ Returns:
20
+ Meshgrid of size on device with dtype values.
21
+ """
22
+ if size[0] % 2:
23
+ # Odd
24
+ x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1)
25
+ else:
26
+ # Even
27
+ x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0]
28
+
29
+ if size[1] % 2:
30
+ # Odd
31
+ y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1)
32
+ else:
33
+ # Even
34
+ y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1]
35
+ # Use indexing param depending on torch version
36
+ recommended_torch_version = _parse_version("1.10.0")
37
+ torch_version = _parse_version(torch.__version__)
38
+ if len(torch_version) > 0 and torch_version >= recommended_torch_version:
39
+ return torch.meshgrid(x, y, indexing='ij')
40
+ return torch.meshgrid(x, y)
41
+
42
+
43
+ def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor:
44
+ r""" Compute similarity_map between two tensors using Dice-like equation.
45
+
46
+ Args:
47
+ map_x: Tensor with map to be compared
48
+ map_y: Tensor with map to be compared
49
+ constant: Used for numerical stability
50
+ alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator
51
+ """
52
+ return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \
53
+ (map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant)
54
+
55
+
56
+ def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor:
57
+ r""" Compute gradient map for a given tensor and stack of kernels.
58
+
59
+ Args:
60
+ x: Tensor with shape (N, C, H, W).
61
+ kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W)
62
+ Returns:
63
+ Gradients of x per-channel with shape (N, C, H, W)
64
+ """
65
+ padding = kernels.size(-1) // 2
66
+ grads = torch.nn.functional.conv2d(x, kernels, padding=padding)
67
+
68
+ return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True))
69
+
70
+
71
+ def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor:
72
+ r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values.
73
+ Complex numbers are represented by modulus and argument: r * \exp(i * \phi).
74
+
75
+ It will likely to be redundant with introduction of torch.ComplexTensor.
76
+
77
+ Args:
78
+ base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2).
79
+ exp: Exponent
80
+ Returns:
81
+ Complex tensor with shape (N, C, H, W, 2).
82
+ """
83
+ if base.dim() == 4:
84
+ x_complex_r = base.abs()
85
+ x_complex_phi = torch.atan2(torch.zeros_like(base), base)
86
+ elif base.dim() == 5 and base.size(-1) == 2:
87
+ x_complex_r = base.pow(2).sum(dim=-1).sqrt()
88
+ x_complex_phi = torch.atan2(base[..., 1], base[..., 0])
89
+ else:
90
+ raise ValueError(f'Expected real or complex tensor, got {base.size()}')
91
+
92
+ x_complex_pow_r = x_complex_r ** exp
93
+ x_complex_pow_phi = x_complex_phi * exp
94
+ x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi)
95
+ x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi)
96
+ return torch.stack((x_real_pow, x_imag_pow), dim=-1)
97
+
98
+
99
+ def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor:
100
+ r"""Crop tensor with images into small patches
101
+ Args:
102
+ x: Tensor with shape (N, C, H, W), expected to be images-like entities
103
+ size: Size of a square patch
104
+ stride: Step between patches
105
+ """
106
+ assert (x.shape[2] >= size) and (x.shape[3] >= size), \
107
+ f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})"
108
+ channels = x.shape[1]
109
+ patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride)
110
+ patches = patches.reshape(-1, channels, size, size)
111
+ return patches
libs/metric/piq/functional/colour_conversion.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Colour space conversion functions"""
2
+ from typing import Union, Dict
3
+ import torch
4
+
5
+
6
+ def rgb2lmn(x: torch.Tensor) -> torch.Tensor:
7
+ r"""Convert a batch of RGB images to a batch of LMN images
8
+
9
+ Args:
10
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
11
+
12
+ Returns:
13
+ Batch of images with shape (N, 3, H, W). LMN colour space.
14
+ """
15
+ weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27],
16
+ [0.30, 0.04, -0.35],
17
+ [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
18
+ x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2)
19
+ return x_lmn
20
+
21
+
22
+ def rgb2xyz(x: torch.Tensor) -> torch.Tensor:
23
+ r"""Convert a batch of RGB images to a batch of XYZ images
24
+
25
+ Args:
26
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
27
+
28
+ Returns:
29
+ Batch of images with shape (N, 3, H, W). XYZ colour space.
30
+ """
31
+ mask_below = (x <= 0.04045).type(x.dtype)
32
+ mask_above = (x > 0.04045).type(x.dtype)
33
+
34
+ tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above
35
+
36
+ weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
37
+ [0.2126729, 0.7151522, 0.0721750],
38
+ [0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device)
39
+
40
+ x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2)
41
+ return x_xyz
42
+
43
+
44
+ def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor:
45
+ r"""Convert a batch of XYZ images to a batch of LAB images
46
+
47
+ Args:
48
+ x: Batch of images with shape (N, 3, H, W). XYZ colour space.
49
+ illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant.
50
+ observer: {“2”, “10”}, optional. The aperture angle of the observer.
51
+
52
+ Returns:
53
+ Batch of images with shape (N, 3, H, W). LAB colour space.
54
+ """
55
+ epsilon = 0.008856
56
+ kappa = 903.3
57
+ illuminants: Dict[str, Dict] = \
58
+ {"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
59
+ '10': (1.111420406956693, 1, 0.3519978321919493)},
60
+ "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
61
+ '10': (0.9672062750333777, 1, 0.8142801513128616)},
62
+ "D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
63
+ '10': (0.9579665682254781, 1, 0.9092525159847462)},
64
+ "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
65
+ '10': (0.94809667673716, 1, 1.0730513595166162)},
66
+ "D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
67
+ '10': (0.9441713925645873, 1, 1.2064272211720228)},
68
+ "E": {'2': (1.0, 1.0, 1.0),
69
+ '10': (1.0, 1.0, 1.0)}}
70
+
71
+ illuminants_to_use = torch.tensor(illuminants[illuminant][observer],
72
+ dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
73
+
74
+ tmp = x / illuminants_to_use
75
+
76
+ mask_below = (tmp <= epsilon).type(x.dtype)
77
+ mask_above = (tmp > epsilon).type(x.dtype)
78
+ tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below
79
+
80
+ weights_xyz_to_lab = torch.tensor([[0, 116., 0],
81
+ [500., -500., 0],
82
+ [0, 200., -200.]], dtype=x.dtype, device=x.device)
83
+ bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
84
+
85
+ x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab
86
+ return x_lab
87
+
88
+
89
+ def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor:
90
+ r"""Convert a batch of RGB images to a batch of LAB images
91
+
92
+ Args:
93
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
94
+ data_range: dynamic range of the input image.
95
+
96
+ Returns:
97
+ Batch of images with shape (N, 3, H, W). LAB colour space.
98
+ """
99
+ return xyz2lab(rgb2xyz(x / float(data_range)))
100
+
101
+
102
+ def rgb2yiq(x: torch.Tensor) -> torch.Tensor:
103
+ r"""Convert a batch of RGB images to a batch of YIQ images
104
+
105
+ Args:
106
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
107
+
108
+ Returns:
109
+ Batch of images with shape (N, 3, H, W). YIQ colour space.
110
+ """
111
+ yiq_weights = torch.tensor([
112
+ [0.299, 0.587, 0.114],
113
+ [0.5959, -0.2746, -0.3213],
114
+ [0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t()
115
+ x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2)
116
+ return x_yiq
117
+
118
+
119
+ def rgb2lhm(x: torch.Tensor) -> torch.Tensor:
120
+ r"""Convert a batch of RGB images to a batch of LHM images
121
+
122
+ Args:
123
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
124
+
125
+ Returns:
126
+ Batch of images with shape (N, 3, H, W). LHM colour space.
127
+
128
+ Reference:
129
+ https://arxiv.org/pdf/1608.07433.pdf
130
+ """
131
+ lhm_weights = torch.tensor([
132
+ [0.2989, 0.587, 0.114],
133
+ [0.3, 0.04, -0.35],
134
+ [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
135
+ x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2)
136
+ return x_lhm
libs/metric/piq/functional/filters.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Filters for gradient computation, bluring, etc."""
2
+ import torch
3
+ import numpy as np
4
+ from typing import Optional
5
+
6
+
7
+ def haar_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
8
+ r"""Creates Haar kernel
9
+
10
+ Args:
11
+ kernel_size: size of the kernel
12
+ device: target device for kernel generation
13
+ dtype: target data type for kernel generation
14
+ Returns:
15
+ kernel: Tensor with shape (1, kernel_size, kernel_size)
16
+ """
17
+ kernel = torch.ones((kernel_size, kernel_size), device=device, dtype=dtype) / kernel_size
18
+ kernel[kernel_size // 2:, :] = - kernel[kernel_size // 2:, :]
19
+ return kernel.unsqueeze(0)
20
+
21
+
22
+ def hann_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
23
+ r"""Creates Hann kernel
24
+ Args:
25
+ kernel_size: size of the kernel
26
+ device: target device for kernel generation
27
+ dtype: target data type for kernel generation
28
+ Returns:
29
+ kernel: Tensor with shape (1, kernel_size, kernel_size)
30
+ """
31
+ # Take bigger window and drop borders
32
+ window = torch.hann_window(kernel_size + 2, periodic=False, device=device, dtype=dtype)[1:-1]
33
+ kernel = window[:, None] * window[None, :]
34
+ # Normalize and reshape kernel
35
+ return kernel.view(1, kernel_size, kernel_size) / kernel.sum()
36
+
37
+
38
+ def gaussian_filter(kernel_size: int, sigma: float, device: Optional[str] = None,
39
+ dtype: Optional[type] = None) -> torch.Tensor:
40
+ r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
41
+ Args:
42
+ size: Size of the kernel
43
+ sigma: Std of the distribution
44
+ device: target device for kernel generation
45
+ dtype: target data type for kernel generation
46
+ Returns:
47
+ gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
48
+ """
49
+ coords = torch.arange(kernel_size, dtype=dtype, device=device)
50
+ coords -= (kernel_size - 1) / 2.
51
+
52
+ g = coords ** 2
53
+ g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()
54
+
55
+ g /= g.sum()
56
+ return g.unsqueeze(0)
57
+
58
+
59
+ # Gradient operator kernels
60
+ def scharr_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
61
+ r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction
62
+
63
+ Args:
64
+ device: target device for kernel generation
65
+ dtype: target data type for kernel generation
66
+ Returns:
67
+ kernel: Tensor with shape (1, 3, 3)
68
+ """
69
+ return torch.tensor([[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]]], device=device, dtype=dtype) / 16
70
+
71
+
72
+ def prewitt_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
73
+ r"""Utility function that returns a normalized 3x3 Prewitt kernel in X direction
74
+
75
+ Args:
76
+ device: target device for kernel generation
77
+ dtype: target data type for kernel generation
78
+ Returns:
79
+ kernel: Tensor with shape (1, 3, 3)"""
80
+ return torch.tensor([[[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]]], device=device, dtype=dtype) / 3
81
+
82
+
83
+ def binomial_filter1d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
84
+ r"""Creates 1D normalized binomial filter
85
+
86
+ Args:
87
+ kernel_size (int): kernel size
88
+ device: target device for kernel generation
89
+ dtype: target data type for kernel generation
90
+
91
+ Returns:
92
+ Binomial kernel with shape (1, 1, kernel_size)
93
+ """
94
+ kernel = np.poly1d([0.5, 0.5]) ** (kernel_size - 1)
95
+ return torch.tensor(kernel.c, dtype=dtype, device=device).view(1, 1, kernel_size)
96
+
97
+
98
+ def average_filter2d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
99
+ r"""Creates 2D normalized average filter
100
+
101
+ Args:
102
+ kernel_size (int): kernel size
103
+ device: target device for kernel generation
104
+ dtype: target data type for kernel generation
105
+
106
+ Returns:
107
+ kernel: Tensor with shape (1, kernel_size, kernel_size)
108
+ """
109
+ window = torch.ones(kernel_size, dtype=dtype, device=device) / kernel_size
110
+ kernel = window[:, None] * window[None, :]
111
+ return kernel.unsqueeze(0)
libs/metric/piq/functional/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Custom layers used in metrics computations"""
2
+ import torch
3
+ from typing import Optional
4
+
5
+ from .filters import hann_filter
6
+
7
+
8
+ class L2Pool2d(torch.nn.Module):
9
+ r"""Applies L2 pooling with Hann window of size 3x3
10
+ Args:
11
+ x: Tensor with shape (N, C, H, W)"""
12
+ EPS = 1e-12
13
+
14
+ def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None:
15
+ super().__init__()
16
+ self.kernel_size = kernel_size
17
+ self.stride = stride
18
+ self.padding = padding
19
+
20
+ self.kernel: Optional[torch.Tensor] = None
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ if self.kernel is None:
24
+ C = x.size(1)
25
+ self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x)
26
+
27
+ out = torch.nn.functional.conv2d(
28
+ x ** 2, self.kernel,
29
+ stride=self.stride,
30
+ padding=self.padding,
31
+ groups=x.shape[1]
32
+ )
33
+ return (out + self.EPS).sqrt()
libs/metric/piq/functional/resize.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A standalone PyTorch implementation for fast and efficient bicubic resampling.
3
+ The resulting values are the same to MATLAB function imresize('bicubic').
4
+ ## Author: Sanghyun Son
5
+ ## Email: [email protected] (primary), [email protected] (secondary)
6
+ ## Version: 1.2.0
7
+ ## Last update: July 9th, 2020 (KST)
8
+ Dependency: torch
9
+ Example::
10
+ >>> import torch
11
+ >>> import core
12
+ >>> x = torch.arange(16).float().view(1, 1, 4, 4)
13
+ >>> y = core.imresize(x, sizes=(3, 3))
14
+ >>> print(y)
15
+ tensor([[[[ 0.7506, 2.1004, 3.4503],
16
+ [ 6.1505, 7.5000, 8.8499],
17
+ [11.5497, 12.8996, 14.2494]]]])
18
+ """
19
+
20
+ import math
21
+ import typing
22
+
23
+ import torch
24
+ from torch.nn import functional as F
25
+
26
+ __all__ = ['imresize']
27
+
28
+ _I = typing.Optional[int]
29
+ _D = typing.Optional[torch.dtype]
30
+
31
+
32
+ def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
33
+ range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
34
+ cont = range_around_0.to(dtype=x.dtype)
35
+ return cont
36
+
37
+
38
+ def linear_contribution(x: torch.Tensor) -> torch.Tensor:
39
+ ax = x.abs()
40
+ range_01 = ax.le(1)
41
+ cont = (1 - ax) * range_01.to(dtype=x.dtype)
42
+ return cont
43
+
44
+
45
+ def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
46
+ ax = x.abs()
47
+ ax2 = ax * ax
48
+ ax3 = ax * ax2
49
+
50
+ range_01 = ax.le(1)
51
+ range_12 = torch.logical_and(ax.gt(1), ax.le(2))
52
+
53
+ cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
54
+ cont_01 = cont_01 * range_01.to(dtype=x.dtype)
55
+
56
+ cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
57
+ cont_12 = cont_12 * range_12.to(dtype=x.dtype)
58
+
59
+ cont = cont_01 + cont_12
60
+ return cont
61
+
62
+
63
+ def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
64
+ range_3sigma = (x.abs() <= 3 * sigma + 1)
65
+ # Normalization will be done after
66
+ cont = torch.exp(-x.pow(2) / (2 * sigma ** 2))
67
+ cont = cont * range_3sigma.to(dtype=x.dtype)
68
+ return cont
69
+
70
+
71
+ def discrete_kernel(
72
+ kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
73
+ '''
74
+ For downsampling with integer scale only.
75
+ '''
76
+ downsampling_factor = int(1 / scale)
77
+ if kernel == 'cubic':
78
+ kernel_size_orig = 4
79
+ else:
80
+ raise ValueError('Pass!')
81
+
82
+ if antialiasing:
83
+ kernel_size = kernel_size_orig * downsampling_factor
84
+ else:
85
+ kernel_size = kernel_size_orig
86
+
87
+ if downsampling_factor % 2 == 0:
88
+ a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
89
+ else:
90
+ kernel_size -= 1
91
+ a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
92
+
93
+ with torch.no_grad():
94
+ r = torch.linspace(-a, a, steps=kernel_size)
95
+ k = cubic_contribution(r).view(-1, 1)
96
+ k = torch.matmul(k, k.t())
97
+ k /= k.sum()
98
+
99
+ return k
100
+
101
+
102
+ def reflect_padding(
103
+ x: torch.Tensor,
104
+ dim: int,
105
+ pad_pre: int,
106
+ pad_post: int) -> torch.Tensor:
107
+ '''
108
+ Apply reflect padding to the given Tensor.
109
+ Note that it is slightly different from the PyTorch functional.pad,
110
+ where boundary elements are used only once.
111
+ Instead, we follow the MATLAB implementation
112
+ which uses boundary elements twice.
113
+ For example,
114
+ [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
115
+ while our implementation yields [a, a, b, c, d, d].
116
+ '''
117
+ b, c, h, w = x.size()
118
+ if dim == 2 or dim == -2:
119
+ padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
120
+ padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
121
+ for p in range(pad_pre):
122
+ padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
123
+ for p in range(pad_post):
124
+ padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
125
+ else:
126
+ padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
127
+ padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
128
+ for p in range(pad_pre):
129
+ padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
130
+ for p in range(pad_post):
131
+ padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
132
+
133
+ return padding_buffer
134
+
135
+
136
+ def padding(
137
+ x: torch.Tensor,
138
+ dim: int,
139
+ pad_pre: int,
140
+ pad_post: int,
141
+ padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
142
+ if padding_type is None:
143
+ return x
144
+ elif padding_type == 'reflect':
145
+ x_pad = reflect_padding(x, dim, pad_pre, pad_post)
146
+ else:
147
+ raise ValueError('{} padding is not supported!'.format(padding_type))
148
+
149
+ return x_pad
150
+
151
+
152
+ def get_padding(
153
+ base: torch.Tensor,
154
+ kernel_size: int,
155
+ x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
156
+ base = base.long()
157
+ r_min = base.min()
158
+ r_max = base.max() + kernel_size - 1
159
+
160
+ if r_min <= 0:
161
+ pad_pre = -r_min
162
+ pad_pre = pad_pre.item()
163
+ base += pad_pre
164
+ else:
165
+ pad_pre = 0
166
+
167
+ if r_max >= x_size:
168
+ pad_post = r_max - x_size + 1
169
+ pad_post = pad_post.item()
170
+ else:
171
+ pad_post = 0
172
+
173
+ return pad_pre, pad_post, base
174
+
175
+
176
+ def get_weight(
177
+ dist: torch.Tensor,
178
+ kernel_size: int,
179
+ kernel: str = 'cubic',
180
+ sigma: float = 2.0,
181
+ antialiasing_factor: float = 1) -> torch.Tensor:
182
+ buffer_pos = dist.new_zeros(kernel_size, len(dist))
183
+ for idx, buffer_sub in enumerate(buffer_pos):
184
+ buffer_sub.copy_(dist - idx)
185
+
186
+ # Expand (downsampling) / Shrink (upsampling) the receptive field.
187
+ buffer_pos *= antialiasing_factor
188
+ if kernel == 'cubic':
189
+ weight = cubic_contribution(buffer_pos)
190
+ elif kernel == 'gaussian':
191
+ weight = gaussian_contribution(buffer_pos, sigma=sigma)
192
+ else:
193
+ raise ValueError('{} kernel is not supported!'.format(kernel))
194
+
195
+ weight /= weight.sum(dim=0, keepdim=True)
196
+ return weight
197
+
198
+
199
+ def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
200
+ # Resize height
201
+ if dim == 2 or dim == -2:
202
+ k = (kernel_size, 1)
203
+ h_out = x.size(-2) - kernel_size + 1
204
+ w_out = x.size(-1)
205
+ # Resize width
206
+ else:
207
+ k = (1, kernel_size)
208
+ h_out = x.size(-2)
209
+ w_out = x.size(-1) - kernel_size + 1
210
+
211
+ unfold = F.unfold(x, k)
212
+ unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
213
+ return unfold
214
+
215
+
216
+ def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
217
+ if x.dim() == 4:
218
+ b, c, h, w = x.size()
219
+ elif x.dim() == 3:
220
+ c, h, w = x.size()
221
+ b = None
222
+ elif x.dim() == 2:
223
+ h, w = x.size()
224
+ b = c = None
225
+ else:
226
+ raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
227
+
228
+ x = x.view(-1, 1, h, w)
229
+ return x, b, c, h, w
230
+
231
+
232
+ def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
233
+ rh = x.size(-2)
234
+ rw = x.size(-1)
235
+ # Back to the original dimension
236
+ if b is not None:
237
+ x = x.view(b, c, rh, rw) # 4-dim
238
+ else:
239
+ if c is not None:
240
+ x = x.view(c, rh, rw) # 3-dim
241
+ else:
242
+ x = x.view(rh, rw) # 2-dim
243
+
244
+ return x
245
+
246
+
247
+ def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
248
+ if x.dtype != torch.float32 or x.dtype != torch.float64:
249
+ dtype = x.dtype
250
+ x = x.float()
251
+ else:
252
+ dtype = None
253
+
254
+ return x, dtype
255
+
256
+
257
+ def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
258
+ if dtype is not None:
259
+ if not dtype.is_floating_point:
260
+ x = x.round()
261
+ # To prevent over/underflow when converting types
262
+ if dtype is torch.uint8:
263
+ x = x.clamp(0, 255)
264
+
265
+ x = x.to(dtype=dtype)
266
+
267
+ return x
268
+
269
+
270
+ def resize_1d(
271
+ x: torch.Tensor,
272
+ dim: int,
273
+ size: int,
274
+ scale: float,
275
+ kernel: str = 'cubic',
276
+ sigma: float = 2.0,
277
+ padding_type: str = 'reflect',
278
+ antialiasing: bool = True) -> torch.Tensor:
279
+ '''
280
+ Args:
281
+ x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
282
+ dim (int):
283
+ scale (float):
284
+ size (int):
285
+ Return:
286
+ '''
287
+ # Identity case
288
+ if scale == 1:
289
+ return x
290
+
291
+ # Default bicubic kernel with antialiasing (only when downsampling)
292
+ if kernel == 'cubic':
293
+ kernel_size = 4
294
+ else:
295
+ kernel_size = math.floor(6 * sigma)
296
+
297
+ if antialiasing and (scale < 1):
298
+ antialiasing_factor = scale
299
+ kernel_size = math.ceil(kernel_size / antialiasing_factor)
300
+ else:
301
+ antialiasing_factor = 1
302
+
303
+ # We allow margin to both sizes
304
+ kernel_size += 2
305
+
306
+ # Weights only depend on the shape of input and output,
307
+ # so we do not calculate gradients here.
308
+ with torch.no_grad():
309
+ pos = torch.linspace(
310
+ 0, size - 1, steps=size, dtype=x.dtype, device=x.device,
311
+ )
312
+ pos = (pos + 0.5) / scale - 0.5
313
+ base = pos.floor() - (kernel_size // 2) + 1
314
+ dist = pos - base
315
+ weight = get_weight(
316
+ dist,
317
+ kernel_size,
318
+ kernel=kernel,
319
+ sigma=sigma,
320
+ antialiasing_factor=antialiasing_factor,
321
+ )
322
+ pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
323
+
324
+ # To backpropagate through x
325
+ x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
326
+ unfold = reshape_tensor(x_pad, dim, kernel_size)
327
+ # Subsampling first
328
+ if dim == 2 or dim == -2:
329
+ sample = unfold[..., base, :]
330
+ weight = weight.view(1, kernel_size, sample.size(2), 1)
331
+ else:
332
+ sample = unfold[..., base]
333
+ weight = weight.view(1, kernel_size, 1, sample.size(3))
334
+
335
+ # Apply the kernel
336
+ x = sample * weight
337
+ x = x.sum(dim=1, keepdim=True)
338
+ return x
339
+
340
+
341
+ def downsampling_2d(
342
+ x: torch.Tensor,
343
+ k: torch.Tensor,
344
+ scale: int,
345
+ padding_type: str = 'reflect') -> torch.Tensor:
346
+ c = x.size(1)
347
+ k_h = k.size(-2)
348
+ k_w = k.size(-1)
349
+
350
+ k = k.to(dtype=x.dtype, device=x.device)
351
+ k = k.view(1, 1, k_h, k_w)
352
+ k = k.repeat(c, c, 1, 1)
353
+ e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
354
+ e = e.view(c, c, 1, 1)
355
+ k = k * e
356
+
357
+ pad_h = (k_h - scale) // 2
358
+ pad_w = (k_w - scale) // 2
359
+ x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
360
+ x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
361
+ y = F.conv2d(x, k, padding=0, stride=scale)
362
+ return y
363
+
364
+
365
+ def imresize(
366
+ x: torch.Tensor,
367
+ scale: typing.Optional[float] = None,
368
+ sizes: typing.Optional[typing.Tuple[int, int]] = None,
369
+ kernel: typing.Union[str, torch.Tensor] = 'cubic',
370
+ sigma: float = 2,
371
+ rotation_degree: float = 0,
372
+ padding_type: str = 'reflect',
373
+ antialiasing: bool = True) -> torch.Tensor:
374
+ """
375
+ Args:
376
+ x (torch.Tensor):
377
+ scale (float):
378
+ sizes (tuple(int, int)):
379
+ kernel (str, default='cubic'):
380
+ sigma (float, default=2):
381
+ rotation_degree (float, default=0):
382
+ padding_type (str, default='reflect'):
383
+ antialiasing (bool, default=True):
384
+ Return:
385
+ torch.Tensor:
386
+ """
387
+ if scale is None and sizes is None:
388
+ raise ValueError('One of scale or sizes must be specified!')
389
+ if scale is not None and sizes is not None:
390
+ raise ValueError('Please specify scale or sizes to avoid conflict!')
391
+
392
+ x, b, c, h, w = reshape_input(x)
393
+
394
+ if sizes is None and scale is not None:
395
+ '''
396
+ # Check if we can apply the convolution algorithm
397
+ scale_inv = 1 / scale
398
+ if isinstance(kernel, str) and scale_inv.is_integer():
399
+ kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
400
+ elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
401
+ raise ValueError(
402
+ 'An integer downsampling factor '
403
+ 'should be used with a predefined kernel!'
404
+ )
405
+ '''
406
+ # Determine output size
407
+ sizes = (math.ceil(h * scale), math.ceil(w * scale))
408
+ scales = (scale, scale)
409
+
410
+ if scale is None and sizes is not None:
411
+ scales = (sizes[0] / h, sizes[1] / w)
412
+
413
+ x, dtype = cast_input(x)
414
+
415
+ if isinstance(kernel, str) and sizes is not None:
416
+ # Core resizing module
417
+ x = resize_1d(x, -2, size=sizes[0], scale=scales[0], kernel=kernel, sigma=sigma, padding_type=padding_type,
418
+ antialiasing=antialiasing)
419
+ x = resize_1d(x, -1, size=sizes[1], scale=scales[1], kernel=kernel, sigma=sigma, padding_type=padding_type,
420
+ antialiasing=antialiasing)
421
+ elif isinstance(kernel, torch.Tensor) and scale is not None:
422
+ x = downsampling_2d(x, kernel, scale=int(1 / scale))
423
+
424
+ x = reshape_output(x, b, c)
425
+ x = cast_output(x, dtype)
426
+ return x
libs/metric/piq/perceptual.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Content loss, Style loss, LPIPS and DISTS metrics
3
+ References:
4
+ .. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias
5
+ (2016). A Neural Algorithm of Artistic Style}
6
+ Association for Research in Vision and Ophthalmology (ARVO)
7
+ https://arxiv.org/abs/1508.06576
8
+ .. [2] Zhang, Richard and Isola, Phillip and Efros, et al.
9
+ (2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
10
+ 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition
11
+ https://arxiv.org/abs/1801.03924
12
+ """
13
+ from typing import List, Union, Collection
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.nn.modules.loss import _Loss
18
+ from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights
19
+
20
+ from .utils import _validate_input, _reduce
21
+ from .functional import similarity_map, L2Pool2d
22
+
23
+ # Map VGG names to corresponding number in torchvision layer
24
+ VGG16_LAYERS = {
25
+ "conv1_1": '0', "relu1_1": '1',
26
+ "conv1_2": '2', "relu1_2": '3',
27
+ "pool1": '4',
28
+ "conv2_1": '5', "relu2_1": '6',
29
+ "conv2_2": '7', "relu2_2": '8',
30
+ "pool2": '9',
31
+ "conv3_1": '10', "relu3_1": '11',
32
+ "conv3_2": '12', "relu3_2": '13',
33
+ "conv3_3": '14', "relu3_3": '15',
34
+ "pool3": '16',
35
+ "conv4_1": '17', "relu4_1": '18',
36
+ "conv4_2": '19', "relu4_2": '20',
37
+ "conv4_3": '21', "relu4_3": '22',
38
+ "pool4": '23',
39
+ "conv5_1": '24', "relu5_1": '25',
40
+ "conv5_2": '26', "relu5_2": '27',
41
+ "conv5_3": '28', "relu5_3": '29',
42
+ "pool5": '30',
43
+ }
44
+
45
+ VGG19_LAYERS = {
46
+ "conv1_1": '0', "relu1_1": '1',
47
+ "conv1_2": '2', "relu1_2": '3',
48
+ "pool1": '4',
49
+ "conv2_1": '5', "relu2_1": '6',
50
+ "conv2_2": '7', "relu2_2": '8',
51
+ "pool2": '9',
52
+ "conv3_1": '10', "relu3_1": '11',
53
+ "conv3_2": '12', "relu3_2": '13',
54
+ "conv3_3": '14', "relu3_3": '15',
55
+ "conv3_4": '16', "relu3_4": '17',
56
+ "pool3": '18',
57
+ "conv4_1": '19', "relu4_1": '20',
58
+ "conv4_2": '21', "relu4_2": '22',
59
+ "conv4_3": '23', "relu4_3": '24',
60
+ "conv4_4": '25', "relu4_4": '26',
61
+ "pool4": '27',
62
+ "conv5_1": '28', "relu5_1": '29',
63
+ "conv5_2": '30', "relu5_2": '31',
64
+ "conv5_3": '32', "relu5_3": '33',
65
+ "conv5_4": '34', "relu5_4": '35',
66
+ "pool5": '36',
67
+ }
68
+
69
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
70
+ IMAGENET_STD = [0.229, 0.224, 0.225]
71
+
72
+ # Constant used in feature normalization to avoid zero division
73
+ EPS = 1e-10
74
+
75
+
76
+ class ContentLoss(_Loss):
77
+ r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks.
78
+ Uses pretrained VGG models from torchvision.
79
+ Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1]
80
+
81
+ Args:
82
+ feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
83
+ layers: List of strings with layer names. Default: ``'relu3_3'``
84
+ weights: List of float weight to balance different layers
85
+ replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
86
+ distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
87
+ reduction: Specifies the reduction type:
88
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
89
+ mean: List of float values used for data standardization. Default: ImageNet mean.
90
+ If there is no need to normalize data, use [0., 0., 0.].
91
+ std: List of float values used for data standardization. Default: ImageNet std.
92
+ If there is no need to normalize data, use [1., 1., 1.].
93
+ normalize_features: If true, unit-normalize each feature in channel dimension before scaling
94
+ and computing distance. See references for details.
95
+
96
+ Examples:
97
+ >>> loss = ContentLoss()
98
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
99
+ >>> y = torch.rand(3, 3, 256, 256)
100
+ >>> output = loss(x, y)
101
+ >>> output.backward()
102
+
103
+ References:
104
+ Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
105
+ A Neural Algorithm of Artistic Style
106
+ Association for Research in Vision and Ophthalmology (ARVO)
107
+ https://arxiv.org/abs/1508.06576
108
+
109
+ Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
110
+ The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
111
+ IEEE/CVF Conference on Computer Vision and Pattern Recognition
112
+ https://arxiv.org/abs/1801.03924
113
+ """
114
+
115
+ def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",),
116
+ weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False,
117
+ distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
118
+ std: List[float] = IMAGENET_STD, normalize_features: bool = False,
119
+ allow_layers_weights_mismatch: bool = False) -> None:
120
+
121
+ assert allow_layers_weights_mismatch or len(layers) == len(weights), \
122
+ f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \
123
+ f'which will cause incorrect results. Please provide weight for each layer.'
124
+
125
+ super().__init__()
126
+
127
+ if callable(feature_extractor):
128
+ self.model = feature_extractor
129
+ self.layers = layers
130
+ else:
131
+ if feature_extractor == "vgg16":
132
+ # self.model = vgg16(pretrained=True, progress=False).features
133
+ self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features
134
+ self.layers = [VGG16_LAYERS[l] for l in layers]
135
+ elif feature_extractor == "vgg19":
136
+ # self.model = vgg19(pretrained=True, progress=False).features
137
+ self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features
138
+ self.layers = [VGG19_LAYERS[l] for l in layers]
139
+ else:
140
+ raise ValueError("Unknown feature extractor")
141
+
142
+ if replace_pooling:
143
+ self.model = self.replace_pooling(self.model)
144
+
145
+ # Disable gradients
146
+ for param in self.model.parameters():
147
+ param.requires_grad_(False)
148
+
149
+ self.distance = {
150
+ "mse": nn.MSELoss,
151
+ "mae": nn.L1Loss,
152
+ }[distance](reduction='none')
153
+
154
+ self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights]
155
+
156
+ mean = torch.tensor(mean)
157
+ std = torch.tensor(std)
158
+ self.mean = mean.view(1, -1, 1, 1)
159
+ self.std = std.view(1, -1, 1, 1)
160
+
161
+ self.normalize_features = normalize_features
162
+ self.reduction = reduction
163
+
164
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
165
+ r"""Computation of Content loss between feature representations of prediction :math:`x` and
166
+ target :math:`y` tensors.
167
+
168
+ Args:
169
+ x: An input tensor. Shape :math:`(N, C, H, W)`.
170
+ y: A target tensor. Shape :math:`(N, C, H, W)`.
171
+
172
+ Returns:
173
+ Content loss between feature representations
174
+ """
175
+ _validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))
176
+
177
+ self.model.to(x)
178
+ x_features = self.get_features(x)
179
+ y_features = self.get_features(y)
180
+
181
+ distances = self.compute_distance(x_features, y_features)
182
+
183
+ # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
184
+ loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1)
185
+
186
+ return _reduce(loss, self.reduction)
187
+
188
+ def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]:
189
+ r"""Take L2 or L1 distance between feature maps depending on ``distance``.
190
+
191
+ Args:
192
+ x_features: Features of the input tensor.
193
+ y_features: Features of the target tensor.
194
+
195
+ Returns:
196
+ Distance between feature maps
197
+ """
198
+ return [self.distance(x, y) for x, y in zip(x_features, y_features)]
199
+
200
+ def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
201
+ r"""
202
+ Args:
203
+ x: Tensor. Shape :math:`(N, C, H, W)`.
204
+
205
+ Returns:
206
+ List of features extracted from intermediate layers
207
+ """
208
+ # Normalize input
209
+ x = (x - self.mean.to(x)) / self.std.to(x)
210
+
211
+ features = []
212
+ for name, module in self.model._modules.items():
213
+ x = module(x)
214
+ if name in self.layers:
215
+ features.append(self.normalize(x) if self.normalize_features else x)
216
+
217
+ return features
218
+
219
+ @staticmethod
220
+ def normalize(x: torch.Tensor) -> torch.Tensor:
221
+ r"""Normalize feature maps in channel direction to unit length.
222
+
223
+ Args:
224
+ x: Tensor. Shape :math:`(N, C, H, W)`.
225
+
226
+ Returns:
227
+ Normalized input
228
+ """
229
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
230
+ return x / (norm_factor + EPS)
231
+
232
+ def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
233
+ r"""Turn All MaxPool layers into AveragePool
234
+
235
+ Args:
236
+ module: Module to change MaxPool int AveragePool
237
+
238
+ Returns:
239
+ Module with AveragePool instead MaxPool
240
+
241
+ """
242
+ module_output = module
243
+ if isinstance(module, torch.nn.MaxPool2d):
244
+ module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
245
+
246
+ for name, child in module.named_children():
247
+ module_output.add_module(name, self.replace_pooling(child))
248
+ return module_output
249
+
250
+
251
+ class StyleLoss(ContentLoss):
252
+ r"""Creates Style loss that can be used for image style transfer or as a measure in
253
+ image to image tasks. Computes distance between Gram matrices of feature maps.
254
+ Uses pretrained VGG models from torchvision.
255
+
256
+ By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
257
+ If no normalisation is required, change `mean` and `std` values accordingly.
258
+
259
+ Args:
260
+ feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
261
+ layers: List of strings with layer names. Default: ``'relu3_3'``
262
+ weights: List of float weight to balance different layers
263
+ replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
264
+ distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
265
+ reduction: Specifies the reduction type:
266
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
267
+ mean: List of float values used for data standardization. Default: ImageNet mean.
268
+ If there is no need to normalize data, use [0., 0., 0.].
269
+ std: List of float values used for data standardization. Default: ImageNet std.
270
+ If there is no need to normalize data, use [1., 1., 1.].
271
+ normalize_features: If true, unit-normalize each feature in channel dimension before scaling
272
+ and computing distance. See references for details.
273
+
274
+ Examples:
275
+ >>> loss = StyleLoss()
276
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
277
+ >>> y = torch.rand(3, 3, 256, 256)
278
+ >>> output = loss(x, y)
279
+ >>> output.backward()
280
+
281
+ References:
282
+ Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
283
+ A Neural Algorithm of Artistic Style
284
+ Association for Research in Vision and Ophthalmology (ARVO)
285
+ https://arxiv.org/abs/1508.06576
286
+
287
+ Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
288
+ The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
289
+ IEEE/CVF Conference on Computer Vision and Pattern Recognition
290
+ https://arxiv.org/abs/1801.03924
291
+ """
292
+
293
+ def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor):
294
+ r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``.
295
+
296
+ Args:
297
+ x_features: Features of the input tensor.
298
+ y_features: Features of the target tensor.
299
+
300
+ Returns:
301
+ Distance between Gram matrices
302
+ """
303
+ x_gram = [self.gram_matrix(x) for x in x_features]
304
+ y_gram = [self.gram_matrix(x) for x in y_features]
305
+ return [self.distance(x, y) for x, y in zip(x_gram, y_gram)]
306
+
307
+ @staticmethod
308
+ def gram_matrix(x: torch.Tensor) -> torch.Tensor:
309
+ r"""Compute Gram matrix for batch of features.
310
+
311
+ Args:
312
+ x: Tensor. Shape :math:`(N, C, H, W)`.
313
+
314
+ Returns:
315
+ Gram matrix for given input
316
+ """
317
+ B, C, H, W = x.size()
318
+ gram = []
319
+ for i in range(B):
320
+ features = x[i].view(C, H * W)
321
+
322
+ # Add fake channel dimension
323
+ gram.append(torch.mm(features, features.t()).unsqueeze(0))
324
+
325
+ return torch.stack(gram)
326
+
327
+
328
+ class LPIPS(ContentLoss):
329
+ r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported.
330
+
331
+ By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
332
+ If no normalisation is required, change `mean` and `std` values accordingly.
333
+
334
+ Args:
335
+ replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
336
+ distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
337
+ reduction: Specifies the reduction type:
338
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
339
+ mean: List of float values used for data standardization. Default: ImageNet mean.
340
+ If there is no need to normalize data, use [0., 0., 0.].
341
+ std: List of float values used for data standardization. Default: ImageNet std.
342
+ If there is no need to normalize data, use [1., 1., 1.].
343
+
344
+ Examples:
345
+ >>> loss = LPIPS()
346
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
347
+ >>> y = torch.rand(3, 3, 256, 256)
348
+ >>> output = loss(x, y)
349
+ >>> output.backward()
350
+
351
+ References:
352
+ Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
353
+ A Neural Algorithm of Artistic Style
354
+ Association for Research in Vision and Ophthalmology (ARVO)
355
+ https://arxiv.org/abs/1508.06576
356
+
357
+ Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
358
+ The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
359
+ IEEE/CVF Conference on Computer Vision and Pattern Recognition
360
+ https://arxiv.org/abs/1801.03924
361
+ https://github.com/richzhang/PerceptualSimilarity
362
+ """
363
+ _weights_url = "https://github.com/photosynthesis-team/" + \
364
+ "photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt"
365
+
366
+ def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean",
367
+ mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None:
368
+ lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
369
+ lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
370
+ super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights,
371
+ replace_pooling=replace_pooling, distance=distance,
372
+ reduction=reduction, mean=mean, std=std,
373
+ normalize_features=True)
374
+
375
+
376
+ class DISTS(ContentLoss):
377
+ r"""Deep Image Structure and Texture Similarity metric.
378
+
379
+ By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
380
+ If no normalisation is required, change `mean` and `std` values accordingly.
381
+
382
+ Args:
383
+ reduction: Specifies the reduction type:
384
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
385
+ mean: List of float values used for data standardization. Default: ImageNet mean.
386
+ If there is no need to normalize data, use [0., 0., 0.].
387
+ std: List of float values used for data standardization. Default: ImageNet std.
388
+ If there is no need to normalize data, use [1., 1., 1.].
389
+
390
+ Examples:
391
+ >>> loss = DISTS()
392
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
393
+ >>> y = torch.rand(3, 3, 256, 256)
394
+ >>> output = loss(x, y)
395
+ >>> output.backward()
396
+
397
+ References:
398
+ Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020).
399
+ Image Quality Assessment: Unifying Structure and Texture Similarity.
400
+ https://arxiv.org/abs/2004.07728
401
+ https://github.com/dingkeyan93/DISTS
402
+ """
403
+ _weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt"
404
+
405
+ def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
406
+ std: List[float] = IMAGENET_STD) -> None:
407
+ dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
408
+ channels = [3, 64, 128, 256, 512, 512]
409
+
410
+ weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
411
+ dists_weights = list(torch.split(weights['alpha'], channels, dim=1))
412
+ dists_weights.extend(torch.split(weights['beta'], channels, dim=1))
413
+
414
+ super().__init__("vgg16", layers=dists_layers, weights=dists_weights,
415
+ replace_pooling=True, reduction=reduction, mean=mean, std=std,
416
+ normalize_features=False, allow_layers_weights_mismatch=True)
417
+
418
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
419
+ r"""
420
+
421
+ Args:
422
+ x: An input tensor. Shape :math:`(N, C, H, W)`.
423
+ y: A target tensor. Shape :math:`(N, C, H, W)`.
424
+
425
+ Returns:
426
+ Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1].
427
+ """
428
+ _, _, H, W = x.shape
429
+
430
+ if min(H, W) > 256:
431
+ x = torch.nn.functional.interpolate(
432
+ x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
433
+ y = torch.nn.functional.interpolate(
434
+ y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
435
+
436
+ loss = super().forward(x, y)
437
+ return 1 - loss
438
+
439
+ def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]:
440
+ r"""Compute structure similarity between feature maps
441
+
442
+ Args:
443
+ x_features: Features of the input tensor.
444
+ y_features: Features of the target tensor.
445
+
446
+ Returns:
447
+ Structural similarity distance between feature maps
448
+ """
449
+ structure_distance, texture_distance = [], []
450
+ # Small constant for numerical stability
451
+ EPS = 1e-6
452
+
453
+ for x, y in zip(x_features, y_features):
454
+ x_mean = x.mean([2, 3], keepdim=True)
455
+ y_mean = y.mean([2, 3], keepdim=True)
456
+ structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS))
457
+
458
+ x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True)
459
+ y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True)
460
+ xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean
461
+ texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS))
462
+
463
+ return structure_distance + texture_distance
464
+
465
+ def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
466
+ r"""
467
+
468
+ Args:
469
+ x: Input tensor
470
+
471
+ Returns:
472
+ List of features extracted from input tensor
473
+ """
474
+ features = super().get_features(x)
475
+
476
+ # Add input tensor as an additional feature
477
+ features.insert(0, x)
478
+ return features
479
+
480
+ def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
481
+ r"""Turn All MaxPool layers into L2Pool
482
+
483
+ Args:
484
+ module: Module to change MaxPool into L2Pool
485
+
486
+ Returns:
487
+ Module with L2Pool instead of MaxPool
488
+ """
489
+ module_output = module
490
+ if isinstance(module, torch.nn.MaxPool2d):
491
+ module_output = L2Pool2d(kernel_size=3, stride=2, padding=1)
492
+
493
+ for name, child in module.named_children():
494
+ module_output.add_module(name, self.replace_pooling(child))
495
+
496
+ return module_output
libs/metric/piq/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .common import _validate_input, _reduce, _parse_version
2
+
3
+ __all__ = [
4
+ "_validate_input",
5
+ "_reduce",
6
+ '_parse_version'
7
+ ]
libs/metric/piq/utils/common.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import warnings
4
+
5
+ from typing import Tuple, List, Optional, Union, Dict, Any
6
+
7
+ SEMVER_VERSION_PATTERN = re.compile(
8
+ r"""
9
+ ^
10
+ (?P<major>0|[1-9]\d*)
11
+ \.
12
+ (?P<minor>0|[1-9]\d*)
13
+ \.
14
+ (?P<patch>0|[1-9]\d*)
15
+ (?:-(?P<prerelease>
16
+ (?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)
17
+ (?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*
18
+ ))?
19
+ (?:\+(?P<build>
20
+ [0-9a-zA-Z-]+
21
+ (?:\.[0-9a-zA-Z-]+)*
22
+ ))?
23
+ $
24
+ """,
25
+ re.VERBOSE,
26
+ )
27
+
28
+
29
+ PEP_440_VERSION_PATTERN = r"""
30
+ v?
31
+ (?:
32
+ (?:(?P<epoch>[0-9]+)!)? # epoch
33
+ (?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
34
+ (?P<pre> # pre-release
35
+ [-_\.]?
36
+ (?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
37
+ [-_\.]?
38
+ (?P<pre_n>[0-9]+)?
39
+ )?
40
+ (?P<post> # post release
41
+ (?:-(?P<post_n1>[0-9]+))
42
+ |
43
+ (?:
44
+ [-_\.]?
45
+ (?P<post_l>post|rev|r)
46
+ [-_\.]?
47
+ (?P<post_n2>[0-9]+)?
48
+ )
49
+ )?
50
+ (?P<dev> # dev release
51
+ [-_\.]?
52
+ (?P<dev_l>dev)
53
+ [-_\.]?
54
+ (?P<dev_n>[0-9]+)?
55
+ )?
56
+ )
57
+ (?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
58
+ """
59
+
60
+
61
+ def _validate_input(
62
+ tensors: List[torch.Tensor],
63
+ dim_range: Tuple[int, int] = (0, -1),
64
+ data_range: Tuple[float, float] = (0., -1.),
65
+ # size_dim_range: Tuple[float, float] = (0., -1.),
66
+ size_range: Optional[Tuple[int, int]] = None,
67
+ ) -> None:
68
+ r"""Check that input(-s) satisfies the requirements
69
+ Args:
70
+ tensors: Tensors to check
71
+ dim_range: Allowed number of dimensions. (min, max)
72
+ data_range: Allowed range of values in tensors. (min, max)
73
+ size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
74
+ """
75
+
76
+ if not __debug__:
77
+ return
78
+
79
+ x = tensors[0]
80
+
81
+ for t in tensors:
82
+ assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
83
+ assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
84
+
85
+ if size_range is None:
86
+ assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
87
+ else:
88
+ assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
89
+ f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
90
+
91
+ if dim_range[0] == dim_range[1]:
92
+ assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
93
+ elif dim_range[0] < dim_range[1]:
94
+ assert dim_range[0] <= t.dim() <= dim_range[1], \
95
+ f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
96
+
97
+ if data_range[0] < data_range[1]:
98
+ assert data_range[0] <= t.min(), \
99
+ f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
100
+ assert t.max() <= data_range[1], \
101
+ f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
102
+
103
+
104
+ def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
105
+ r"""Reduce input in batch dimension if needed.
106
+
107
+ Args:
108
+ x: Tensor with shape (N, *).
109
+ reduction: Specifies the reduction type:
110
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
111
+ """
112
+ if reduction == 'none':
113
+ return x
114
+ elif reduction == 'mean':
115
+ return x.mean(dim=0)
116
+ elif reduction == 'sum':
117
+ return x.sum(dim=0)
118
+ else:
119
+ raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
120
+
121
+
122
+ def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
123
+ """ Parses valid Python versions according to Semver and PEP 440 specifications.
124
+ For more on Semver check: https://semver.org/
125
+ For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
126
+
127
+ Implementation is inspired by:
128
+ - https://github.com/python-semver
129
+ - https://github.com/pypa/packaging
130
+
131
+ Args:
132
+ version: unparsed information about the library of interest.
133
+
134
+ Returns:
135
+ parsed information about the library of interest.
136
+ """
137
+ if isinstance(version, bytes):
138
+ version = version.decode("UTF-8")
139
+ elif not isinstance(version, str) and not isinstance(version, bytes):
140
+ raise TypeError(f"not expecting type {type(version)}")
141
+
142
+ # Semver processing
143
+ match = SEMVER_VERSION_PATTERN.match(version)
144
+ if match:
145
+ matched_version_parts: Dict[str, Any] = match.groupdict()
146
+ release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
147
+ return release
148
+
149
+ # PEP 440 processing
150
+ regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
151
+ match = regex.search(version)
152
+
153
+ if match is None:
154
+ warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
155
+ return tuple()
156
+
157
+ release = tuple(int(i) for i in match.group("release").split("."))
158
+ return release
libs/metric/pytorch_fid/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = '0.3.0'
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+
6
+ from .inception import InceptionV3
7
+ from .fid_score import calculate_frechet_distance
8
+
9
+
10
+ class PytorchFIDFactory(torch.nn.Module):
11
+ """
12
+
13
+ Args:
14
+ channels:
15
+ inception_block_idx:
16
+
17
+ Examples:
18
+ >>> fid_factory = PytorchFIDFactory()
19
+ >>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
20
+ >>> print(fid_score)
21
+ """
22
+
23
+ def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
24
+ super().__init__()
25
+ self.channels = channels
26
+
27
+ # load models
28
+ assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
29
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
30
+ self.inception_v3 = InceptionV3([block_idx])
31
+
32
+ @torch.no_grad()
33
+ def calculate_activation_statistics(self, samples):
34
+ features = self.inception_v3(samples)[0]
35
+ features = rearrange(features, '... 1 1 -> ...')
36
+
37
+ mu = torch.mean(features, dim=0).cpu()
38
+ sigma = torch.cov(features).cpu()
39
+ return mu, sigma
40
+
41
+ def score(self, real_samples, fake_samples):
42
+ if self.channels == 1:
43
+ real_samples, fake_samples = map(
44
+ lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
45
+ )
46
+
47
+ min_batch = min(real_samples.shape[0], fake_samples.shape[0])
48
+ real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
49
+
50
+ m1, s1 = self.calculate_activation_statistics(real_samples)
51
+ m2, s2 = self.calculate_activation_statistics(fake_samples)
52
+
53
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
54
+ return fid_value
libs/metric/pytorch_fid/fid_score.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ import os
35
+ import pathlib
36
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torchvision.transforms as TF
41
+ from PIL import Image
42
+ from scipy import linalg
43
+ from torch.nn.functional import adaptive_avg_pool2d
44
+
45
+ try:
46
+ from tqdm import tqdm
47
+ except ImportError:
48
+ # If tqdm is not available, provide a mock version of it
49
+ def tqdm(x):
50
+ return x
51
+
52
+ from .inception import InceptionV3
53
+
54
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
55
+ parser.add_argument('--batch-size', type=int, default=50,
56
+ help='Batch size to use')
57
+ parser.add_argument('--num-workers', type=int,
58
+ help=('Number of processes to use for data loading. '
59
+ 'Defaults to `min(8, num_cpus)`'))
60
+ parser.add_argument('--device', type=str, default=None,
61
+ help='Device to use. Like cuda, cuda:0 or cpu')
62
+ parser.add_argument('--dims', type=int, default=2048,
63
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
64
+ help=('Dimensionality of Inception features to use. '
65
+ 'By default, uses pool3 features'))
66
+ parser.add_argument('--save-stats', action='store_true',
67
+ help=('Generate an npz archive from a directory of samples. '
68
+ 'The first path is used as input and the second as output.'))
69
+ parser.add_argument('path', type=str, nargs=2,
70
+ help=('Paths to the generated images or '
71
+ 'to .npz statistic files'))
72
+
73
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
74
+ 'tif', 'tiff', 'webp'}
75
+
76
+
77
+ class ImagePathDataset(torch.utils.data.Dataset):
78
+ def __init__(self, files, transforms=None):
79
+ self.files = files
80
+ self.transforms = transforms
81
+
82
+ def __len__(self):
83
+ return len(self.files)
84
+
85
+ def __getitem__(self, i):
86
+ path = self.files[i]
87
+ img = Image.open(path).convert('RGB')
88
+ if self.transforms is not None:
89
+ img = self.transforms(img)
90
+ return img
91
+
92
+
93
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
94
+ num_workers=1):
95
+ """Calculates the activations of the pool_3 layer for all images.
96
+
97
+ Params:
98
+ -- files : List of image files paths
99
+ -- model : Instance of inception model
100
+ -- batch_size : Batch size of images for the model to process at once.
101
+ Make sure that the number of samples is a multiple of
102
+ the batch size, otherwise some samples are ignored. This
103
+ behavior is retained to match the original FID score
104
+ implementation.
105
+ -- dims : Dimensionality of features returned by Inception
106
+ -- device : Device to run calculations
107
+ -- num_workers : Number of parallel dataloader workers
108
+
109
+ Returns:
110
+ -- A numpy array of dimension (num images, dims) that contains the
111
+ activations of the given tensor when feeding inception with the
112
+ query tensor.
113
+ """
114
+ model.eval()
115
+
116
+ if batch_size > len(files):
117
+ print(('Warning: batch size is bigger than the data size. '
118
+ 'Setting batch size to data size'))
119
+ batch_size = len(files)
120
+
121
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
122
+ dataloader = torch.utils.data.DataLoader(dataset,
123
+ batch_size=batch_size,
124
+ shuffle=False,
125
+ drop_last=False,
126
+ num_workers=num_workers)
127
+
128
+ pred_arr = np.empty((len(files), dims))
129
+
130
+ start_idx = 0
131
+
132
+ for batch in tqdm(dataloader):
133
+ batch = batch.to(device)
134
+
135
+ with torch.no_grad():
136
+ pred = model(batch)[0]
137
+
138
+ # If model output is not scalar, apply global spatial average pooling.
139
+ # This happens if you choose a dimensionality not equal 2048.
140
+ if pred.size(2) != 1 or pred.size(3) != 1:
141
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
142
+
143
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
144
+
145
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
146
+
147
+ start_idx = start_idx + pred.shape[0]
148
+
149
+ return pred_arr
150
+
151
+
152
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
153
+ """Numpy implementation of the Frechet Distance.
154
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
155
+ and X_2 ~ N(mu_2, C_2) is
156
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
157
+
158
+ Stable version by Dougal J. Sutherland.
159
+
160
+ Params:
161
+ -- mu1 : Numpy array containing the activations of a layer of the
162
+ inception net (like returned by the function 'get_predictions')
163
+ for generated samples.
164
+ -- mu2 : The sample mean over activations, precalculated on an
165
+ representative data set.
166
+ -- sigma1: The covariance matrix over activations for generated samples.
167
+ -- sigma2: The covariance matrix over activations, precalculated on an
168
+ representative data set.
169
+
170
+ Returns:
171
+ -- : The Frechet Distance.
172
+ """
173
+
174
+ mu1 = np.atleast_1d(mu1)
175
+ mu2 = np.atleast_1d(mu2)
176
+
177
+ sigma1 = np.atleast_2d(sigma1)
178
+ sigma2 = np.atleast_2d(sigma2)
179
+
180
+ assert mu1.shape == mu2.shape, \
181
+ 'Training and test mean vectors have different lengths'
182
+ assert sigma1.shape == sigma2.shape, \
183
+ 'Training and test covariances have different dimensions'
184
+
185
+ diff = mu1 - mu2
186
+
187
+ # Product might be almost singular
188
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
189
+ if not np.isfinite(covmean).all():
190
+ msg = ('fid calculation produces singular product; '
191
+ 'adding %s to diagonal of cov estimates') % eps
192
+ print(msg)
193
+ offset = np.eye(sigma1.shape[0]) * eps
194
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
195
+
196
+ # Numerical error might give slight imaginary component
197
+ if np.iscomplexobj(covmean):
198
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
199
+ m = np.max(np.abs(covmean.imag))
200
+ raise ValueError('Imaginary component {}'.format(m))
201
+ covmean = covmean.real
202
+
203
+ tr_covmean = np.trace(covmean)
204
+
205
+ return (diff.dot(diff) + np.trace(sigma1)
206
+ + np.trace(sigma2) - 2 * tr_covmean)
207
+
208
+
209
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
210
+ device='cpu', num_workers=1):
211
+ """Calculation of the statistics used by the FID.
212
+ Params:
213
+ -- files : List of image files paths
214
+ -- model : Instance of inception model
215
+ -- batch_size : The images numpy array is split into batches with
216
+ batch size batch_size. A reasonable batch size
217
+ depends on the hardware.
218
+ -- dims : Dimensionality of features returned by Inception
219
+ -- device : Device to run calculations
220
+ -- num_workers : Number of parallel dataloader workers
221
+
222
+ Returns:
223
+ -- mu : The mean over samples of the activations of the pool_3 layer of
224
+ the inception model.
225
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
226
+ the inception model.
227
+ """
228
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
229
+ mu = np.mean(act, axis=0)
230
+ sigma = np.cov(act, rowvar=False)
231
+ return mu, sigma
232
+
233
+
234
+ def compute_statistics_of_path(path, model, batch_size, dims, device,
235
+ num_workers=1):
236
+ if path.endswith('.npz'):
237
+ with np.load(path) as f:
238
+ m, s = f['mu'][:], f['sigma'][:]
239
+ else:
240
+ path = pathlib.Path(path)
241
+ files = sorted([file for ext in IMAGE_EXTENSIONS
242
+ for file in path.glob('*.{}'.format(ext))])
243
+ m, s = calculate_activation_statistics(files, model, batch_size,
244
+ dims, device, num_workers)
245
+
246
+ return m, s
247
+
248
+
249
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
250
+ """Calculates the FID of two paths"""
251
+ for p in paths:
252
+ if not os.path.exists(p):
253
+ raise RuntimeError('Invalid path: %s' % p)
254
+
255
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
256
+
257
+ model = InceptionV3([block_idx]).to(device)
258
+
259
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
260
+ dims, device, num_workers)
261
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
262
+ dims, device, num_workers)
263
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
264
+
265
+ return fid_value
266
+
267
+
268
+ def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
269
+ """Calculates the FID of two paths"""
270
+ if not os.path.exists(paths[0]):
271
+ raise RuntimeError('Invalid path: %s' % paths[0])
272
+
273
+ if os.path.exists(paths[1]):
274
+ raise RuntimeError('Existing output file: %s' % paths[1])
275
+
276
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
277
+
278
+ model = InceptionV3([block_idx]).to(device)
279
+
280
+ print(f"Saving statistics for {paths[0]}")
281
+
282
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
283
+ dims, device, num_workers)
284
+
285
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
286
+
287
+
288
+ def main():
289
+ args = parser.parse_args()
290
+
291
+ if args.device is None:
292
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
293
+ else:
294
+ device = torch.device(args.device)
295
+
296
+ if args.num_workers is None:
297
+ try:
298
+ num_cpus = len(os.sched_getaffinity(0))
299
+ except AttributeError:
300
+ # os.sched_getaffinity is not available under Windows, use
301
+ # os.cpu_count instead (which may not return the *available* number
302
+ # of CPUs).
303
+ num_cpus = os.cpu_count()
304
+
305
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
306
+ else:
307
+ num_workers = args.num_workers
308
+
309
+ if args.save_stats:
310
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
311
+ return
312
+
313
+ fid_value = calculate_fid_given_paths(args.path,
314
+ args.batch_size,
315
+ device,
316
+ args.dims,
317
+ num_workers)
318
+ print('FID: ', fid_value)
319
+
320
+
321
+ if __name__ == '__main__':
322
+ main()
libs/metric/pytorch_fid/inception.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = _inception_v3(weights='DEFAULT')
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def _inception_v3(*args, **kwargs):
167
+ """Wraps `torchvision.models.inception_v3`"""
168
+ try:
169
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
170
+ except ValueError:
171
+ # Just a caution against weird version strings
172
+ version = (0,)
173
+
174
+ # Skips default weight inititialization if supported by torchvision
175
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
176
+ if version >= (0, 6):
177
+ kwargs['init_weights'] = False
178
+
179
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
180
+ # argument prior to version 0.13.
181
+ if version < (0, 13) and 'weights' in kwargs:
182
+ if kwargs['weights'] == 'DEFAULT':
183
+ kwargs['pretrained'] = True
184
+ elif kwargs['weights'] is None:
185
+ kwargs['pretrained'] = False
186
+ else:
187
+ raise ValueError(
188
+ 'weights=={} not supported in torchvision {}'.format(
189
+ kwargs['weights'], torchvision.__version__
190
+ )
191
+ )
192
+ del kwargs['weights']
193
+
194
+ return torchvision.models.inception_v3(*args, **kwargs)
195
+
196
+
197
+ def fid_inception_v3():
198
+ """Build pretrained Inception model for FID computation
199
+
200
+ The Inception model for FID computation uses a different set of weights
201
+ and has a slightly different structure than torchvision's Inception.
202
+
203
+ This method first constructs torchvision's Inception and then patches the
204
+ necessary parts that are different in the FID Inception model.
205
+ """
206
+ inception = _inception_v3(num_classes=1008,
207
+ aux_logits=False,
208
+ weights=None)
209
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
210
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
211
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
212
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
213
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
214
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
215
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
216
+ inception.Mixed_7b = FIDInceptionE_1(1280)
217
+ inception.Mixed_7c = FIDInceptionE_2(2048)
218
+
219
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
220
+ inception.load_state_dict(state_dict)
221
+ return inception
222
+
223
+
224
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
225
+ """InceptionA block patched for FID computation"""
226
+ def __init__(self, in_channels, pool_features):
227
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
228
+
229
+ def forward(self, x):
230
+ branch1x1 = self.branch1x1(x)
231
+
232
+ branch5x5 = self.branch5x5_1(x)
233
+ branch5x5 = self.branch5x5_2(branch5x5)
234
+
235
+ branch3x3dbl = self.branch3x3dbl_1(x)
236
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
237
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
238
+
239
+ # Patch: Tensorflow's average pool does not use the padded zero's in
240
+ # its average calculation
241
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
242
+ count_include_pad=False)
243
+ branch_pool = self.branch_pool(branch_pool)
244
+
245
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
246
+ return torch.cat(outputs, 1)
247
+
248
+
249
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
250
+ """InceptionC block patched for FID computation"""
251
+ def __init__(self, in_channels, channels_7x7):
252
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
253
+
254
+ def forward(self, x):
255
+ branch1x1 = self.branch1x1(x)
256
+
257
+ branch7x7 = self.branch7x7_1(x)
258
+ branch7x7 = self.branch7x7_2(branch7x7)
259
+ branch7x7 = self.branch7x7_3(branch7x7)
260
+
261
+ branch7x7dbl = self.branch7x7dbl_1(x)
262
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
263
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
264
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
265
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
266
+
267
+ # Patch: Tensorflow's average pool does not use the padded zero's in
268
+ # its average calculation
269
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
270
+ count_include_pad=False)
271
+ branch_pool = self.branch_pool(branch_pool)
272
+
273
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
274
+ return torch.cat(outputs, 1)
275
+
276
+
277
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
278
+ """First InceptionE block patched for FID computation"""
279
+ def __init__(self, in_channels):
280
+ super(FIDInceptionE_1, self).__init__(in_channels)
281
+
282
+ def forward(self, x):
283
+ branch1x1 = self.branch1x1(x)
284
+
285
+ branch3x3 = self.branch3x3_1(x)
286
+ branch3x3 = [
287
+ self.branch3x3_2a(branch3x3),
288
+ self.branch3x3_2b(branch3x3),
289
+ ]
290
+ branch3x3 = torch.cat(branch3x3, 1)
291
+
292
+ branch3x3dbl = self.branch3x3dbl_1(x)
293
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
294
+ branch3x3dbl = [
295
+ self.branch3x3dbl_3a(branch3x3dbl),
296
+ self.branch3x3dbl_3b(branch3x3dbl),
297
+ ]
298
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
299
+
300
+ # Patch: Tensorflow's average pool does not use the padded zero's in
301
+ # its average calculation
302
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
303
+ count_include_pad=False)
304
+ branch_pool = self.branch_pool(branch_pool)
305
+
306
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
307
+ return torch.cat(outputs, 1)
308
+
309
+
310
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
311
+ """Second InceptionE block patched for FID computation"""
312
+ def __init__(self, in_channels):
313
+ super(FIDInceptionE_2, self).__init__(in_channels)
314
+
315
+ def forward(self, x):
316
+ branch1x1 = self.branch1x1(x)
317
+
318
+ branch3x3 = self.branch3x3_1(x)
319
+ branch3x3 = [
320
+ self.branch3x3_2a(branch3x3),
321
+ self.branch3x3_2b(branch3x3),
322
+ ]
323
+ branch3x3 = torch.cat(branch3x3, 1)
324
+
325
+ branch3x3dbl = self.branch3x3dbl_1(x)
326
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
327
+ branch3x3dbl = [
328
+ self.branch3x3dbl_3a(branch3x3dbl),
329
+ self.branch3x3dbl_3b(branch3x3dbl),
330
+ ]
331
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
332
+
333
+ # Patch: The FID Inception model uses max pooling instead of average
334
+ # pooling. This is likely an error in this specific Inception
335
+ # implementation, as other Inception models use average pooling here
336
+ # (which matches the description in the paper).
337
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
338
+ branch_pool = self.branch_pool(branch_pool)
339
+
340
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
341
+ return torch.cat(outputs, 1)
libs/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+