Spaces:
Build error
Build error
Create net/local_affine.py
Browse files- lib/net/local_affine.py +57 -0
lib/net/local_affine.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
|
2 |
+
# All rights reserved.
|
3 |
+
# This file is part of the pytorch-nicp,
|
4 |
+
# and is released under the "MIT License Agreement". Please see the LICENSE
|
5 |
+
# file that should have been included as part of this package.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.sparse as sp
|
10 |
+
|
11 |
+
# reference: https://github.com/wuhaozhe/pytorch-nicp
|
12 |
+
class LocalAffine(nn.Module):
|
13 |
+
def __init__(self, num_points, batch_size=1, edges=None):
|
14 |
+
'''
|
15 |
+
specify the number of points, the number of points should be constant across the batch
|
16 |
+
and the edges torch.Longtensor() with shape N * 2
|
17 |
+
the local affine operator supports batch operation
|
18 |
+
batch size must be constant
|
19 |
+
add additional pooling on top of w matrix
|
20 |
+
'''
|
21 |
+
super(LocalAffine, self).__init__()
|
22 |
+
self.A = nn.Parameter(torch.eye(3).unsqueeze(
|
23 |
+
0).unsqueeze(0).repeat(batch_size, num_points, 1, 1))
|
24 |
+
self.b = nn.Parameter(torch.zeros(3).unsqueeze(0).unsqueeze(
|
25 |
+
0).unsqueeze(3).repeat(batch_size, num_points, 1, 1))
|
26 |
+
self.edges = edges
|
27 |
+
self.num_points = num_points
|
28 |
+
|
29 |
+
def stiffness(self):
|
30 |
+
'''
|
31 |
+
calculate the stiffness of local affine transformation
|
32 |
+
f norm get infinity gradient when w is zero matrix,
|
33 |
+
'''
|
34 |
+
if self.edges is None:
|
35 |
+
raise Exception("edges cannot be none when calculate stiff")
|
36 |
+
idx1 = self.edges[:, 0]
|
37 |
+
idx2 = self.edges[:, 1]
|
38 |
+
affine_weight = torch.cat((self.A, self.b), dim=3)
|
39 |
+
w1 = torch.index_select(affine_weight, dim=1, index=idx1)
|
40 |
+
w2 = torch.index_select(affine_weight, dim=1, index=idx2)
|
41 |
+
w_diff = (w1 - w2) ** 2
|
42 |
+
w_rigid = (torch.linalg.det(self.A) - 1.0) ** 2
|
43 |
+
return w_diff, w_rigid
|
44 |
+
|
45 |
+
def forward(self, x, return_stiff=False):
|
46 |
+
'''
|
47 |
+
x should have shape of B * N * 3
|
48 |
+
'''
|
49 |
+
x = x.unsqueeze(3)
|
50 |
+
out_x = torch.matmul(self.A, x)
|
51 |
+
out_x = out_x + self.b
|
52 |
+
out_x.squeeze_(3)
|
53 |
+
if return_stiff:
|
54 |
+
stiffness, rigid = self.stiffness()
|
55 |
+
return out_x, stiffness, rigid
|
56 |
+
else:
|
57 |
+
return out_x
|