juwaeze commited on
Commit
f8c51d7
·
verified ·
1 Parent(s): e149e7f

Upload 8 files

Browse files
lpips_pytorch/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .modules.lpips import LPIPS
4
+
5
+
6
+ def lpips(x: torch.Tensor,
7
+ y: torch.Tensor,
8
+ net_type: str = 'alex',
9
+ version: str = '0.1'):
10
+ r"""Function that measures
11
+ Learned Perceptual Image Patch Similarity (LPIPS).
12
+
13
+ Arguments:
14
+ x, y (torch.Tensor): the input tensors to compare.
15
+ net_type (str): the network type to compare the features:
16
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17
+ version (str): the version of LPIPS. Default: 0.1.
18
+ """
19
+ device = x.device
20
+ criterion = LPIPS(net_type, version).to(device)
21
+ return criterion(x, y)
lpips_pytorch/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (863 Bytes). View file
 
lpips_pytorch/modules/__pycache__/lpips.cpython-38.pyc ADDED
Binary file (1.83 kB). View file
 
lpips_pytorch/modules/__pycache__/networks.cpython-38.pyc ADDED
Binary file (3.88 kB). View file
 
lpips_pytorch/modules/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.07 kB). View file
 
lpips_pytorch/modules/lpips.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .networks import get_network, LinLayers
5
+ from .utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+
12
+ Arguments:
13
+ net_type (str): the network type to compare the features:
14
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15
+ version (str): the version of LPIPS. Default: 0.1.
16
+ """
17
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18
+
19
+ assert version in ['0.1'], 'v0.1 is only supported now'
20
+
21
+ super(LPIPS, self).__init__()
22
+
23
+ # pretrained network
24
+ self.net = get_network(net_type)
25
+
26
+ # linear layers
27
+ self.lin = LinLayers(self.net.n_channels_list)
28
+ self.lin.load_state_dict(get_state_dict(net_type, version))
29
+
30
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
31
+ feat_x, feat_y = self.net(x), self.net(y)
32
+
33
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35
+
36
+ return torch.sum(torch.cat(res, 0), 0, True)
lpips_pytorch/modules/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from .utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
lpips_pytorch/modules/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict