YoonaAI commited on
Commit
5bbd1a6
·
1 Parent(s): 3fb8682

Create net/local_affine.py

Browse files
Files changed (1) hide show
  1. lib/net/local_affine.py +57 -0
lib/net/local_affine.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
2
+ # All rights reserved.
3
+ # This file is part of the pytorch-nicp,
4
+ # and is released under the "MIT License Agreement". Please see the LICENSE
5
+ # file that should have been included as part of this package.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.sparse as sp
10
+
11
+ # reference: https://github.com/wuhaozhe/pytorch-nicp
12
+ class LocalAffine(nn.Module):
13
+ def __init__(self, num_points, batch_size=1, edges=None):
14
+ '''
15
+ specify the number of points, the number of points should be constant across the batch
16
+ and the edges torch.Longtensor() with shape N * 2
17
+ the local affine operator supports batch operation
18
+ batch size must be constant
19
+ add additional pooling on top of w matrix
20
+ '''
21
+ super(LocalAffine, self).__init__()
22
+ self.A = nn.Parameter(torch.eye(3).unsqueeze(
23
+ 0).unsqueeze(0).repeat(batch_size, num_points, 1, 1))
24
+ self.b = nn.Parameter(torch.zeros(3).unsqueeze(0).unsqueeze(
25
+ 0).unsqueeze(3).repeat(batch_size, num_points, 1, 1))
26
+ self.edges = edges
27
+ self.num_points = num_points
28
+
29
+ def stiffness(self):
30
+ '''
31
+ calculate the stiffness of local affine transformation
32
+ f norm get infinity gradient when w is zero matrix,
33
+ '''
34
+ if self.edges is None:
35
+ raise Exception("edges cannot be none when calculate stiff")
36
+ idx1 = self.edges[:, 0]
37
+ idx2 = self.edges[:, 1]
38
+ affine_weight = torch.cat((self.A, self.b), dim=3)
39
+ w1 = torch.index_select(affine_weight, dim=1, index=idx1)
40
+ w2 = torch.index_select(affine_weight, dim=1, index=idx2)
41
+ w_diff = (w1 - w2) ** 2
42
+ w_rigid = (torch.linalg.det(self.A) - 1.0) ** 2
43
+ return w_diff, w_rigid
44
+
45
+ def forward(self, x, return_stiff=False):
46
+ '''
47
+ x should have shape of B * N * 3
48
+ '''
49
+ x = x.unsqueeze(3)
50
+ out_x = torch.matmul(self.A, x)
51
+ out_x = out_x + self.b
52
+ out_x.squeeze_(3)
53
+ if return_stiff:
54
+ stiffness, rigid = self.stiffness()
55
+ return out_x, stiffness, rigid
56
+ else:
57
+ return out_x