Zhang Jiaqi commited on
Commit
1f2fedd
1 Parent(s): 4c65532

SketchModeling

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/instant-mesh-large.yaml +22 -0
  2. src/__pycache__/BackgroundRemove.cpython-312.pyc +0 -0
  3. src/__pycache__/ImageToModel.cpython-312.pyc +0 -0
  4. src/__pycache__/SketchToImage.cpython-312.pyc +0 -0
  5. src/models/__init__.py +0 -0
  6. src/models/__pycache__/__init__.cpython-312.pyc +0 -0
  7. src/models/__pycache__/lrm_mesh.cpython-312.pyc +0 -0
  8. src/models/decoder/__init__.py +0 -0
  9. src/models/decoder/__pycache__/__init__.cpython-312.pyc +0 -0
  10. src/models/decoder/__pycache__/transformer.cpython-312.pyc +0 -0
  11. src/models/decoder/transformer.py +123 -0
  12. src/models/encoder/__init__.py +0 -0
  13. src/models/encoder/__pycache__/__init__.cpython-312.pyc +0 -0
  14. src/models/encoder/__pycache__/dino.cpython-312.pyc +0 -0
  15. src/models/encoder/__pycache__/dino_wrapper.cpython-312.pyc +0 -0
  16. src/models/encoder/dino.py +550 -0
  17. src/models/encoder/dino_wrapper.py +80 -0
  18. src/models/geometry/__init__.py +7 -0
  19. src/models/geometry/__pycache__/__init__.cpython-312.pyc +0 -0
  20. src/models/geometry/camera/__init__.py +16 -0
  21. src/models/geometry/camera/__pycache__/__init__.cpython-312.pyc +0 -0
  22. src/models/geometry/camera/__pycache__/perspective_camera.cpython-312.pyc +0 -0
  23. src/models/geometry/camera/perspective_camera.py +35 -0
  24. src/models/geometry/render/__init__.py +8 -0
  25. src/models/geometry/render/__pycache__/__init__.cpython-312.pyc +0 -0
  26. src/models/geometry/render/__pycache__/neural_render.cpython-312.pyc +0 -0
  27. src/models/geometry/render/neural_render.py +121 -0
  28. src/models/geometry/rep_3d/__init__.py +18 -0
  29. src/models/geometry/rep_3d/__pycache__/__init__.cpython-312.pyc +0 -0
  30. src/models/geometry/rep_3d/__pycache__/dmtet.cpython-312.pyc +0 -0
  31. src/models/geometry/rep_3d/__pycache__/dmtet_utils.cpython-312.pyc +0 -0
  32. src/models/geometry/rep_3d/__pycache__/flexicubes.cpython-312.pyc +0 -0
  33. src/models/geometry/rep_3d/__pycache__/flexicubes_geometry.cpython-312.pyc +0 -0
  34. src/models/geometry/rep_3d/__pycache__/tables.cpython-312.pyc +0 -0
  35. src/models/geometry/rep_3d/dmtet.py +504 -0
  36. src/models/geometry/rep_3d/dmtet_utils.py +20 -0
  37. src/models/geometry/rep_3d/extract_texture_map.py +40 -0
  38. src/models/geometry/rep_3d/flexicubes.py +579 -0
  39. src/models/geometry/rep_3d/flexicubes_geometry.py +120 -0
  40. src/models/geometry/rep_3d/tables.py +791 -0
  41. src/models/lrm.py +209 -0
  42. src/models/lrm_mesh.py +382 -0
  43. src/models/renderer/__init__.py +9 -0
  44. src/models/renderer/__pycache__/__init__.cpython-312.pyc +0 -0
  45. src/models/renderer/__pycache__/synthesizer_mesh.cpython-312.pyc +0 -0
  46. src/models/renderer/synthesizer.py +203 -0
  47. src/models/renderer/synthesizer_mesh.py +141 -0
  48. src/models/renderer/utils/__init__.py +9 -0
  49. src/models/renderer/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  50. src/models/renderer/utils/__pycache__/math_utils.cpython-312.pyc +0 -0
configs/instant-mesh-large.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_large.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
src/__pycache__/BackgroundRemove.cpython-312.pyc ADDED
Binary file (454 Bytes). View file
 
src/__pycache__/ImageToModel.cpython-312.pyc ADDED
Binary file (3.49 kB). View file
 
src/__pycache__/SketchToImage.cpython-312.pyc ADDED
Binary file (995 Bytes). View file
 
src/models/__init__.py ADDED
File without changes
src/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (146 Bytes). View file
 
src/models/__pycache__/lrm_mesh.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
src/models/decoder/__init__.py ADDED
File without changes
src/models/decoder/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (154 Bytes). View file
 
src/models/decoder/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (5.51 kB). View file
 
src/models/decoder/transformer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class BasicTransformerBlock(nn.Module):
21
+ """
22
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
23
+ """
24
+ # use attention from torch.nn.MultiHeadAttention
25
+ # Block contains a cross-attention layer, a self-attention layer, and a MLP
26
+ def __init__(
27
+ self,
28
+ inner_dim: int,
29
+ cond_dim: int,
30
+ num_heads: int,
31
+ eps: float,
32
+ attn_drop: float = 0.,
33
+ attn_bias: bool = False,
34
+ mlp_ratio: float = 4.,
35
+ mlp_drop: float = 0.,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm1 = nn.LayerNorm(inner_dim)
40
+ self.cross_attn = nn.MultiheadAttention(
41
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
42
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
43
+ self.norm2 = nn.LayerNorm(inner_dim)
44
+ self.self_attn = nn.MultiheadAttention(
45
+ embed_dim=inner_dim, num_heads=num_heads,
46
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
47
+ self.norm3 = nn.LayerNorm(inner_dim)
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
50
+ nn.GELU(),
51
+ nn.Dropout(mlp_drop),
52
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
53
+ nn.Dropout(mlp_drop),
54
+ )
55
+
56
+ def forward(self, x, cond):
57
+ # x: [N, L, D]
58
+ # cond: [N, L_cond, D_cond]
59
+ x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
+ before_sa = self.norm2(x)
61
+ x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
+ x = x + self.mlp(self.norm3(x))
63
+ return x
64
+
65
+
66
+ class TriplaneTransformer(nn.Module):
67
+ """
68
+ Transformer with condition that generates a triplane representation.
69
+
70
+ Reference:
71
+ Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
72
+ """
73
+ def __init__(
74
+ self,
75
+ inner_dim: int,
76
+ image_feat_dim: int,
77
+ triplane_low_res: int,
78
+ triplane_high_res: int,
79
+ triplane_dim: int,
80
+ num_layers: int,
81
+ num_heads: int,
82
+ eps: float = 1e-6,
83
+ ):
84
+ super().__init__()
85
+
86
+ # attributes
87
+ self.triplane_low_res = triplane_low_res
88
+ self.triplane_high_res = triplane_high_res
89
+ self.triplane_dim = triplane_dim
90
+
91
+ # modules
92
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
93
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
94
+ self.layers = nn.ModuleList([
95
+ BasicTransformerBlock(
96
+ inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
97
+ for _ in range(num_layers)
98
+ ])
99
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
+ self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
101
+
102
+ def forward(self, image_feats):
103
+ # image_feats: [N, L_cond, D_cond]
104
+
105
+ N = image_feats.shape[0]
106
+ H = W = self.triplane_low_res
107
+ L = 3 * H * W
108
+
109
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
+ for layer in self.layers:
111
+ x = layer(x, image_feats)
112
+ x = self.norm(x)
113
+
114
+ # separate each plane and apply deconv
115
+ x = x.view(N, 3, H, W, -1)
116
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
117
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
118
+ x = self.deconv(x) # [3*N, D', H', W']
119
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
120
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
121
+ x = x.contiguous()
122
+
123
+ return x
src/models/encoder/__init__.py ADDED
File without changes
src/models/encoder/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (154 Bytes). View file
 
src/models/encoder/__pycache__/dino.cpython-312.pyc ADDED
Binary file (31.6 kB). View file
 
src/models/encoder/__pycache__/dino_wrapper.cpython-312.pyc ADDED
Binary file (4.14 kB). View file
 
src/models/encoder/dino.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch ViT model."""
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPooling,
29
+ )
30
+ from transformers import PreTrainedModel, ViTConfig
31
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
32
+
33
+
34
+ class ViTEmbeddings(nn.Module):
35
+ """
36
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
37
+ """
38
+
39
+ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
40
+ super().__init__()
41
+
42
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
43
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
44
+ self.patch_embeddings = ViTPatchEmbeddings(config)
45
+ num_patches = self.patch_embeddings.num_patches
46
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
47
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
48
+ self.config = config
49
+
50
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
51
+ """
52
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
53
+ resolution images.
54
+
55
+ Source:
56
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
57
+ """
58
+
59
+ num_patches = embeddings.shape[1] - 1
60
+ num_positions = self.position_embeddings.shape[1] - 1
61
+ if num_patches == num_positions and height == width:
62
+ return self.position_embeddings
63
+ class_pos_embed = self.position_embeddings[:, 0]
64
+ patch_pos_embed = self.position_embeddings[:, 1:]
65
+ dim = embeddings.shape[-1]
66
+ h0 = height // self.config.patch_size
67
+ w0 = width // self.config.patch_size
68
+ # we add a small number to avoid floating point error in the interpolation
69
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
70
+ h0, w0 = h0 + 0.1, w0 + 0.1
71
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
72
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
73
+ patch_pos_embed = nn.functional.interpolate(
74
+ patch_pos_embed,
75
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
76
+ mode="bicubic",
77
+ align_corners=False,
78
+ )
79
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
80
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
+
83
+ def forward(
84
+ self,
85
+ pixel_values: torch.Tensor,
86
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
87
+ interpolate_pos_encoding: bool = False,
88
+ ) -> torch.Tensor:
89
+ batch_size, num_channels, height, width = pixel_values.shape
90
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
91
+
92
+ if bool_masked_pos is not None:
93
+ seq_length = embeddings.shape[1]
94
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
95
+ # replace the masked visual tokens by mask_tokens
96
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
97
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
98
+
99
+ # add the [CLS] token to the embedded patch tokens
100
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
101
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
102
+
103
+ # add positional encoding to each token
104
+ if interpolate_pos_encoding:
105
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
106
+ else:
107
+ embeddings = embeddings + self.position_embeddings
108
+
109
+ embeddings = self.dropout(embeddings)
110
+
111
+ return embeddings
112
+
113
+
114
+ class ViTPatchEmbeddings(nn.Module):
115
+ """
116
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
117
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
118
+ Transformer.
119
+ """
120
+
121
+ def __init__(self, config):
122
+ super().__init__()
123
+ image_size, patch_size = config.image_size, config.patch_size
124
+ num_channels, hidden_size = config.num_channels, config.hidden_size
125
+
126
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
127
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
128
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
129
+ self.image_size = image_size
130
+ self.patch_size = patch_size
131
+ self.num_channels = num_channels
132
+ self.num_patches = num_patches
133
+
134
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
135
+
136
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
137
+ batch_size, num_channels, height, width = pixel_values.shape
138
+ if num_channels != self.num_channels:
139
+ raise ValueError(
140
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
141
+ f" Expected {self.num_channels} but got {num_channels}."
142
+ )
143
+ if not interpolate_pos_encoding:
144
+ if height != self.image_size[0] or width != self.image_size[1]:
145
+ raise ValueError(
146
+ f"Input image size ({height}*{width}) doesn't match model"
147
+ f" ({self.image_size[0]}*{self.image_size[1]})."
148
+ )
149
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
150
+ return embeddings
151
+
152
+
153
+ class ViTSelfAttention(nn.Module):
154
+ def __init__(self, config: ViTConfig) -> None:
155
+ super().__init__()
156
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
+ raise ValueError(
158
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
159
+ f"heads {config.num_attention_heads}."
160
+ )
161
+
162
+ self.num_attention_heads = config.num_attention_heads
163
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
164
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
165
+
166
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
167
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
168
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
169
+
170
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
171
+
172
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
173
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
174
+ x = x.view(new_x_shape)
175
+ return x.permute(0, 2, 1, 3)
176
+
177
+ def forward(
178
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
179
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
180
+ mixed_query_layer = self.query(hidden_states)
181
+
182
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
183
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
184
+ query_layer = self.transpose_for_scores(mixed_query_layer)
185
+
186
+ # Take the dot product between "query" and "key" to get the raw attention scores.
187
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
188
+
189
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
190
+
191
+ # Normalize the attention scores to probabilities.
192
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
193
+
194
+ # This is actually dropping out entire tokens to attend to, which might
195
+ # seem a bit unusual, but is taken from the original Transformer paper.
196
+ attention_probs = self.dropout(attention_probs)
197
+
198
+ # Mask heads if we want to
199
+ if head_mask is not None:
200
+ attention_probs = attention_probs * head_mask
201
+
202
+ context_layer = torch.matmul(attention_probs, value_layer)
203
+
204
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
205
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
206
+ context_layer = context_layer.view(new_context_layer_shape)
207
+
208
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
209
+
210
+ return outputs
211
+
212
+
213
+ class ViTSelfOutput(nn.Module):
214
+ """
215
+ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
216
+ layernorm applied before each block.
217
+ """
218
+
219
+ def __init__(self, config: ViTConfig) -> None:
220
+ super().__init__()
221
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+
224
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
225
+ hidden_states = self.dense(hidden_states)
226
+ hidden_states = self.dropout(hidden_states)
227
+
228
+ return hidden_states
229
+
230
+
231
+ class ViTAttention(nn.Module):
232
+ def __init__(self, config: ViTConfig) -> None:
233
+ super().__init__()
234
+ self.attention = ViTSelfAttention(config)
235
+ self.output = ViTSelfOutput(config)
236
+ self.pruned_heads = set()
237
+
238
+ def prune_heads(self, heads: Set[int]) -> None:
239
+ if len(heads) == 0:
240
+ return
241
+ heads, index = find_pruneable_heads_and_indices(
242
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
243
+ )
244
+
245
+ # Prune linear layers
246
+ self.attention.query = prune_linear_layer(self.attention.query, index)
247
+ self.attention.key = prune_linear_layer(self.attention.key, index)
248
+ self.attention.value = prune_linear_layer(self.attention.value, index)
249
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
250
+
251
+ # Update hyper params and store pruned heads
252
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
253
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
254
+ self.pruned_heads = self.pruned_heads.union(heads)
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.Tensor,
259
+ head_mask: Optional[torch.Tensor] = None,
260
+ output_attentions: bool = False,
261
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
262
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
263
+
264
+ attention_output = self.output(self_outputs[0], hidden_states)
265
+
266
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
267
+ return outputs
268
+
269
+
270
+ class ViTIntermediate(nn.Module):
271
+ def __init__(self, config: ViTConfig) -> None:
272
+ super().__init__()
273
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
274
+ if isinstance(config.hidden_act, str):
275
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
276
+ else:
277
+ self.intermediate_act_fn = config.hidden_act
278
+
279
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
280
+ hidden_states = self.dense(hidden_states)
281
+ hidden_states = self.intermediate_act_fn(hidden_states)
282
+
283
+ return hidden_states
284
+
285
+
286
+ class ViTOutput(nn.Module):
287
+ def __init__(self, config: ViTConfig) -> None:
288
+ super().__init__()
289
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
+
292
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+
296
+ hidden_states = hidden_states + input_tensor
297
+
298
+ return hidden_states
299
+
300
+
301
+ def modulate(x, shift, scale):
302
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
303
+
304
+
305
+ class ViTLayer(nn.Module):
306
+ """This corresponds to the Block class in the timm implementation."""
307
+
308
+ def __init__(self, config: ViTConfig) -> None:
309
+ super().__init__()
310
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
311
+ self.seq_len_dim = 1
312
+ self.attention = ViTAttention(config)
313
+ self.intermediate = ViTIntermediate(config)
314
+ self.output = ViTOutput(config)
315
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
316
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
+
318
+ self.adaLN_modulation = nn.Sequential(
319
+ nn.SiLU(),
320
+ nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
321
+ )
322
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
323
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states: torch.Tensor,
328
+ adaln_input: torch.Tensor = None,
329
+ head_mask: Optional[torch.Tensor] = None,
330
+ output_attentions: bool = False,
331
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
332
+ shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
333
+
334
+ self_attention_outputs = self.attention(
335
+ modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
336
+ head_mask,
337
+ output_attentions=output_attentions,
338
+ )
339
+ attention_output = self_attention_outputs[0]
340
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
341
+
342
+ # first residual connection
343
+ hidden_states = attention_output + hidden_states
344
+
345
+ # in ViT, layernorm is also applied after self-attention
346
+ layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
347
+ layer_output = self.intermediate(layer_output)
348
+
349
+ # second residual connection is done here
350
+ layer_output = self.output(layer_output, hidden_states)
351
+
352
+ outputs = (layer_output,) + outputs
353
+
354
+ return outputs
355
+
356
+
357
+ class ViTEncoder(nn.Module):
358
+ def __init__(self, config: ViTConfig) -> None:
359
+ super().__init__()
360
+ self.config = config
361
+ self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
362
+ self.gradient_checkpointing = False
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ adaln_input: torch.Tensor = None,
368
+ head_mask: Optional[torch.Tensor] = None,
369
+ output_attentions: bool = False,
370
+ output_hidden_states: bool = False,
371
+ return_dict: bool = True,
372
+ ) -> Union[tuple, BaseModelOutput]:
373
+ all_hidden_states = () if output_hidden_states else None
374
+ all_self_attentions = () if output_attentions else None
375
+
376
+ for i, layer_module in enumerate(self.layer):
377
+ if output_hidden_states:
378
+ all_hidden_states = all_hidden_states + (hidden_states,)
379
+
380
+ layer_head_mask = head_mask[i] if head_mask is not None else None
381
+
382
+ if self.gradient_checkpointing and self.training:
383
+ layer_outputs = self._gradient_checkpointing_func(
384
+ layer_module.__call__,
385
+ hidden_states,
386
+ adaln_input,
387
+ layer_head_mask,
388
+ output_attentions,
389
+ )
390
+ else:
391
+ layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
392
+
393
+ hidden_states = layer_outputs[0]
394
+
395
+ if output_attentions:
396
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
397
+
398
+ if output_hidden_states:
399
+ all_hidden_states = all_hidden_states + (hidden_states,)
400
+
401
+ if not return_dict:
402
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
403
+ return BaseModelOutput(
404
+ last_hidden_state=hidden_states,
405
+ hidden_states=all_hidden_states,
406
+ attentions=all_self_attentions,
407
+ )
408
+
409
+
410
+ class ViTPreTrainedModel(PreTrainedModel):
411
+ """
412
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
+ models.
414
+ """
415
+
416
+ config_class = ViTConfig
417
+ base_model_prefix = "vit"
418
+ main_input_name = "pixel_values"
419
+ supports_gradient_checkpointing = True
420
+ _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
421
+
422
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
423
+ """Initialize the weights"""
424
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
425
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
426
+ # `trunc_normal_cpu` not implemented in `half` issues
427
+ module.weight.data = nn.init.trunc_normal_(
428
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
429
+ ).to(module.weight.dtype)
430
+ if module.bias is not None:
431
+ module.bias.data.zero_()
432
+ elif isinstance(module, nn.LayerNorm):
433
+ module.bias.data.zero_()
434
+ module.weight.data.fill_(1.0)
435
+ elif isinstance(module, ViTEmbeddings):
436
+ module.position_embeddings.data = nn.init.trunc_normal_(
437
+ module.position_embeddings.data.to(torch.float32),
438
+ mean=0.0,
439
+ std=self.config.initializer_range,
440
+ ).to(module.position_embeddings.dtype)
441
+
442
+ module.cls_token.data = nn.init.trunc_normal_(
443
+ module.cls_token.data.to(torch.float32),
444
+ mean=0.0,
445
+ std=self.config.initializer_range,
446
+ ).to(module.cls_token.dtype)
447
+
448
+
449
+ class ViTModel(ViTPreTrainedModel):
450
+ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
451
+ super().__init__(config)
452
+ self.config = config
453
+
454
+ self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
455
+ self.encoder = ViTEncoder(config)
456
+
457
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
+ self.pooler = ViTPooler(config) if add_pooling_layer else None
459
+
460
+ # Initialize weights and apply final processing
461
+ self.post_init()
462
+
463
+ def get_input_embeddings(self) -> ViTPatchEmbeddings:
464
+ return self.embeddings.patch_embeddings
465
+
466
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
467
+ """
468
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
469
+ class PreTrainedModel
470
+ """
471
+ for layer, heads in heads_to_prune.items():
472
+ self.encoder.layer[layer].attention.prune_heads(heads)
473
+
474
+ def forward(
475
+ self,
476
+ pixel_values: Optional[torch.Tensor] = None,
477
+ adaln_input: Optional[torch.Tensor] = None,
478
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
479
+ head_mask: Optional[torch.Tensor] = None,
480
+ output_attentions: Optional[bool] = None,
481
+ output_hidden_states: Optional[bool] = None,
482
+ interpolate_pos_encoding: Optional[bool] = None,
483
+ return_dict: Optional[bool] = None,
484
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
485
+ r"""
486
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
487
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
488
+ """
489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
+ output_hidden_states = (
491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
+ )
493
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494
+
495
+ if pixel_values is None:
496
+ raise ValueError("You have to specify pixel_values")
497
+
498
+ # Prepare head mask if needed
499
+ # 1.0 in head_mask indicate we keep the head
500
+ # attention_probs has shape bsz x n_heads x N x N
501
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
502
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
503
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
504
+
505
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
506
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
507
+ if pixel_values.dtype != expected_dtype:
508
+ pixel_values = pixel_values.to(expected_dtype)
509
+
510
+ embedding_output = self.embeddings(
511
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
512
+ )
513
+
514
+ encoder_outputs = self.encoder(
515
+ embedding_output,
516
+ adaln_input=adaln_input,
517
+ head_mask=head_mask,
518
+ output_attentions=output_attentions,
519
+ output_hidden_states=output_hidden_states,
520
+ return_dict=return_dict,
521
+ )
522
+ sequence_output = encoder_outputs[0]
523
+ sequence_output = self.layernorm(sequence_output)
524
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
525
+
526
+ if not return_dict:
527
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
528
+ return head_outputs + encoder_outputs[1:]
529
+
530
+ return BaseModelOutputWithPooling(
531
+ last_hidden_state=sequence_output,
532
+ pooler_output=pooled_output,
533
+ hidden_states=encoder_outputs.hidden_states,
534
+ attentions=encoder_outputs.attentions,
535
+ )
536
+
537
+
538
+ class ViTPooler(nn.Module):
539
+ def __init__(self, config: ViTConfig):
540
+ super().__init__()
541
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
542
+ self.activation = nn.Tanh()
543
+
544
+ def forward(self, hidden_states):
545
+ # We "pool" the model by simply taking the hidden state corresponding
546
+ # to the first token.
547
+ first_token_tensor = hidden_states[:, 0]
548
+ pooled_output = self.dense(first_token_tensor)
549
+ pooled_output = self.activation(pooled_output)
550
+ return pooled_output
src/models/encoder/dino_wrapper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch.nn as nn
17
+ from transformers import ViTImageProcessor
18
+ from einops import rearrange, repeat
19
+ from .dino import ViTModel
20
+
21
+
22
+ class DinoWrapper(nn.Module):
23
+ """
24
+ Dino v1 wrapper using huggingface transformer implementation.
25
+ """
26
+ def __init__(self, model_name: str, freeze: bool = True):
27
+ super().__init__()
28
+ self.model, self.processor = self._build_dino(model_name)
29
+ self.camera_embedder = nn.Sequential(
30
+ nn.Linear(16, self.model.config.hidden_size, bias=True),
31
+ nn.SiLU(),
32
+ nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
33
+ )
34
+ if freeze:
35
+ self._freeze()
36
+
37
+ def forward(self, image, camera):
38
+ # image: [B, N, C, H, W]
39
+ # camera: [B, N, D]
40
+ # RGB image with [0,1] scale and properly sized
41
+ if image.ndim == 5:
42
+ image = rearrange(image, 'b n c h w -> (b n) c h w')
43
+ dtype = image.dtype
44
+ inputs = self.processor(
45
+ images=image.float(),
46
+ return_tensors="pt",
47
+ do_rescale=False,
48
+ do_resize=False,
49
+ ).to(self.model.device).to(dtype)
50
+ # embed camera
51
+ N = camera.shape[1]
52
+ camera_embeddings = self.camera_embedder(camera)
53
+ camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
54
+ embeddings = camera_embeddings
55
+ # This resampling of positional embedding uses bicubic interpolation
56
+ outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
57
+ last_hidden_states = outputs.last_hidden_state
58
+ return last_hidden_states
59
+
60
+ def _freeze(self):
61
+ print(f"======== Freezing DinoWrapper ========")
62
+ self.model.eval()
63
+ for name, param in self.model.named_parameters():
64
+ param.requires_grad = False
65
+
66
+ @staticmethod
67
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
68
+ import requests
69
+ try:
70
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
71
+ processor = ViTImageProcessor.from_pretrained(model_name)
72
+ return model, processor
73
+ except requests.exceptions.ProxyError as err:
74
+ if proxy_error_retries > 0:
75
+ print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
76
+ import time
77
+ time.sleep(proxy_error_cooldown)
78
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
79
+ else:
80
+ raise err
src/models/geometry/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
src/models/geometry/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (155 Bytes). View file
 
src/models/geometry/camera/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class Camera(nn.Module):
14
+ def __init__(self):
15
+ super(Camera, self).__init__()
16
+ pass
src/models/geometry/camera/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (646 Bytes). View file
 
src/models/geometry/camera/__pycache__/perspective_camera.cpython-312.pyc ADDED
Binary file (2.1 kB). View file
 
src/models/geometry/camera/perspective_camera.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from . import Camera
11
+ import numpy as np
12
+
13
+
14
+ def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
15
+ if near_plane is None:
16
+ near_plane = n
17
+ return np.array(
18
+ [[n / x, 0, 0, 0],
19
+ [0, n / -x, 0, 0],
20
+ [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
21
+ [0, 0, -1, 0]]).astype(np.float32)
22
+
23
+
24
+ class PerspectiveCamera(Camera):
25
+ def __init__(self, fovy=49.0, device='cuda'):
26
+ super(PerspectiveCamera, self).__init__()
27
+ self.device = device
28
+ focal = np.tan(fovy / 180.0 * np.pi * 0.5)
29
+ self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
30
+
31
+ def project(self, points_bxnx4):
32
+ out = torch.matmul(
33
+ points_bxnx4,
34
+ torch.transpose(self.proj_mtx, 1, 2))
35
+ return out
src/models/geometry/render/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Renderer():
4
+ def __init__(self):
5
+ pass
6
+
7
+ def forward(self):
8
+ pass
src/models/geometry/render/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (586 Bytes). View file
 
src/models/geometry/render/__pycache__/neural_render.cpython-312.pyc ADDED
Binary file (6.56 kB). View file
 
src/models/geometry/render/neural_render.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import nvdiffrast.torch as dr
12
+ from . import Renderer
13
+
14
+ _FG_LUT = None
15
+
16
+
17
+ def interpolate(attr, rast, attr_idx, rast_db=None):
18
+ return dr.interpolate(
19
+ attr.contiguous(), rast, attr_idx, rast_db=rast_db,
20
+ diff_attrs=None if rast_db is None else 'all')
21
+
22
+
23
+ def xfm_points(points, matrix, use_python=True):
24
+ '''Transform points.
25
+ Args:
26
+ points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
27
+ matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
28
+ use_python: Use PyTorch's torch.matmul (for validation)
29
+ Returns:
30
+ Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
31
+ '''
32
+ out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
33
+ if torch.is_anomaly_enabled():
34
+ assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
35
+ return out
36
+
37
+
38
+ def dot(x, y):
39
+ return torch.sum(x * y, -1, keepdim=True)
40
+
41
+
42
+ def compute_vertex_normal(v_pos, t_pos_idx):
43
+ i0 = t_pos_idx[:, 0]
44
+ i1 = t_pos_idx[:, 1]
45
+ i2 = t_pos_idx[:, 2]
46
+
47
+ v0 = v_pos[i0, :]
48
+ v1 = v_pos[i1, :]
49
+ v2 = v_pos[i2, :]
50
+
51
+ face_normals = torch.cross(v1 - v0, v2 - v0)
52
+
53
+ # Splat face normals to vertices
54
+ v_nrm = torch.zeros_like(v_pos)
55
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
56
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
57
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
58
+
59
+ # Normalize, replace zero (degenerated) normals with some default value
60
+ v_nrm = torch.where(
61
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
62
+ )
63
+ v_nrm = F.normalize(v_nrm, dim=1)
64
+ assert torch.all(torch.isfinite(v_nrm))
65
+
66
+ return v_nrm
67
+
68
+
69
+ class NeuralRender(Renderer):
70
+ def __init__(self, device='cuda', camera_model=None):
71
+ super(NeuralRender, self).__init__()
72
+ self.device = device
73
+ self.ctx = dr.RasterizeCudaContext(device=device)
74
+ self.projection_mtx = None
75
+ self.camera = camera_model
76
+
77
+ def render_mesh(
78
+ self,
79
+ mesh_v_pos_bxnx3,
80
+ mesh_t_pos_idx_fx3,
81
+ camera_mv_bx4x4,
82
+ mesh_v_feat_bxnxd,
83
+ resolution=256,
84
+ spp=1,
85
+ device='cuda',
86
+ hierarchical_mask=False
87
+ ):
88
+ assert not hierarchical_mask
89
+
90
+ mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
91
+ v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
92
+ v_pos_clip = self.camera.project(v_pos) # Projection in the camera
93
+
94
+ v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
95
+
96
+ # Render the image,
97
+ # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
98
+ num_layers = 1
99
+ mask_pyramid = None
100
+ assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
101
+ mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
102
+
103
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
104
+ for _ in range(num_layers):
105
+ rast, db = peeler.rasterize_next_layer()
106
+ gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
107
+
108
+ hard_mask = torch.clamp(rast[..., -1:], 0, 1)
109
+ antialias_mask = dr.antialias(
110
+ hard_mask.clone().contiguous(), rast, v_pos_clip,
111
+ mesh_t_pos_idx_fx3)
112
+
113
+ depth = gb_feat[..., -2:-1]
114
+ ori_mesh_feature = gb_feat[..., :-4]
115
+
116
+ normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
117
+ normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
118
+ normal = F.normalize(normal, dim=-1)
119
+ normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
120
+
121
+ return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
src/models/geometry/rep_3d/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import numpy as np
11
+
12
+
13
+ class Geometry():
14
+ def __init__(self):
15
+ pass
16
+
17
+ def forward(self):
18
+ pass
src/models/geometry/rep_3d/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (610 Bytes). View file
 
src/models/geometry/rep_3d/__pycache__/dmtet.cpython-312.pyc ADDED
Binary file (28.8 kB). View file
 
src/models/geometry/rep_3d/__pycache__/dmtet_utils.cpython-312.pyc ADDED
Binary file (981 Bytes). View file
 
src/models/geometry/rep_3d/__pycache__/flexicubes.cpython-312.pyc ADDED
Binary file (41.6 kB). View file
 
src/models/geometry/rep_3d/__pycache__/flexicubes_geometry.cpython-312.pyc ADDED
Binary file (6.26 kB). View file
 
src/models/geometry/rep_3d/__pycache__/tables.cpython-312.pyc ADDED
Binary file (32.3 kB). View file
 
src/models/geometry/rep_3d/dmtet.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import numpy as np
11
+ import os
12
+ from . import Geometry
13
+ from .dmtet_utils import get_center_boundary_index
14
+ import torch.nn.functional as F
15
+
16
+
17
+ ###############################################################################
18
+ # DMTet utility functions
19
+ ###############################################################################
20
+ def create_mt_variable(device):
21
+ triangle_table = torch.tensor(
22
+ [
23
+ [-1, -1, -1, -1, -1, -1],
24
+ [1, 0, 2, -1, -1, -1],
25
+ [4, 0, 3, -1, -1, -1],
26
+ [1, 4, 2, 1, 3, 4],
27
+ [3, 1, 5, -1, -1, -1],
28
+ [2, 3, 0, 2, 5, 3],
29
+ [1, 4, 0, 1, 5, 4],
30
+ [4, 2, 5, -1, -1, -1],
31
+ [4, 5, 2, -1, -1, -1],
32
+ [4, 1, 0, 4, 5, 1],
33
+ [3, 2, 0, 3, 5, 2],
34
+ [1, 3, 5, -1, -1, -1],
35
+ [4, 1, 2, 4, 3, 1],
36
+ [3, 0, 4, -1, -1, -1],
37
+ [2, 0, 1, -1, -1, -1],
38
+ [-1, -1, -1, -1, -1, -1]
39
+ ], dtype=torch.long, device=device)
40
+
41
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
42
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
43
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
44
+ return triangle_table, num_triangles_table, base_tet_edges, v_id
45
+
46
+
47
+ def sort_edges(edges_ex2):
48
+ with torch.no_grad():
49
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
50
+ order = order.unsqueeze(dim=1)
51
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
52
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
53
+ return torch.stack([a, b], -1)
54
+
55
+
56
+ ###############################################################################
57
+ # marching tetrahedrons (differentiable)
58
+ ###############################################################################
59
+
60
+ def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id):
61
+ with torch.no_grad():
62
+ occ_n = sdf_n > 0
63
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
64
+ occ_sum = torch.sum(occ_fx4, -1)
65
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
66
+ occ_sum = occ_sum[valid_tets]
67
+
68
+ # find all vertices
69
+ all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
70
+ all_edges = sort_edges(all_edges)
71
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
72
+
73
+ unique_edges = unique_edges.long()
74
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
75
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
76
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
77
+ idx_map = mapping[idx_map] # map edges to verts
78
+
79
+ interp_v = unique_edges[mask_edges] # .long()
80
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
81
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
82
+ edges_to_interp_sdf[:, -1] *= -1
83
+
84
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
85
+
86
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
87
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
88
+
89
+ idx_map = idx_map.reshape(-1, 6)
90
+
91
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
92
+ num_triangles = num_triangles_table[tetindex]
93
+
94
+ # Generate triangle indices
95
+ faces = torch.cat(
96
+ (
97
+ torch.gather(
98
+ input=idx_map[num_triangles == 1], dim=1,
99
+ index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
100
+ torch.gather(
101
+ input=idx_map[num_triangles == 2], dim=1,
102
+ index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
103
+ ), dim=0)
104
+ return verts, faces
105
+
106
+
107
+ def create_tetmesh_variables(device='cuda'):
108
+ tet_table = torch.tensor(
109
+ [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
110
+ [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1],
111
+ [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1],
112
+ [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8],
113
+ [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1],
114
+ [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9],
115
+ [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9],
116
+ [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9],
117
+ [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1],
118
+ [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9],
119
+ [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9],
120
+ [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9],
121
+ [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8],
122
+ [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8],
123
+ [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6],
124
+ [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device)
125
+ num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device)
126
+ return tet_table, num_tets_table
127
+
128
+
129
+ def marching_tets_tetmesh(
130
+ pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
131
+ return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
132
+ with torch.no_grad():
133
+ occ_n = sdf_n > 0
134
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
135
+ occ_sum = torch.sum(occ_fx4, -1)
136
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
137
+ occ_sum = occ_sum[valid_tets]
138
+
139
+ # find all vertices
140
+ all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
141
+ all_edges = sort_edges(all_edges)
142
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
143
+
144
+ unique_edges = unique_edges.long()
145
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
146
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
147
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
148
+ idx_map = mapping[idx_map] # map edges to verts
149
+
150
+ interp_v = unique_edges[mask_edges] # .long()
151
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
152
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
153
+ edges_to_interp_sdf[:, -1] *= -1
154
+
155
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
156
+
157
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
158
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
159
+
160
+ idx_map = idx_map.reshape(-1, 6)
161
+
162
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
163
+ num_triangles = num_triangles_table[tetindex]
164
+
165
+ # Generate triangle indices
166
+ faces = torch.cat(
167
+ (
168
+ torch.gather(
169
+ input=idx_map[num_triangles == 1], dim=1,
170
+ index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
171
+ torch.gather(
172
+ input=idx_map[num_triangles == 2], dim=1,
173
+ index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
174
+ ), dim=0)
175
+ if not return_tet_mesh:
176
+ return verts, faces
177
+ occupied_verts = ori_v[occ_n]
178
+ mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1
179
+ mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda")
180
+ tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4))
181
+
182
+ idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10
183
+ tet_verts = torch.cat([verts, occupied_verts], 0)
184
+ num_tets = num_tets_table[tetindex]
185
+
186
+ tets = torch.cat(
187
+ (
188
+ torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape(
189
+ -1,
190
+ 4),
191
+ torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape(
192
+ -1,
193
+ 4),
194
+ ), dim=0)
195
+ # add fully occupied tets
196
+ fully_occupied = occ_fx4.sum(-1) == 4
197
+ tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0]
198
+ tets = torch.cat([tets, tet_fully_occupied])
199
+
200
+ return verts, faces, tet_verts, tets
201
+
202
+
203
+ ###############################################################################
204
+ # Compact tet grid
205
+ ###############################################################################
206
+
207
+ def compact_tets(pos_nx3, sdf_n, tet_fx4):
208
+ with torch.no_grad():
209
+ # Find surface tets
210
+ occ_n = sdf_n > 0
211
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
212
+ occ_sum = torch.sum(occ_fx4, -1)
213
+ valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets
214
+
215
+ valid_vtx = tet_fx4[valid_tets].reshape(-1)
216
+ unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True)
217
+ new_pos = pos_nx3[unique_vtx]
218
+ new_sdf = sdf_n[unique_vtx]
219
+ new_tets = idx_map.reshape(-1, 4)
220
+ return new_pos, new_sdf, new_tets
221
+
222
+
223
+ ###############################################################################
224
+ # Subdivide volume
225
+ ###############################################################################
226
+
227
+ def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
228
+ device = tet_pos_bxnx3.device
229
+ # get new verts
230
+ tet_fx4 = tet_bxfx4[0]
231
+ edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3]
232
+ all_edges = tet_fx4[:, edges].reshape(-1, 2)
233
+ all_edges = sort_edges(all_edges)
234
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
235
+ idx_map = idx_map + tet_pos_bxnx3.shape[1]
236
+ all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1)
237
+ mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape(
238
+ all_values.shape[0], -1, 2,
239
+ all_values.shape[-1]).mean(2)
240
+ new_v = torch.cat([all_values, mid_points_pos], 1)
241
+ new_v, new_sdf = new_v[..., :3], new_v[..., 3]
242
+
243
+ # get new tets
244
+
245
+ idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3]
246
+ idx_ab = idx_map[0::6]
247
+ idx_ac = idx_map[1::6]
248
+ idx_ad = idx_map[2::6]
249
+ idx_bc = idx_map[3::6]
250
+ idx_bd = idx_map[4::6]
251
+ idx_cd = idx_map[5::6]
252
+
253
+ tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
254
+ tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
255
+ tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
256
+ tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
257
+ tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
258
+ tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
259
+ tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
260
+ tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)
261
+
262
+ tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0)
263
+ tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1)
264
+ tet = tet_np.long().to(device)
265
+
266
+ return new_v, tet, new_sdf
267
+
268
+
269
+ ###############################################################################
270
+ # Adjacency
271
+ ###############################################################################
272
+ def tet_to_tet_adj_sparse(tet_tx4):
273
+ # include self connection!!!!!!!!!!!!!!!!!!!
274
+ with torch.no_grad():
275
+ t = tet_tx4.shape[0]
276
+ device = tet_tx4.device
277
+ idx_array = torch.LongTensor(
278
+ [0, 1, 2,
279
+ 1, 0, 3,
280
+ 2, 3, 0,
281
+ 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3)
282
+
283
+ # get all faces
284
+ all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape(
285
+ -1,
286
+ 3) # (tx4, 3)
287
+ all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1)
288
+ # sort and group
289
+ all_faces_sorted, _ = torch.sort(all_faces, dim=1)
290
+
291
+ all_faces_unique, inverse_indices, counts = torch.unique(
292
+ all_faces_sorted, dim=0, return_counts=True,
293
+ return_inverse=True)
294
+ tet_face_fx3 = all_faces_unique[counts == 2]
295
+ counts = counts[inverse_indices] # tx4
296
+ valid = (counts == 2)
297
+
298
+ group = inverse_indices[valid]
299
+ # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape)
300
+ _, indices = torch.sort(group)
301
+ all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices]
302
+ tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1)
303
+
304
+ tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])])
305
+ adj_self = torch.arange(t, device=tet_tx4.device)
306
+ adj_self = torch.stack([adj_self, adj_self], -1)
307
+ tet_adj_idx = torch.cat([tet_adj_idx, adj_self])
308
+
309
+ tet_adj_idx = torch.unique(tet_adj_idx, dim=0)
310
+ values = torch.ones(
311
+ tet_adj_idx.shape[0], device=tet_tx4.device).float()
312
+ adj_sparse = torch.sparse.FloatTensor(
313
+ tet_adj_idx.t(), values, torch.Size([t, t]))
314
+
315
+ # normalization
316
+ neighbor_num = 1.0 / torch.sparse.sum(
317
+ adj_sparse, dim=1).to_dense()
318
+ values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0])
319
+ adj_sparse = torch.sparse.FloatTensor(
320
+ tet_adj_idx.t(), values, torch.Size([t, t]))
321
+ return adj_sparse
322
+
323
+
324
+ ###############################################################################
325
+ # Compact grid
326
+ ###############################################################################
327
+
328
+ def get_tet_bxfx4x3(bxnxz, bxfx4):
329
+ n_batch, z = bxnxz.shape[0], bxnxz.shape[2]
330
+ gather_input = bxnxz.unsqueeze(2).expand(
331
+ n_batch, bxnxz.shape[1], 4, z)
332
+ gather_index = bxfx4.unsqueeze(-1).expand(
333
+ n_batch, bxfx4.shape[1], 4, z).long()
334
+ tet_bxfx4xz = torch.gather(
335
+ input=gather_input, dim=1, index=gather_index)
336
+
337
+ return tet_bxfx4xz
338
+
339
+
340
+ def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
341
+ with torch.no_grad():
342
+ assert tet_pos_bxnx3.shape[0] == 1
343
+
344
+ occ = grid_sdf[0] > 0
345
+ occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1)
346
+ mask = (occ_sum > 0) & (occ_sum < 4)
347
+
348
+ # build connectivity graph
349
+ adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0])
350
+ mask = mask.float().unsqueeze(-1)
351
+
352
+ # Include a one ring of neighbors
353
+ for i in range(1):
354
+ mask = torch.sparse.mm(adj_matrix, mask)
355
+ mask = mask.squeeze(-1) > 0
356
+
357
+ mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long)
358
+ new_tet_bxfx4 = tet_bxfx4[:, mask].long()
359
+ selected_verts_idx = torch.unique(new_tet_bxfx4)
360
+ new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx]
361
+ mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device)
362
+ new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape)
363
+ new_grid_sdf = grid_sdf[:, selected_verts_idx]
364
+ return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf
365
+
366
+
367
+ ###############################################################################
368
+ # Regularizer
369
+ ###############################################################################
370
+
371
+ def sdf_reg_loss(sdf, all_edges):
372
+ sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
373
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
374
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
375
+ sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(
376
+ sdf_f1x6x2[..., 0],
377
+ (sdf_f1x6x2[..., 1] > 0).float()) + \
378
+ torch.nn.functional.binary_cross_entropy_with_logits(
379
+ sdf_f1x6x2[..., 1],
380
+ (sdf_f1x6x2[..., 0] > 0).float())
381
+ return sdf_diff
382
+
383
+
384
+ def sdf_reg_loss_batch(sdf, all_edges):
385
+ sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
386
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
387
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
388
+ sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
389
+ torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
390
+ return sdf_diff
391
+
392
+
393
+ ###############################################################################
394
+ # Geometry interface
395
+ ###############################################################################
396
+ class DMTetGeometry(Geometry):
397
+ def __init__(
398
+ self, grid_res=64, scale=2.0, device='cuda', renderer=None,
399
+ render_type='neural_render', args=None):
400
+ super(DMTetGeometry, self).__init__()
401
+ self.grid_res = grid_res
402
+ self.device = device
403
+ self.args = args
404
+ tets = np.load('data/tets/%d_compress.npz' % (grid_res))
405
+ self.verts = torch.from_numpy(tets['vertices']).float().to(self.device)
406
+ # Make sure the tet is zero-centered and length is equal to 1
407
+ length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0]
408
+ length = length.max()
409
+ mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0
410
+ self.verts = (self.verts - mid.unsqueeze(dim=0)) / length
411
+ if isinstance(scale, list):
412
+ self.verts[:, 0] = self.verts[:, 0] * scale[0]
413
+ self.verts[:, 1] = self.verts[:, 1] * scale[1]
414
+ self.verts[:, 2] = self.verts[:, 2] * scale[1]
415
+ else:
416
+ self.verts = self.verts * scale
417
+ self.indices = torch.from_numpy(tets['tets']).long().to(self.device)
418
+ self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device)
419
+ self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device)
420
+ # Parameters for regularization computation
421
+ edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
422
+ all_edges = self.indices[:, edges].reshape(-1, 2)
423
+ all_edges_sorted = torch.sort(all_edges, dim=1)[0]
424
+ self.all_edges = torch.unique(all_edges_sorted, dim=0)
425
+
426
+ # Parameters used for fix boundary sdf
427
+ self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts)
428
+ self.renderer = renderer
429
+ self.render_type = render_type
430
+
431
+ def getAABB(self):
432
+ return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
433
+
434
+ def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
435
+ if indices is None:
436
+ indices = self.indices
437
+ verts, faces = marching_tets(
438
+ v_deformed_nx3, sdf_n, indices, self.triangle_table,
439
+ self.num_triangles_table, self.base_tet_edges, self.v_id)
440
+ faces = torch.cat(
441
+ [faces[:, 0:1],
442
+ faces[:, 2:3],
443
+ faces[:, 1:2], ], dim=-1)
444
+ return verts, faces
445
+
446
+ def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
447
+ if indices is None:
448
+ indices = self.indices
449
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(
450
+ v_deformed_nx3, sdf_n, indices, self.triangle_table,
451
+ self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True,
452
+ num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3)
453
+ faces = torch.cat(
454
+ [faces[:, 0:1],
455
+ faces[:, 2:3],
456
+ faces[:, 1:2], ], dim=-1)
457
+ return verts, faces, tet_verts, tets
458
+
459
+ def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
460
+ return_value = dict()
461
+ if self.render_type == 'neural_render':
462
+ tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
463
+ mesh_v_nx3.unsqueeze(dim=0),
464
+ mesh_f_fx3.int(),
465
+ camera_mv_bx4x4,
466
+ mesh_v_nx3.unsqueeze(dim=0),
467
+ resolution=resolution,
468
+ device=self.device,
469
+ hierarchical_mask=hierarchical_mask
470
+ )
471
+
472
+ return_value['tex_pos'] = tex_pos
473
+ return_value['mask'] = mask
474
+ return_value['hard_mask'] = hard_mask
475
+ return_value['rast'] = rast
476
+ return_value['v_pos_clip'] = v_pos_clip
477
+ return_value['mask_pyramid'] = mask_pyramid
478
+ return_value['depth'] = depth
479
+ else:
480
+ raise NotImplementedError
481
+
482
+ return return_value
483
+
484
+ def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
485
+ # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
486
+ v_list = []
487
+ f_list = []
488
+ n_batch = v_deformed_bxnx3.shape[0]
489
+ all_render_output = []
490
+ for i_batch in range(n_batch):
491
+ verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
492
+ v_list.append(verts_nx3)
493
+ f_list.append(faces_fx3)
494
+ render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
495
+ all_render_output.append(render_output)
496
+
497
+ # Concatenate all render output
498
+ return_keys = all_render_output[0].keys()
499
+ return_value = dict()
500
+ for k in return_keys:
501
+ value = [v[k] for v in all_render_output]
502
+ return_value[k] = value
503
+ # We can do concatenation outside of the render
504
+ return return_value
src/models/geometry/rep_3d/dmtet_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+
11
+
12
+ def get_center_boundary_index(verts):
13
+ length_ = torch.sum(verts ** 2, dim=-1)
14
+ center_idx = torch.argmin(length_)
15
+ boundary_neg = verts == verts.max()
16
+ boundary_pos = verts == verts.min()
17
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
18
+ boundary = torch.sum(boundary.float(), dim=-1)
19
+ boundary_idx = torch.nonzero(boundary)
20
+ return center_idx, boundary_idx.squeeze(dim=-1)
src/models/geometry/rep_3d/extract_texture_map.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import xatlas
11
+ import numpy as np
12
+ import nvdiffrast.torch as dr
13
+
14
+
15
+ # ==============================================================================================
16
+ def interpolate(attr, rast, attr_idx, rast_db=None):
17
+ return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
18
+
19
+
20
+ def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
21
+ vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
22
+
23
+ # Convert to tensors
24
+ indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
25
+
26
+ uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
27
+ mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
28
+ # mesh_v_tex. ture
29
+ uv_clip = uvs[None, ...] * 2.0 - 1.0
30
+
31
+ # pad to four component coordinate
32
+ uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
33
+
34
+ # rasterize
35
+ rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
36
+
37
+ # Interpolate world space position
38
+ gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
39
+ mask = rast[..., 3:4] > 0
40
+ return uvs, mesh_tex_idx, gb_pos, mask
src/models/geometry/rep_3d/flexicubes.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import torch
9
+ from .tables import *
10
+
11
+ __all__ = [
12
+ 'FlexiCubes'
13
+ ]
14
+
15
+
16
+ class FlexiCubes:
17
+ """
18
+ This class implements the FlexiCubes method for extracting meshes from scalar fields.
19
+ It maintains a series of lookup tables and indices to support the mesh extraction process.
20
+ FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
21
+ the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
22
+ the surface representation through gradient-based optimization.
23
+
24
+ During instantiation, the class loads DMC tables from a file and transforms them into
25
+ PyTorch tensors on the specified device.
26
+
27
+ Attributes:
28
+ device (str): Specifies the computational device (default is "cuda").
29
+ dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
30
+ associated with each dual vertex in 256 Marching Cubes (MC) configurations.
31
+ num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
32
+ the 256 MC configurations.
33
+ check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
34
+ of the DMC configurations.
35
+ tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
36
+ quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
37
+ along one diagonal.
38
+ quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
39
+ two triangles along the other diagonal.
40
+ quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
41
+ during training by connecting all edges to their midpoints.
42
+ cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
43
+ eight corners in 3D space, ordered starting from the origin (0,0,0),
44
+ moving along the x-axis, then y-axis, and finally z-axis.
45
+ Used as a blueprint for generating a voxel grid.
46
+ cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
47
+ to retrieve the case id.
48
+ cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
49
+ Used to retrieve edge vertices in DMC.
50
+ edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
51
+ their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
52
+ first edge is oriented along the x-axis.
53
+ dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
54
+ across four adjacent cubes to the shared faces of these cubes. For instance,
55
+ dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
56
+ the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
57
+ This tensor is only utilized during isosurface tetrahedralization.
58
+ adj_pairs (torch.Tensor):
59
+ A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
60
+ qef_reg_scale (float):
61
+ The scaling factor applied to the regularization loss to prevent issues with singularity
62
+ when solving the QEF. This parameter is only used when a 'grad_func' is specified.
63
+ weight_scale (float):
64
+ The scale of weights in FlexiCubes. Should be between 0 and 1.
65
+ """
66
+
67
+ def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
68
+
69
+ self.device = device
70
+ self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
71
+ self.num_vd_table = torch.tensor(num_vd_table,
72
+ dtype=torch.long, device=device, requires_grad=False)
73
+ self.check_table = torch.tensor(
74
+ check_table,
75
+ dtype=torch.long, device=device, requires_grad=False)
76
+
77
+ self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
78
+ self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
79
+ self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
80
+ self.quad_split_train = torch.tensor(
81
+ [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
82
+
83
+ self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
84
+ 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
85
+ self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
86
+ self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
87
+ 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
88
+
89
+ self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
90
+ dtype=torch.long, device=device)
91
+ self.dir_faces_table = torch.tensor([
92
+ [[5, 4], [3, 2], [4, 5], [2, 3]],
93
+ [[5, 4], [1, 0], [4, 5], [0, 1]],
94
+ [[3, 2], [1, 0], [2, 3], [0, 1]]
95
+ ], dtype=torch.long, device=device)
96
+ self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
97
+ self.qef_reg_scale = qef_reg_scale
98
+ self.weight_scale = weight_scale
99
+
100
+ def construct_voxel_grid(self, res):
101
+ """
102
+ Generates a voxel grid based on the specified resolution.
103
+
104
+ Args:
105
+ res (int or list[int]): The resolution of the voxel grid. If an integer
106
+ is provided, it is used for all three dimensions. If a list or tuple
107
+ of 3 integers is provided, they define the resolution for the x,
108
+ y, and z dimensions respectively.
109
+
110
+ Returns:
111
+ (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
112
+ cube corners (index into vertices) of the constructed voxel grid.
113
+ The vertices are centered at the origin, with the length of each
114
+ dimension in the grid being one.
115
+ """
116
+ base_cube_f = torch.arange(8).to(self.device)
117
+ if isinstance(res, int):
118
+ res = (res, res, res)
119
+ voxel_grid_template = torch.ones(res, device=self.device)
120
+
121
+ res = torch.tensor([res], dtype=torch.float, device=self.device)
122
+ coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
123
+ verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
124
+ cubes = (base_cube_f.unsqueeze(0) +
125
+ torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
126
+
127
+ verts_rounded = torch.round(verts * 10**5) / (10**5)
128
+ verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
129
+ cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
130
+
131
+ return verts_unique - 0.5, cubes
132
+
133
+ def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
134
+ gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
135
+ r"""
136
+ Main function for mesh extraction from scalar field using FlexiCubes. This function converts
137
+ discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
138
+ to triangle or tetrahedral meshes using a differentiable operation as described in
139
+ `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
140
+ mesh quality and geometric fidelity by adjusting the surface representation based on gradient
141
+ optimization. The output surface is differentiable with respect to the input vertex positions,
142
+ scalar field values, and weight parameters.
143
+
144
+ If you intend to extract a surface mesh from a fixed Signed Distance Field without the
145
+ optimization of parameters, it is suggested to provide the "grad_func" which should
146
+ return the surface gradient at any given 3D position. When grad_func is provided, the process
147
+ to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
148
+ described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
149
+ Please note, this approach is non-differentiable.
150
+
151
+ For more details and example usage in optimization, refer to the
152
+ `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
153
+
154
+ Args:
155
+ x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
156
+ s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
157
+ denote that the corresponding vertex resides inside the isosurface. This affects
158
+ the directions of the extracted triangle faces and volume to be tetrahedralized.
159
+ cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
160
+ res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
161
+ is used for all three dimensions. If a list or tuple of 3 integers is provided, they
162
+ specify the resolution for the x, y, and z dimensions respectively.
163
+ beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
164
+ vertices positioning. Defaults to uniform value for all edges.
165
+ alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
166
+ vertices positioning. Defaults to uniform value for all vertices.
167
+ gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
168
+ quadrilaterals into triangles. Defaults to uniform value for all cubes.
169
+ training (bool, optional): If set to True, applies differentiable quad splitting for
170
+ training. Defaults to False.
171
+ output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
172
+ outputs a triangular mesh. Defaults to False.
173
+ grad_func (callable, optional): A function to compute the surface gradient at specified
174
+ 3D positions (input: Nx3 positions). The function should return gradients as an Nx3
175
+ tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
176
+
177
+ Returns:
178
+ (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
179
+ - Vertices for the extracted triangular/tetrahedral mesh.
180
+ - Faces for the extracted triangular/tetrahedral mesh.
181
+ - Regularizer L_dev, computed per dual vertex.
182
+
183
+ .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
184
+ https://research.nvidia.com/labs/toronto-ai/flexicubes/
185
+ .. _Manifold Dual Contouring:
186
+ https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
187
+ """
188
+
189
+ surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
190
+ if surf_cubes.sum() == 0:
191
+ return torch.zeros(
192
+ (0, 3),
193
+ device=self.device), torch.zeros(
194
+ (0, 4),
195
+ dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
196
+ (0, 3),
197
+ dtype=torch.long, device=self.device), torch.zeros(
198
+ (0),
199
+ device=self.device)
200
+ beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
201
+
202
+ case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
203
+
204
+ surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
205
+
206
+ vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
207
+ x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
208
+ vertices, faces, s_edges, edge_indices = self._triangulate(
209
+ s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
210
+ if not output_tetmesh:
211
+ return vertices, faces, L_dev
212
+ else:
213
+ vertices, tets = self._tetrahedralize(
214
+ x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
215
+ surf_cubes, training)
216
+ return vertices, tets, L_dev
217
+
218
+ def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
219
+ """
220
+ Regularizer L_dev as in Equation 8
221
+ """
222
+ dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
223
+ mean_l2 = torch.zeros_like(vd[:, 0])
224
+ mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
225
+ mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
226
+ return mad
227
+
228
+ def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
229
+ """
230
+ Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
231
+ """
232
+ n_cubes = surf_cubes.shape[0]
233
+
234
+ if beta_fx12 is not None:
235
+ beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
236
+ else:
237
+ beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
238
+
239
+ if alpha_fx8 is not None:
240
+ alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
241
+ else:
242
+ alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
243
+
244
+ if gamma_f is not None:
245
+ gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
246
+ else:
247
+ gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
248
+
249
+ return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
250
+
251
+ @torch.no_grad()
252
+ def _get_case_id(self, occ_fx8, surf_cubes, res):
253
+ """
254
+ Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
255
+ ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
256
+ supplementary material. It should be noted that this function assumes a regular grid.
257
+ """
258
+ case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
259
+
260
+ problem_config = self.check_table.to(self.device)[case_ids]
261
+ to_check = problem_config[..., 0] == 1
262
+ problem_config = problem_config[to_check]
263
+ if not isinstance(res, (list, tuple)):
264
+ res = [res, res, res]
265
+
266
+ # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
267
+ # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
268
+ # This allows efficient checking on adjacent cubes.
269
+ problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
270
+ vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
271
+ vol_idx_problem = vol_idx[surf_cubes][to_check]
272
+ problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
273
+ vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
274
+
275
+ within_range = (
276
+ vol_idx_problem_adj[..., 0] >= 0) & (
277
+ vol_idx_problem_adj[..., 0] < res[0]) & (
278
+ vol_idx_problem_adj[..., 1] >= 0) & (
279
+ vol_idx_problem_adj[..., 1] < res[1]) & (
280
+ vol_idx_problem_adj[..., 2] >= 0) & (
281
+ vol_idx_problem_adj[..., 2] < res[2])
282
+
283
+ vol_idx_problem = vol_idx_problem[within_range]
284
+ vol_idx_problem_adj = vol_idx_problem_adj[within_range]
285
+ problem_config = problem_config[within_range]
286
+ problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
287
+ vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
288
+ # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
289
+ to_invert = (problem_config_adj[..., 0] == 1)
290
+ idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
291
+ case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
292
+ return case_ids
293
+
294
+ @torch.no_grad()
295
+ def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
296
+ """
297
+ Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
298
+ can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
299
+ and marks the cube edges with this index.
300
+ """
301
+ occ_n = s_n < 0
302
+ all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
303
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
304
+
305
+ unique_edges = unique_edges.long()
306
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
307
+
308
+ surf_edges_mask = mask_edges[_idx_map]
309
+ counts = counts[_idx_map]
310
+
311
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
312
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
313
+ # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
314
+ # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
315
+ idx_map = mapping[_idx_map]
316
+ surf_edges = unique_edges[mask_edges]
317
+ return surf_edges, idx_map, counts, surf_edges_mask
318
+
319
+ @torch.no_grad()
320
+ def _identify_surf_cubes(self, s_n, cube_fx8):
321
+ """
322
+ Identifies grid cubes that intersect with the underlying surface by checking if the signs at
323
+ all corners are not identical.
324
+ """
325
+ occ_n = s_n < 0
326
+ occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
327
+ _occ_sum = torch.sum(occ_fx8, -1)
328
+ surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
329
+ return surf_cubes, occ_fx8
330
+
331
+ def _linear_interp(self, edges_weight, edges_x):
332
+ """
333
+ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
334
+ """
335
+ edge_dim = edges_weight.dim() - 2
336
+ assert edges_weight.shape[edge_dim] == 2
337
+ edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
338
+ torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
339
+ denominator = edges_weight.sum(edge_dim)
340
+ ue = (edges_x * edges_weight).sum(edge_dim) / denominator
341
+ return ue
342
+
343
+ def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
344
+ p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
345
+ norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
346
+ c_bx3 = c_bx3.reshape(-1, 3)
347
+ A = norm_bxnx3
348
+ B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
349
+
350
+ A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
351
+ B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
352
+ A = torch.cat([A, A_reg], 1)
353
+ B = torch.cat([B, B_reg], 1)
354
+ dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
355
+ return dual_verts
356
+
357
+ def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
358
+ """
359
+ Computes the location of dual vertices as described in Section 4.2
360
+ """
361
+ alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
362
+ surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
363
+ surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
364
+ zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
365
+
366
+ idx_map = idx_map.reshape(-1, 12)
367
+ num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
368
+ edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
369
+
370
+ total_num_vd = 0
371
+ vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
372
+ if grad_func is not None:
373
+ normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
374
+ vd = []
375
+ for num in torch.unique(num_vd):
376
+ cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
377
+ curr_num_vd = cur_cubes.sum() * num
378
+ curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
379
+ curr_edge_group_to_vd = torch.arange(
380
+ curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
381
+ total_num_vd += curr_num_vd
382
+ curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
383
+ cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
384
+
385
+ curr_mask = (curr_edge_group != -1)
386
+ edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
387
+ edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
388
+ edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
389
+ vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
390
+ vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
391
+
392
+ if grad_func is not None:
393
+ with torch.no_grad():
394
+ cube_e_verts_idx = idx_map[cur_cubes]
395
+ curr_edge_group[~curr_mask] = 0
396
+
397
+ verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
398
+ verts_group_idx[verts_group_idx == -1] = 0
399
+ verts_group_pos = torch.index_select(
400
+ input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
401
+ v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
402
+ curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
403
+ verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
404
+
405
+ normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
406
+ -1, num.item(), 7,
407
+ 3)
408
+ curr_mask = curr_mask.squeeze(2)
409
+ vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
410
+ verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
411
+ edge_group = torch.cat(edge_group)
412
+ edge_group_to_vd = torch.cat(edge_group_to_vd)
413
+ edge_group_to_cube = torch.cat(edge_group_to_cube)
414
+ vd_num_edges = torch.cat(vd_num_edges)
415
+ vd_gamma = torch.cat(vd_gamma)
416
+
417
+ if grad_func is not None:
418
+ vd = torch.cat(vd)
419
+ L_dev = torch.zeros([1], device=self.device)
420
+ else:
421
+ vd = torch.zeros((total_num_vd, 3), device=self.device)
422
+ beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
423
+
424
+ idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
425
+
426
+ x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
427
+ s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
428
+
429
+ zero_crossing_group = torch.index_select(
430
+ input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
431
+
432
+ alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
433
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
434
+ ue_group = self._linear_interp(s_group * alpha_group, x_group)
435
+
436
+ beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
437
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
438
+ beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
439
+ vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
440
+ L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
441
+
442
+ v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
443
+
444
+ vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
445
+ 12 + edge_group, src=v_idx[edge_group_to_vd])
446
+
447
+ return vd, L_dev, vd_gamma, vd_idx_map
448
+
449
+ def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
450
+ """
451
+ Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
452
+ triangles based on the gamma parameter, as described in Section 4.3.
453
+ """
454
+ with torch.no_grad():
455
+ group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
456
+ group = idx_map.reshape(-1)[group_mask]
457
+ vd_idx = vd_idx_map[group_mask]
458
+ edge_indices, indices = torch.sort(group, stable=True)
459
+ quad_vd_idx = vd_idx[indices].reshape(-1, 4)
460
+
461
+ # Ensure all face directions point towards the positive SDF to maintain consistent winding.
462
+ s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
463
+ flip_mask = s_edges[:, 0] > 0
464
+ quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
465
+ quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
466
+ if grad_func is not None:
467
+ # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
468
+ with torch.no_grad():
469
+ vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
470
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
471
+ gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
472
+ gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
473
+ else:
474
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
475
+ gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
476
+ 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
477
+ gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
478
+ 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
479
+ if not training:
480
+ mask = (gamma_02 > gamma_13).squeeze(1)
481
+ faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
482
+ faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
483
+ faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
484
+ faces = faces.reshape(-1, 3)
485
+ else:
486
+ vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
487
+ vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
488
+ torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
489
+ vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
490
+ torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
491
+ weight_sum = (gamma_02 + gamma_13) + 1e-8
492
+ vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
493
+ weight_sum.unsqueeze(-1)).squeeze(1)
494
+ vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
495
+ vd = torch.cat([vd, vd_center])
496
+ faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
497
+ faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
498
+ return vd, faces, s_edges, edge_indices
499
+
500
+ def _tetrahedralize(
501
+ self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
502
+ surf_cubes, training):
503
+ """
504
+ Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
505
+ """
506
+ occ_n = s_n < 0
507
+ occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
508
+ occ_sum = torch.sum(occ_fx8, -1)
509
+
510
+ inside_verts = x_nx3[occ_n]
511
+ mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
512
+ mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
513
+ """
514
+ For each grid edge connecting two grid vertices with different
515
+ signs, we first form a four-sided pyramid by connecting one
516
+ of the grid vertices with four mesh vertices that correspond
517
+ to the grid edge and then subdivide the pyramid into two tetrahedra
518
+ """
519
+ inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
520
+ s_edges < 0]]
521
+ if not training:
522
+ inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
523
+ else:
524
+ inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
525
+
526
+ tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
527
+ """
528
+ For each grid edge connecting two grid vertices with the
529
+ same sign, the tetrahedron is formed by the two grid vertices
530
+ and two vertices in consecutive adjacent cells
531
+ """
532
+ inside_cubes = (occ_sum == 8)
533
+ inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
534
+ inside_cubes_center_idx = torch.arange(
535
+ inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
536
+
537
+ surface_n_inside_cubes = surf_cubes | inside_cubes
538
+ edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
539
+ dtype=torch.long, device=x_nx3.device) * -1
540
+ surf_cubes = surf_cubes[surface_n_inside_cubes]
541
+ inside_cubes = inside_cubes[surface_n_inside_cubes]
542
+ edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
543
+ edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
544
+
545
+ all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
546
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
547
+ unique_edges = unique_edges.long()
548
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
549
+ mask = mask_edges[_idx_map]
550
+ counts = counts[_idx_map]
551
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
552
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
553
+ idx_map = mapping[_idx_map]
554
+
555
+ group_mask = (counts == 4) & mask
556
+ group = idx_map.reshape(-1)[group_mask]
557
+ edge_indices, indices = torch.sort(group)
558
+ cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
559
+ device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
560
+ edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
561
+ 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
562
+ # Identify the face shared by the adjacent cells.
563
+ cube_idx_4 = cube_idx[indices].reshape(-1, 4)
564
+ edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
565
+ shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
566
+ cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
567
+ # Identify an edge of the face with different signs and
568
+ # select the mesh vertex corresponding to the identified edge.
569
+ case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
570
+ case_ids_expand[surf_cubes] = case_ids
571
+ cases = case_ids_expand[cube_idx_4x2]
572
+ quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
573
+ mask = (quad_edge == -1).sum(-1) == 0
574
+ inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
575
+ tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
576
+
577
+ tets = torch.cat([tets_surface, tets_inside])
578
+ vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
579
+ return vertices, tets
src/models/geometry/rep_3d/flexicubes_geometry.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import numpy as np
11
+ import os
12
+ from . import Geometry
13
+ from .flexicubes import FlexiCubes # replace later
14
+ from .dmtet import sdf_reg_loss_batch
15
+ import torch.nn.functional as F
16
+
17
+ def get_center_boundary_index(grid_res, device):
18
+ v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
19
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
20
+ center_indices = torch.nonzero(v.reshape(-1))
21
+
22
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
23
+ v[:2, ...] = True
24
+ v[-2:, ...] = True
25
+ v[:, :2, ...] = True
26
+ v[:, -2:, ...] = True
27
+ v[:, :, :2] = True
28
+ v[:, :, -2:] = True
29
+ boundary_indices = torch.nonzero(v.reshape(-1))
30
+ return center_indices, boundary_indices
31
+
32
+ ###############################################################################
33
+ # Geometry interface
34
+ ###############################################################################
35
+ class FlexiCubesGeometry(Geometry):
36
+ def __init__(
37
+ self, grid_res=64, scale=2.0, device='cuda', renderer=None,
38
+ render_type='neural_render', args=None):
39
+ super(FlexiCubesGeometry, self).__init__()
40
+ self.grid_res = grid_res
41
+ self.device = device
42
+ self.args = args
43
+ self.fc = FlexiCubes(device, weight_scale=0.5)
44
+ self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
45
+ if isinstance(scale, list):
46
+ self.verts[:, 0] = self.verts[:, 0] * scale[0]
47
+ self.verts[:, 1] = self.verts[:, 1] * scale[1]
48
+ self.verts[:, 2] = self.verts[:, 2] * scale[1]
49
+ else:
50
+ self.verts = self.verts * scale
51
+
52
+ all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
53
+ self.all_edges = torch.unique(all_edges, dim=0)
54
+
55
+ # Parameters used for fix boundary sdf
56
+ self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
57
+ self.renderer = renderer
58
+ self.render_type = render_type
59
+
60
+ def getAABB(self):
61
+ return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
62
+
63
+ def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
64
+ if indices is None:
65
+ indices = self.indices
66
+
67
+ verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
68
+ beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
69
+ gamma_f=weight_n[:, 20], training=is_training
70
+ )
71
+ return verts, faces, v_reg_loss
72
+
73
+
74
+ def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
75
+ return_value = dict()
76
+ if self.render_type == 'neural_render':
77
+ tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh(
78
+ mesh_v_nx3.unsqueeze(dim=0),
79
+ mesh_f_fx3.int(),
80
+ camera_mv_bx4x4,
81
+ mesh_v_nx3.unsqueeze(dim=0),
82
+ resolution=resolution,
83
+ device=self.device,
84
+ hierarchical_mask=hierarchical_mask
85
+ )
86
+
87
+ return_value['tex_pos'] = tex_pos
88
+ return_value['mask'] = mask
89
+ return_value['hard_mask'] = hard_mask
90
+ return_value['rast'] = rast
91
+ return_value['v_pos_clip'] = v_pos_clip
92
+ return_value['mask_pyramid'] = mask_pyramid
93
+ return_value['depth'] = depth
94
+ return_value['normal'] = normal
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return return_value
99
+
100
+ def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
101
+ # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
102
+ v_list = []
103
+ f_list = []
104
+ n_batch = v_deformed_bxnx3.shape[0]
105
+ all_render_output = []
106
+ for i_batch in range(n_batch):
107
+ verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
108
+ v_list.append(verts_nx3)
109
+ f_list.append(faces_fx3)
110
+ render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
111
+ all_render_output.append(render_output)
112
+
113
+ # Concatenate all render output
114
+ return_keys = all_render_output[0].keys()
115
+ return_value = dict()
116
+ for k in return_keys:
117
+ value = [v[k] for v in all_render_output]
118
+ return_value[k] = value
119
+ # We can do concatenation outside of the render
120
+ return return_value
src/models/geometry/rep_3d/tables.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ dmc_table = [
9
+ [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
10
+ [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
11
+ [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
12
+ [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
13
+ [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
14
+ [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
15
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
16
+ [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
17
+ [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
18
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
19
+ [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
20
+ [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
21
+ [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
22
+ [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
23
+ [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
24
+ [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
25
+ [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
26
+ [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
27
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
28
+ [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
29
+ [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
30
+ [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
31
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
32
+ [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
33
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
34
+ [[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
35
+ [[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
36
+ [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
37
+ [[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
38
+ [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
39
+ [[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
40
+ [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
41
+ [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
42
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
43
+ [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
44
+ [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
45
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
46
+ [[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
47
+ [[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
48
+ [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
49
+ [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
50
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
51
+ [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
52
+ [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
53
+ [[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
54
+ [[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
55
+ [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
56
+ [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
57
+ [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
58
+ [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
59
+ [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
60
+ [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
61
+ [[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
62
+ [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
63
+ [[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
64
+ [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
65
+ [[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
66
+ [[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
67
+ [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
68
+ [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
69
+ [[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
70
+ [[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
71
+ [[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
72
+ [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
73
+ [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
74
+ [[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
75
+ [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
76
+ [[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
77
+ [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
78
+ [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
79
+ [[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
80
+ [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
81
+ [[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
82
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
83
+ [[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
84
+ [[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
85
+ [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
86
+ [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
87
+ [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
88
+ [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
89
+ [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
90
+ [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
91
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
92
+ [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
93
+ [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
94
+ [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
95
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
96
+ [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
97
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
98
+ [[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
99
+ [[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
100
+ [[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
101
+ [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
102
+ [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
103
+ [[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
104
+ [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
105
+ [[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
106
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
107
+ [[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
108
+ [[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
109
+ [[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
110
+ [[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
111
+ [[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
112
+ [[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
113
+ [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
114
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
115
+ [[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
116
+ [[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
117
+ [[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
118
+ [[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
119
+ [[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
120
+ [[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
121
+ [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
122
+ [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
123
+ [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
124
+ [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
125
+ [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
126
+ [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
127
+ [[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
128
+ [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
129
+ [[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
130
+ [[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
131
+ [[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
132
+ [[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
133
+ [[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
134
+ [[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
135
+ [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
136
+ [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
137
+ [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
138
+ [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
139
+ [[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
140
+ [[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
141
+ [[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
142
+ [[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
143
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
144
+ [[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
145
+ [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
146
+ [[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
147
+ [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
148
+ [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
149
+ [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
150
+ [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
151
+ [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
152
+ [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
153
+ [[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
154
+ [[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
155
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
156
+ [[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
157
+ [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
158
+ [[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
159
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
160
+ [[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
161
+ [[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
162
+ [[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
163
+ [[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
164
+ [[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
165
+ [[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
166
+ [[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
167
+ [[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
168
+ [[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
169
+ [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
170
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
171
+ [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
172
+ [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
173
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
174
+ [[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
175
+ [[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
176
+ [[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
177
+ [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
178
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
179
+ [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
180
+ [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
181
+ [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
182
+ [[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
183
+ [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
184
+ [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
185
+ [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
186
+ [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
187
+ [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
188
+ [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
189
+ [[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
190
+ [[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
191
+ [[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
192
+ [[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
193
+ [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
194
+ [[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
195
+ [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
196
+ [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
197
+ [[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
198
+ [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
199
+ [[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
200
+ [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
201
+ [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
202
+ [[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
203
+ [[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
204
+ [[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
205
+ [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
206
+ [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
207
+ [[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
208
+ [[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
209
+ [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
210
+ [[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
211
+ [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
212
+ [[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
213
+ [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
214
+ [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
215
+ [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
216
+ [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
217
+ [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
218
+ [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
219
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
220
+ [[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
221
+ [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
222
+ [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
223
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
224
+ [[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
225
+ [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
226
+ [[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
227
+ [[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
228
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
229
+ [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
230
+ [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
231
+ [[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
232
+ [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
233
+ [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
234
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
235
+ [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
236
+ [[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
237
+ [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
238
+ [[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
239
+ [[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
240
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
241
+ [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
242
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
243
+ [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
244
+ [[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
245
+ [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
246
+ [[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
247
+ [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
248
+ [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
249
+ [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
250
+ [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
251
+ [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
252
+ [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
253
+ [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
254
+ [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
255
+ [[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
256
+ [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
257
+ [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
258
+ [[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
259
+ [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
260
+ [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
261
+ [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
262
+ [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
263
+ [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
264
+ [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
265
+ ]
266
+ num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
267
+ 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
268
+ 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
269
+ 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
270
+ 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
271
+ 3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
272
+ 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
273
+ 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
274
+ 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
275
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
276
+ check_table = [
277
+ [0, 0, 0, 0, 0],
278
+ [0, 0, 0, 0, 0],
279
+ [0, 0, 0, 0, 0],
280
+ [0, 0, 0, 0, 0],
281
+ [0, 0, 0, 0, 0],
282
+ [0, 0, 0, 0, 0],
283
+ [0, 0, 0, 0, 0],
284
+ [0, 0, 0, 0, 0],
285
+ [0, 0, 0, 0, 0],
286
+ [0, 0, 0, 0, 0],
287
+ [0, 0, 0, 0, 0],
288
+ [0, 0, 0, 0, 0],
289
+ [0, 0, 0, 0, 0],
290
+ [0, 0, 0, 0, 0],
291
+ [0, 0, 0, 0, 0],
292
+ [0, 0, 0, 0, 0],
293
+ [0, 0, 0, 0, 0],
294
+ [0, 0, 0, 0, 0],
295
+ [0, 0, 0, 0, 0],
296
+ [0, 0, 0, 0, 0],
297
+ [0, 0, 0, 0, 0],
298
+ [0, 0, 0, 0, 0],
299
+ [0, 0, 0, 0, 0],
300
+ [0, 0, 0, 0, 0],
301
+ [0, 0, 0, 0, 0],
302
+ [0, 0, 0, 0, 0],
303
+ [0, 0, 0, 0, 0],
304
+ [0, 0, 0, 0, 0],
305
+ [0, 0, 0, 0, 0],
306
+ [0, 0, 0, 0, 0],
307
+ [0, 0, 0, 0, 0],
308
+ [0, 0, 0, 0, 0],
309
+ [0, 0, 0, 0, 0],
310
+ [0, 0, 0, 0, 0],
311
+ [0, 0, 0, 0, 0],
312
+ [0, 0, 0, 0, 0],
313
+ [0, 0, 0, 0, 0],
314
+ [0, 0, 0, 0, 0],
315
+ [0, 0, 0, 0, 0],
316
+ [0, 0, 0, 0, 0],
317
+ [0, 0, 0, 0, 0],
318
+ [0, 0, 0, 0, 0],
319
+ [0, 0, 0, 0, 0],
320
+ [0, 0, 0, 0, 0],
321
+ [0, 0, 0, 0, 0],
322
+ [0, 0, 0, 0, 0],
323
+ [0, 0, 0, 0, 0],
324
+ [0, 0, 0, 0, 0],
325
+ [0, 0, 0, 0, 0],
326
+ [0, 0, 0, 0, 0],
327
+ [0, 0, 0, 0, 0],
328
+ [0, 0, 0, 0, 0],
329
+ [0, 0, 0, 0, 0],
330
+ [0, 0, 0, 0, 0],
331
+ [0, 0, 0, 0, 0],
332
+ [0, 0, 0, 0, 0],
333
+ [0, 0, 0, 0, 0],
334
+ [0, 0, 0, 0, 0],
335
+ [0, 0, 0, 0, 0],
336
+ [0, 0, 0, 0, 0],
337
+ [0, 0, 0, 0, 0],
338
+ [1, 1, 0, 0, 194],
339
+ [1, -1, 0, 0, 193],
340
+ [0, 0, 0, 0, 0],
341
+ [0, 0, 0, 0, 0],
342
+ [0, 0, 0, 0, 0],
343
+ [0, 0, 0, 0, 0],
344
+ [0, 0, 0, 0, 0],
345
+ [0, 0, 0, 0, 0],
346
+ [0, 0, 0, 0, 0],
347
+ [0, 0, 0, 0, 0],
348
+ [0, 0, 0, 0, 0],
349
+ [0, 0, 0, 0, 0],
350
+ [0, 0, 0, 0, 0],
351
+ [0, 0, 0, 0, 0],
352
+ [0, 0, 0, 0, 0],
353
+ [0, 0, 0, 0, 0],
354
+ [0, 0, 0, 0, 0],
355
+ [0, 0, 0, 0, 0],
356
+ [0, 0, 0, 0, 0],
357
+ [0, 0, 0, 0, 0],
358
+ [0, 0, 0, 0, 0],
359
+ [0, 0, 0, 0, 0],
360
+ [0, 0, 0, 0, 0],
361
+ [0, 0, 0, 0, 0],
362
+ [0, 0, 0, 0, 0],
363
+ [0, 0, 0, 0, 0],
364
+ [0, 0, 0, 0, 0],
365
+ [0, 0, 0, 0, 0],
366
+ [0, 0, 0, 0, 0],
367
+ [0, 0, 0, 0, 0],
368
+ [1, 0, 1, 0, 164],
369
+ [0, 0, 0, 0, 0],
370
+ [0, 0, 0, 0, 0],
371
+ [1, 0, -1, 0, 161],
372
+ [0, 0, 0, 0, 0],
373
+ [0, 0, 0, 0, 0],
374
+ [0, 0, 0, 0, 0],
375
+ [0, 0, 0, 0, 0],
376
+ [0, 0, 0, 0, 0],
377
+ [0, 0, 0, 0, 0],
378
+ [0, 0, 0, 0, 0],
379
+ [0, 0, 0, 0, 0],
380
+ [1, 0, 0, 1, 152],
381
+ [0, 0, 0, 0, 0],
382
+ [0, 0, 0, 0, 0],
383
+ [0, 0, 0, 0, 0],
384
+ [0, 0, 0, 0, 0],
385
+ [0, 0, 0, 0, 0],
386
+ [0, 0, 0, 0, 0],
387
+ [1, 0, 0, 1, 145],
388
+ [1, 0, 0, 1, 144],
389
+ [0, 0, 0, 0, 0],
390
+ [0, 0, 0, 0, 0],
391
+ [0, 0, 0, 0, 0],
392
+ [0, 0, 0, 0, 0],
393
+ [0, 0, 0, 0, 0],
394
+ [0, 0, 0, 0, 0],
395
+ [1, 0, 0, -1, 137],
396
+ [0, 0, 0, 0, 0],
397
+ [0, 0, 0, 0, 0],
398
+ [0, 0, 0, 0, 0],
399
+ [1, 0, 1, 0, 133],
400
+ [1, 0, 1, 0, 132],
401
+ [1, 1, 0, 0, 131],
402
+ [1, 1, 0, 0, 130],
403
+ [0, 0, 0, 0, 0],
404
+ [0, 0, 0, 0, 0],
405
+ [0, 0, 0, 0, 0],
406
+ [0, 0, 0, 0, 0],
407
+ [0, 0, 0, 0, 0],
408
+ [0, 0, 0, 0, 0],
409
+ [0, 0, 0, 0, 0],
410
+ [0, 0, 0, 0, 0],
411
+ [0, 0, 0, 0, 0],
412
+ [0, 0, 0, 0, 0],
413
+ [0, 0, 0, 0, 0],
414
+ [0, 0, 0, 0, 0],
415
+ [0, 0, 0, 0, 0],
416
+ [0, 0, 0, 0, 0],
417
+ [0, 0, 0, 0, 0],
418
+ [0, 0, 0, 0, 0],
419
+ [0, 0, 0, 0, 0],
420
+ [0, 0, 0, 0, 0],
421
+ [0, 0, 0, 0, 0],
422
+ [0, 0, 0, 0, 0],
423
+ [0, 0, 0, 0, 0],
424
+ [0, 0, 0, 0, 0],
425
+ [0, 0, 0, 0, 0],
426
+ [0, 0, 0, 0, 0],
427
+ [0, 0, 0, 0, 0],
428
+ [0, 0, 0, 0, 0],
429
+ [0, 0, 0, 0, 0],
430
+ [0, 0, 0, 0, 0],
431
+ [0, 0, 0, 0, 0],
432
+ [1, 0, 0, 1, 100],
433
+ [0, 0, 0, 0, 0],
434
+ [1, 0, 0, 1, 98],
435
+ [0, 0, 0, 0, 0],
436
+ [1, 0, 0, 1, 96],
437
+ [0, 0, 0, 0, 0],
438
+ [0, 0, 0, 0, 0],
439
+ [0, 0, 0, 0, 0],
440
+ [0, 0, 0, 0, 0],
441
+ [0, 0, 0, 0, 0],
442
+ [0, 0, 0, 0, 0],
443
+ [0, 0, 0, 0, 0],
444
+ [1, 0, 1, 0, 88],
445
+ [0, 0, 0, 0, 0],
446
+ [0, 0, 0, 0, 0],
447
+ [0, 0, 0, 0, 0],
448
+ [0, 0, 0, 0, 0],
449
+ [0, 0, 0, 0, 0],
450
+ [1, 0, -1, 0, 82],
451
+ [0, 0, 0, 0, 0],
452
+ [0, 0, 0, 0, 0],
453
+ [0, 0, 0, 0, 0],
454
+ [0, 0, 0, 0, 0],
455
+ [0, 0, 0, 0, 0],
456
+ [0, 0, 0, 0, 0],
457
+ [0, 0, 0, 0, 0],
458
+ [1, 0, 1, 0, 74],
459
+ [0, 0, 0, 0, 0],
460
+ [1, 0, 1, 0, 72],
461
+ [0, 0, 0, 0, 0],
462
+ [1, 0, 0, -1, 70],
463
+ [0, 0, 0, 0, 0],
464
+ [0, 0, 0, 0, 0],
465
+ [1, -1, 0, 0, 67],
466
+ [0, 0, 0, 0, 0],
467
+ [1, -1, 0, 0, 65],
468
+ [0, 0, 0, 0, 0],
469
+ [0, 0, 0, 0, 0],
470
+ [0, 0, 0, 0, 0],
471
+ [0, 0, 0, 0, 0],
472
+ [0, 0, 0, 0, 0],
473
+ [0, 0, 0, 0, 0],
474
+ [0, 0, 0, 0, 0],
475
+ [0, 0, 0, 0, 0],
476
+ [1, 1, 0, 0, 56],
477
+ [0, 0, 0, 0, 0],
478
+ [0, 0, 0, 0, 0],
479
+ [0, 0, 0, 0, 0],
480
+ [1, -1, 0, 0, 52],
481
+ [0, 0, 0, 0, 0],
482
+ [0, 0, 0, 0, 0],
483
+ [0, 0, 0, 0, 0],
484
+ [0, 0, 0, 0, 0],
485
+ [0, 0, 0, 0, 0],
486
+ [0, 0, 0, 0, 0],
487
+ [0, 0, 0, 0, 0],
488
+ [1, 1, 0, 0, 44],
489
+ [0, 0, 0, 0, 0],
490
+ [0, 0, 0, 0, 0],
491
+ [0, 0, 0, 0, 0],
492
+ [1, 1, 0, 0, 40],
493
+ [0, 0, 0, 0, 0],
494
+ [1, 0, 0, -1, 38],
495
+ [1, 0, -1, 0, 37],
496
+ [0, 0, 0, 0, 0],
497
+ [0, 0, 0, 0, 0],
498
+ [0, 0, 0, 0, 0],
499
+ [1, 0, -1, 0, 33],
500
+ [0, 0, 0, 0, 0],
501
+ [0, 0, 0, 0, 0],
502
+ [0, 0, 0, 0, 0],
503
+ [0, 0, 0, 0, 0],
504
+ [1, -1, 0, 0, 28],
505
+ [0, 0, 0, 0, 0],
506
+ [1, 0, -1, 0, 26],
507
+ [1, 0, 0, -1, 25],
508
+ [0, 0, 0, 0, 0],
509
+ [0, 0, 0, 0, 0],
510
+ [0, 0, 0, 0, 0],
511
+ [0, 0, 0, 0, 0],
512
+ [1, -1, 0, 0, 20],
513
+ [0, 0, 0, 0, 0],
514
+ [1, 0, -1, 0, 18],
515
+ [0, 0, 0, 0, 0],
516
+ [0, 0, 0, 0, 0],
517
+ [0, 0, 0, 0, 0],
518
+ [0, 0, 0, 0, 0],
519
+ [0, 0, 0, 0, 0],
520
+ [0, 0, 0, 0, 0],
521
+ [0, 0, 0, 0, 0],
522
+ [0, 0, 0, 0, 0],
523
+ [1, 0, 0, -1, 9],
524
+ [0, 0, 0, 0, 0],
525
+ [0, 0, 0, 0, 0],
526
+ [1, 0, 0, -1, 6],
527
+ [0, 0, 0, 0, 0],
528
+ [0, 0, 0, 0, 0],
529
+ [0, 0, 0, 0, 0],
530
+ [0, 0, 0, 0, 0],
531
+ [0, 0, 0, 0, 0],
532
+ [0, 0, 0, 0, 0]
533
+ ]
534
+ tet_table = [
535
+ [-1, -1, -1, -1, -1, -1],
536
+ [0, 0, 0, 0, 0, 0],
537
+ [0, 0, 0, 0, 0, 0],
538
+ [1, 1, 1, 1, 1, 1],
539
+ [4, 4, 4, 4, 4, 4],
540
+ [0, 0, 0, 0, 0, 0],
541
+ [4, 0, 0, 4, 4, -1],
542
+ [1, 1, 1, 1, 1, 1],
543
+ [4, 4, 4, 4, 4, 4],
544
+ [0, 4, 0, 4, 4, -1],
545
+ [0, 0, 0, 0, 0, 0],
546
+ [1, 1, 1, 1, 1, 1],
547
+ [5, 5, 5, 5, 5, 5],
548
+ [0, 0, 0, 0, 0, 0],
549
+ [0, 0, 0, 0, 0, 0],
550
+ [1, 1, 1, 1, 1, 1],
551
+ [2, 2, 2, 2, 2, 2],
552
+ [0, 0, 0, 0, 0, 0],
553
+ [2, 0, 2, -1, 0, 2],
554
+ [1, 1, 1, 1, 1, 1],
555
+ [2, -1, 2, 4, 4, 2],
556
+ [0, 0, 0, 0, 0, 0],
557
+ [2, 0, 2, 4, 4, 2],
558
+ [1, 1, 1, 1, 1, 1],
559
+ [2, 4, 2, 4, 4, 2],
560
+ [0, 4, 0, 4, 4, 0],
561
+ [2, 0, 2, 0, 0, 2],
562
+ [1, 1, 1, 1, 1, 1],
563
+ [2, 5, 2, 5, 5, 2],
564
+ [0, 0, 0, 0, 0, 0],
565
+ [2, 0, 2, 0, 0, 2],
566
+ [1, 1, 1, 1, 1, 1],
567
+ [1, 1, 1, 1, 1, 1],
568
+ [0, 1, 1, -1, 0, 1],
569
+ [0, 0, 0, 0, 0, 0],
570
+ [2, 2, 2, 2, 2, 2],
571
+ [4, 1, 1, 4, 4, 1],
572
+ [0, 1, 1, 0, 0, 1],
573
+ [4, 0, 0, 4, 4, 0],
574
+ [2, 2, 2, 2, 2, 2],
575
+ [-1, 1, 1, 4, 4, 1],
576
+ [0, 1, 1, 4, 4, 1],
577
+ [0, 0, 0, 0, 0, 0],
578
+ [2, 2, 2, 2, 2, 2],
579
+ [5, 1, 1, 5, 5, 1],
580
+ [0, 1, 1, 0, 0, 1],
581
+ [0, 0, 0, 0, 0, 0],
582
+ [2, 2, 2, 2, 2, 2],
583
+ [1, 1, 1, 1, 1, 1],
584
+ [0, 0, 0, 0, 0, 0],
585
+ [0, 0, 0, 0, 0, 0],
586
+ [8, 8, 8, 8, 8, 8],
587
+ [1, 1, 1, 4, 4, 1],
588
+ [0, 0, 0, 0, 0, 0],
589
+ [4, 0, 0, 4, 4, 0],
590
+ [4, 4, 4, 4, 4, 4],
591
+ [1, 1, 1, 4, 4, 1],
592
+ [0, 4, 0, 4, 4, 0],
593
+ [0, 0, 0, 0, 0, 0],
594
+ [4, 4, 4, 4, 4, 4],
595
+ [1, 1, 1, 5, 5, 1],
596
+ [0, 0, 0, 0, 0, 0],
597
+ [0, 0, 0, 0, 0, 0],
598
+ [5, 5, 5, 5, 5, 5],
599
+ [6, 6, 6, 6, 6, 6],
600
+ [6, -1, 0, 6, 0, 6],
601
+ [6, 0, 0, 6, 0, 6],
602
+ [6, 1, 1, 6, 1, 6],
603
+ [4, 4, 4, 4, 4, 4],
604
+ [0, 0, 0, 0, 0, 0],
605
+ [4, 0, 0, 4, 4, 4],
606
+ [1, 1, 1, 1, 1, 1],
607
+ [6, 4, -1, 6, 4, 6],
608
+ [6, 4, 0, 6, 4, 6],
609
+ [6, 0, 0, 6, 0, 6],
610
+ [6, 1, 1, 6, 1, 6],
611
+ [5, 5, 5, 5, 5, 5],
612
+ [0, 0, 0, 0, 0, 0],
613
+ [0, 0, 0, 0, 0, 0],
614
+ [1, 1, 1, 1, 1, 1],
615
+ [2, 2, 2, 2, 2, 2],
616
+ [0, 0, 0, 0, 0, 0],
617
+ [2, 0, 2, 2, 0, 2],
618
+ [1, 1, 1, 1, 1, 1],
619
+ [2, 2, 2, 2, 2, 2],
620
+ [0, 0, 0, 0, 0, 0],
621
+ [2, 0, 2, 2, 2, 2],
622
+ [1, 1, 1, 1, 1, 1],
623
+ [2, 4, 2, 2, 4, 2],
624
+ [0, 4, 0, 4, 4, 0],
625
+ [2, 0, 2, 2, 0, 2],
626
+ [1, 1, 1, 1, 1, 1],
627
+ [2, 2, 2, 2, 2, 2],
628
+ [0, 0, 0, 0, 0, 0],
629
+ [0, 0, 0, 0, 0, 0],
630
+ [1, 1, 1, 1, 1, 1],
631
+ [6, 1, 1, 6, -1, 6],
632
+ [6, 1, 1, 6, 0, 6],
633
+ [6, 0, 0, 6, 0, 6],
634
+ [6, 2, 2, 6, 2, 6],
635
+ [4, 1, 1, 4, 4, 1],
636
+ [0, 1, 1, 0, 0, 1],
637
+ [4, 0, 0, 4, 4, 4],
638
+ [2, 2, 2, 2, 2, 2],
639
+ [6, 1, 1, 6, 4, 6],
640
+ [6, 1, 1, 6, 4, 6],
641
+ [6, 0, 0, 6, 0, 6],
642
+ [6, 2, 2, 6, 2, 6],
643
+ [5, 1, 1, 5, 5, 1],
644
+ [0, 1, 1, 0, 0, 1],
645
+ [0, 0, 0, 0, 0, 0],
646
+ [2, 2, 2, 2, 2, 2],
647
+ [1, 1, 1, 1, 1, 1],
648
+ [0, 0, 0, 0, 0, 0],
649
+ [0, 0, 0, 0, 0, 0],
650
+ [6, 6, 6, 6, 6, 6],
651
+ [1, 1, 1, 1, 1, 1],
652
+ [0, 0, 0, 0, 0, 0],
653
+ [0, 0, 0, 0, 0, 0],
654
+ [4, 4, 4, 4, 4, 4],
655
+ [1, 1, 1, 1, 4, 1],
656
+ [0, 4, 0, 4, 4, 0],
657
+ [0, 0, 0, 0, 0, 0],
658
+ [4, 4, 4, 4, 4, 4],
659
+ [1, 1, 1, 1, 1, 1],
660
+ [0, 0, 0, 0, 0, 0],
661
+ [0, 5, 0, 5, 0, 5],
662
+ [5, 5, 5, 5, 5, 5],
663
+ [5, 5, 5, 5, 5, 5],
664
+ [0, 5, 0, 5, 0, 5],
665
+ [-1, 5, 0, 5, 0, 5],
666
+ [1, 5, 1, 5, 1, 5],
667
+ [4, 5, -1, 5, 4, 5],
668
+ [0, 5, 0, 5, 0, 5],
669
+ [4, 5, 0, 5, 4, 5],
670
+ [1, 5, 1, 5, 1, 5],
671
+ [4, 4, 4, 4, 4, 4],
672
+ [0, 4, 0, 4, 4, 4],
673
+ [0, 0, 0, 0, 0, 0],
674
+ [1, 1, 1, 1, 1, 1],
675
+ [6, 6, 6, 6, 6, 6],
676
+ [0, 0, 0, 0, 0, 0],
677
+ [0, 0, 0, 0, 0, 0],
678
+ [1, 1, 1, 1, 1, 1],
679
+ [2, 5, 2, 5, -1, 5],
680
+ [0, 5, 0, 5, 0, 5],
681
+ [2, 5, 2, 5, 0, 5],
682
+ [1, 5, 1, 5, 1, 5],
683
+ [2, 5, 2, 5, 4, 5],
684
+ [0, 5, 0, 5, 0, 5],
685
+ [2, 5, 2, 5, 4, 5],
686
+ [1, 5, 1, 5, 1, 5],
687
+ [2, 4, 2, 4, 4, 2],
688
+ [0, 4, 0, 4, 4, 4],
689
+ [2, 0, 2, 0, 0, 2],
690
+ [1, 1, 1, 1, 1, 1],
691
+ [2, 6, 2, 6, 6, 2],
692
+ [0, 0, 0, 0, 0, 0],
693
+ [2, 0, 2, 0, 0, 2],
694
+ [1, 1, 1, 1, 1, 1],
695
+ [1, 1, 1, 1, 1, 1],
696
+ [0, 1, 1, 1, 0, 1],
697
+ [0, 0, 0, 0, 0, 0],
698
+ [2, 2, 2, 2, 2, 2],
699
+ [4, 1, 1, 1, 4, 1],
700
+ [0, 1, 1, 1, 0, 1],
701
+ [4, 0, 0, 4, 4, 0],
702
+ [2, 2, 2, 2, 2, 2],
703
+ [1, 1, 1, 1, 1, 1],
704
+ [0, 1, 1, 1, 1, 1],
705
+ [0, 0, 0, 0, 0, 0],
706
+ [2, 2, 2, 2, 2, 2],
707
+ [1, 1, 1, 1, 1, 1],
708
+ [0, 0, 0, 0, 0, 0],
709
+ [0, 0, 0, 0, 0, 0],
710
+ [2, 2, 2, 2, 2, 2],
711
+ [1, 1, 1, 1, 1, 1],
712
+ [0, 0, 0, 0, 0, 0],
713
+ [0, 0, 0, 0, 0, 0],
714
+ [5, 5, 5, 5, 5, 5],
715
+ [1, 1, 1, 1, 4, 1],
716
+ [0, 0, 0, 0, 0, 0],
717
+ [4, 0, 0, 4, 4, 0],
718
+ [4, 4, 4, 4, 4, 4],
719
+ [1, 1, 1, 1, 1, 1],
720
+ [0, 0, 0, 0, 0, 0],
721
+ [0, 0, 0, 0, 0, 0],
722
+ [4, 4, 4, 4, 4, 4],
723
+ [1, 1, 1, 1, 1, 1],
724
+ [6, 0, 0, 6, 0, 6],
725
+ [0, 0, 0, 0, 0, 0],
726
+ [6, 6, 6, 6, 6, 6],
727
+ [5, 5, 5, 5, 5, 5],
728
+ [5, 5, 0, 5, 0, 5],
729
+ [5, 5, 0, 5, 0, 5],
730
+ [5, 5, 1, 5, 1, 5],
731
+ [4, 4, 4, 4, 4, 4],
732
+ [0, 0, 0, 0, 0, 0],
733
+ [4, 4, 0, 4, 4, 4],
734
+ [1, 1, 1, 1, 1, 1],
735
+ [4, 4, 4, 4, 4, 4],
736
+ [4, 4, 0, 4, 4, 4],
737
+ [0, 0, 0, 0, 0, 0],
738
+ [1, 1, 1, 1, 1, 1],
739
+ [8, 8, 8, 8, 8, 8],
740
+ [0, 0, 0, 0, 0, 0],
741
+ [0, 0, 0, 0, 0, 0],
742
+ [1, 1, 1, 1, 1, 1],
743
+ [2, 2, 2, 2, 2, 2],
744
+ [0, 0, 0, 0, 0, 0],
745
+ [2, 2, 2, 2, 0, 2],
746
+ [1, 1, 1, 1, 1, 1],
747
+ [2, 2, 2, 2, 2, 2],
748
+ [0, 0, 0, 0, 0, 0],
749
+ [2, 2, 2, 2, 2, 2],
750
+ [1, 1, 1, 1, 1, 1],
751
+ [2, 2, 2, 2, 2, 2],
752
+ [0, 0, 0, 0, 0, 0],
753
+ [0, 0, 0, 0, 0, 0],
754
+ [4, 1, 1, 4, 4, 1],
755
+ [2, 2, 2, 2, 2, 2],
756
+ [0, 0, 0, 0, 0, 0],
757
+ [0, 0, 0, 0, 0, 0],
758
+ [1, 1, 1, 1, 1, 1],
759
+ [1, 1, 1, 1, 1, 1],
760
+ [1, 1, 1, 1, 0, 1],
761
+ [0, 0, 0, 0, 0, 0],
762
+ [2, 2, 2, 2, 2, 2],
763
+ [1, 1, 1, 1, 1, 1],
764
+ [0, 0, 0, 0, 0, 0],
765
+ [0, 0, 0, 0, 0, 0],
766
+ [2, 4, 2, 4, 4, 2],
767
+ [1, 1, 1, 1, 1, 1],
768
+ [1, 1, 1, 1, 1, 1],
769
+ [0, 0, 0, 0, 0, 0],
770
+ [2, 2, 2, 2, 2, 2],
771
+ [1, 1, 1, 1, 1, 1],
772
+ [0, 0, 0, 0, 0, 0],
773
+ [0, 0, 0, 0, 0, 0],
774
+ [2, 2, 2, 2, 2, 2],
775
+ [1, 1, 1, 1, 1, 1],
776
+ [0, 0, 0, 0, 0, 0],
777
+ [0, 0, 0, 0, 0, 0],
778
+ [5, 5, 5, 5, 5, 5],
779
+ [1, 1, 1, 1, 1, 1],
780
+ [0, 0, 0, 0, 0, 0],
781
+ [0, 0, 0, 0, 0, 0],
782
+ [4, 4, 4, 4, 4, 4],
783
+ [1, 1, 1, 1, 1, 1],
784
+ [0, 0, 0, 0, 0, 0],
785
+ [0, 0, 0, 0, 0, 0],
786
+ [4, 4, 4, 4, 4, 4],
787
+ [1, 1, 1, 1, 1, 1],
788
+ [0, 0, 0, 0, 0, 0],
789
+ [0, 0, 0, 0, 0, 0],
790
+ [12, 12, 12, 12, 12, 12]
791
+ ]
src/models/lrm.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import mcubes
19
+ import nvdiffrast.torch as dr
20
+ from einops import rearrange, repeat
21
+
22
+ from .encoder.dino_wrapper import DinoWrapper
23
+ from .decoder.transformer import TriplaneTransformer
24
+ from .renderer.synthesizer import TriplaneSynthesizer
25
+ from ..utils.mesh_util import xatlas_uvmap
26
+
27
+
28
+ class InstantNeRF(nn.Module):
29
+ """
30
+ Full model of the large reconstruction model.
31
+ """
32
+ def __init__(
33
+ self,
34
+ encoder_freeze: bool = False,
35
+ encoder_model_name: str = 'facebook/dino-vitb16',
36
+ encoder_feat_dim: int = 768,
37
+ transformer_dim: int = 1024,
38
+ transformer_layers: int = 16,
39
+ transformer_heads: int = 16,
40
+ triplane_low_res: int = 32,
41
+ triplane_high_res: int = 64,
42
+ triplane_dim: int = 80,
43
+ rendering_samples_per_ray: int = 128,
44
+ ):
45
+ super().__init__()
46
+
47
+ # modules
48
+ self.encoder = DinoWrapper(
49
+ model_name=encoder_model_name,
50
+ freeze=encoder_freeze,
51
+ )
52
+
53
+ self.transformer = TriplaneTransformer(
54
+ inner_dim=transformer_dim,
55
+ num_layers=transformer_layers,
56
+ num_heads=transformer_heads,
57
+ image_feat_dim=encoder_feat_dim,
58
+ triplane_low_res=triplane_low_res,
59
+ triplane_high_res=triplane_high_res,
60
+ triplane_dim=triplane_dim,
61
+ )
62
+
63
+ self.synthesizer = TriplaneSynthesizer(
64
+ triplane_dim=triplane_dim,
65
+ samples_per_ray=rendering_samples_per_ray,
66
+ )
67
+
68
+ def forward_planes(self, images, cameras):
69
+ # images: [B, V, C_img, H_img, W_img]
70
+ # cameras: [B, V, 16]
71
+ B = images.shape[0]
72
+
73
+ # encode images
74
+ image_feats = self.encoder(images, cameras)
75
+ image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
76
+
77
+ # transformer generating planes
78
+ planes = self.transformer(image_feats)
79
+
80
+ return planes
81
+
82
+ def forward_synthesizer(self, planes, render_cameras, render_size: int):
83
+ render_results = self.synthesizer(
84
+ planes,
85
+ render_cameras,
86
+ render_size,
87
+ )
88
+ return render_results
89
+
90
+ def forward(self, images, cameras, render_cameras, render_size: int):
91
+ # images: [B, V, C_img, H_img, W_img]
92
+ # cameras: [B, V, 16]
93
+ # render_cameras: [B, M, D_cam_render]
94
+ # render_size: int
95
+ B, M = render_cameras.shape[:2]
96
+
97
+ planes = self.forward_planes(images, cameras)
98
+
99
+ # render target views
100
+ render_results = self.synthesizer(planes, render_cameras, render_size)
101
+
102
+ return {
103
+ 'planes': planes,
104
+ **render_results,
105
+ }
106
+
107
+ def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
108
+ '''
109
+ Predict Texture given triplanes
110
+ :param planes: the triplane feature map
111
+ :param tex_pos: Position we want to query the texture field
112
+ :param hard_mask: 2D silhoueete of the rendered image
113
+ '''
114
+ tex_pos = torch.cat(tex_pos, dim=0)
115
+ if not hard_mask is None:
116
+ tex_pos = tex_pos * hard_mask.float()
117
+ batch_size = tex_pos.shape[0]
118
+ tex_pos = tex_pos.reshape(batch_size, -1, 3)
119
+ ###################
120
+ # We use mask to get the texture location (to save the memory)
121
+ if hard_mask is not None:
122
+ n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
123
+ sample_tex_pose_list = []
124
+ max_point = n_point_list.max()
125
+ expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
126
+ for i in range(tex_pos.shape[0]):
127
+ tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
128
+ if tex_pos_one_shape.shape[1] < max_point:
129
+ tex_pos_one_shape = torch.cat(
130
+ [tex_pos_one_shape, torch.zeros(
131
+ 1, max_point - tex_pos_one_shape.shape[1], 3,
132
+ device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
133
+ sample_tex_pose_list.append(tex_pos_one_shape)
134
+ tex_pos = torch.cat(sample_tex_pose_list, dim=0)
135
+
136
+ tex_feat = torch.utils.checkpoint.checkpoint(
137
+ self.synthesizer.forward_points,
138
+ planes,
139
+ tex_pos,
140
+ use_reentrant=False,
141
+ )['rgb']
142
+
143
+ if hard_mask is not None:
144
+ final_tex_feat = torch.zeros(
145
+ planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
146
+ expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
147
+ for i in range(planes.shape[0]):
148
+ final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
149
+ tex_feat = final_tex_feat
150
+
151
+ return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
152
+
153
+ def extract_mesh(
154
+ self,
155
+ planes: torch.Tensor,
156
+ mesh_resolution: int = 256,
157
+ mesh_threshold: int = 10.0,
158
+ use_texture_map: bool = False,
159
+ texture_resolution: int = 1024,
160
+ **kwargs,
161
+ ):
162
+ '''
163
+ Extract a 3D mesh from triplane nerf. Only support batch_size 1.
164
+ :param planes: triplane features
165
+ :param mesh_resolution: marching cubes resolution
166
+ :param mesh_threshold: iso-surface threshold
167
+ :param use_texture_map: use texture map or vertex color
168
+ :param texture_resolution: the resolution of texture map
169
+ '''
170
+ assert planes.shape[0] == 1
171
+ device = planes.device
172
+
173
+ grid_out = self.synthesizer.forward_grid(
174
+ planes=planes,
175
+ grid_size=mesh_resolution,
176
+ )
177
+
178
+ vertices, faces = mcubes.marching_cubes(
179
+ grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
180
+ mesh_threshold,
181
+ )
182
+ vertices = vertices / (mesh_resolution - 1) * 2 - 1
183
+
184
+ if not use_texture_map:
185
+ # query vertex colors
186
+ vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)
187
+ vertices_colors = self.synthesizer.forward_points(
188
+ planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy()
189
+ vertices_colors = (vertices_colors * 255).astype(np.uint8)
190
+
191
+ return vertices, faces, vertices_colors
192
+
193
+ # use x-atlas to get uv mapping for the mesh
194
+ vertices = torch.tensor(vertices, dtype=torch.float32, device=device)
195
+ faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device)
196
+
197
+ ctx = dr.RasterizeCudaContext(device=device)
198
+ uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
199
+ ctx, vertices, faces, resolution=texture_resolution)
200
+ tex_hard_mask = tex_hard_mask.float()
201
+
202
+ # query the texture field to get the RGB color for texture map
203
+ tex_feat = self.get_texture_prediction(
204
+ planes, [gb_pos], tex_hard_mask)
205
+ background_feature = torch.zeros_like(tex_feat)
206
+ img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
207
+ texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
208
+
209
+ return vertices, faces, uvs, mesh_tex_idx, texture_map
src/models/lrm_mesh.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tencent Inc
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import nvdiffrast.torch as dr
19
+ from einops import rearrange, repeat
20
+
21
+ from .encoder.dino_wrapper import DinoWrapper
22
+ from .decoder.transformer import TriplaneTransformer
23
+ from .renderer.synthesizer_mesh import TriplaneSynthesizer
24
+ from .geometry.camera.perspective_camera import PerspectiveCamera
25
+ from .geometry.render.neural_render import NeuralRender
26
+ from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry
27
+ from ..utils.mesh_util import xatlas_uvmap
28
+
29
+
30
+ class InstantMesh(nn.Module):
31
+ """
32
+ Full model of the large reconstruction model.
33
+ """
34
+ def __init__(
35
+ self,
36
+ encoder_freeze: bool = False,
37
+ encoder_model_name: str = 'facebook/dino-vitb16',
38
+ encoder_feat_dim: int = 768,
39
+ transformer_dim: int = 1024,
40
+ transformer_layers: int = 16,
41
+ transformer_heads: int = 16,
42
+ triplane_low_res: int = 32,
43
+ triplane_high_res: int = 64,
44
+ triplane_dim: int = 80,
45
+ rendering_samples_per_ray: int = 128,
46
+ grid_res: int = 128,
47
+ grid_scale: float = 2.0,
48
+ ):
49
+ super().__init__()
50
+
51
+ # attributes
52
+ self.grid_res = grid_res
53
+ self.grid_scale = grid_scale
54
+ self.deformation_multiplier = 4.0
55
+
56
+ # modules
57
+ self.encoder = DinoWrapper(
58
+ model_name=encoder_model_name,
59
+ freeze=encoder_freeze,
60
+ )
61
+
62
+ self.transformer = TriplaneTransformer(
63
+ inner_dim=transformer_dim,
64
+ num_layers=transformer_layers,
65
+ num_heads=transformer_heads,
66
+ image_feat_dim=encoder_feat_dim,
67
+ triplane_low_res=triplane_low_res,
68
+ triplane_high_res=triplane_high_res,
69
+ triplane_dim=triplane_dim,
70
+ )
71
+
72
+ self.synthesizer = TriplaneSynthesizer(
73
+ triplane_dim=triplane_dim,
74
+ samples_per_ray=rendering_samples_per_ray,
75
+ )
76
+
77
+ def init_flexicubes_geometry(self, device, fovy=50.0):
78
+ camera = PerspectiveCamera(fovy=fovy, device=device)
79
+ renderer = NeuralRender(device, camera_model=camera)
80
+ self.geometry = FlexiCubesGeometry(
81
+ grid_res=self.grid_res,
82
+ scale=self.grid_scale,
83
+ renderer=renderer,
84
+ render_type='neural_render',
85
+ device=device,
86
+ )
87
+
88
+ def forward_planes(self, images, cameras):
89
+ # images: [B, V, C_img, H_img, W_img]
90
+ # cameras: [B, V, 16]
91
+ B = images.shape[0]
92
+
93
+ # encode images
94
+ image_feats = self.encoder(images, cameras)
95
+ image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
96
+
97
+ # decode triplanes
98
+ planes = self.transformer(image_feats)
99
+
100
+ return planes
101
+
102
+ def get_sdf_deformation_prediction(self, planes):
103
+ '''
104
+ Predict SDF and deformation for tetrahedron vertices
105
+ :param planes: triplane feature map for the geometry
106
+ '''
107
+ init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1)
108
+
109
+ # Step 1: predict the SDF and deformation
110
+ sdf, deformation, weight = torch.utils.checkpoint.checkpoint(
111
+ self.synthesizer.get_geometry_prediction,
112
+ planes,
113
+ init_position,
114
+ self.geometry.indices,
115
+ use_reentrant=False,
116
+ )
117
+
118
+ # Step 2: Normalize the deformation to avoid the flipped triangles.
119
+ deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)
120
+ sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
121
+
122
+ ####
123
+ # Step 3: Fix some sdf if we observe empty shape (full positive or full negative)
124
+ sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1))
125
+ sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
126
+ pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
127
+ neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
128
+ zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
129
+ if torch.sum(zero_surface).item() > 0:
130
+ update_sdf = torch.zeros_like(sdf[0:1])
131
+ max_sdf = sdf.max()
132
+ min_sdf = sdf.min()
133
+ update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero
134
+ update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero
135
+ new_sdf = torch.zeros_like(sdf)
136
+ for i_batch in range(zero_surface.shape[0]):
137
+ if zero_surface[i_batch]:
138
+ new_sdf[i_batch:i_batch + 1] += update_sdf
139
+ update_mask = (new_sdf == 0).float()
140
+ # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
141
+ sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
142
+ sdf_reg_loss = sdf_reg_loss * zero_surface.float()
143
+ sdf = sdf * update_mask + new_sdf * (1 - update_mask)
144
+
145
+ # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative)
146
+ final_sdf = []
147
+ final_def = []
148
+ for i_batch in range(zero_surface.shape[0]):
149
+ if zero_surface[i_batch]:
150
+ final_sdf.append(sdf[i_batch: i_batch + 1].detach())
151
+ final_def.append(deformation[i_batch: i_batch + 1].detach())
152
+ else:
153
+ final_sdf.append(sdf[i_batch: i_batch + 1])
154
+ final_def.append(deformation[i_batch: i_batch + 1])
155
+ sdf = torch.cat(final_sdf, dim=0)
156
+ deformation = torch.cat(final_def, dim=0)
157
+ return sdf, deformation, sdf_reg_loss, weight
158
+
159
+ def get_geometry_prediction(self, planes=None):
160
+ '''
161
+ Function to generate mesh with give triplanes
162
+ :param planes: triplane features
163
+ '''
164
+ # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid.
165
+ sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes)
166
+ v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation
167
+ tets = self.geometry.indices
168
+ n_batch = planes.shape[0]
169
+ v_list = []
170
+ f_list = []
171
+ flexicubes_surface_reg_list = []
172
+
173
+ # Step 2: Using marching tet to obtain the mesh
174
+ for i_batch in range(n_batch):
175
+ verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
176
+ v_deformed[i_batch],
177
+ sdf[i_batch].squeeze(dim=-1),
178
+ with_uv=False,
179
+ indices=tets,
180
+ weight_n=weight[i_batch].squeeze(dim=-1),
181
+ is_training=self.training,
182
+ )
183
+ flexicubes_surface_reg_list.append(flexicubes_surface_reg)
184
+ v_list.append(verts)
185
+ f_list.append(faces)
186
+
187
+ flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
188
+ flexicubes_weight_reg = (weight ** 2).mean()
189
+
190
+ return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
191
+
192
+ def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
193
+ '''
194
+ Predict Texture given triplanes
195
+ :param planes: the triplane feature map
196
+ :param tex_pos: Position we want to query the texture field
197
+ :param hard_mask: 2D silhoueete of the rendered image
198
+ '''
199
+ tex_pos = torch.cat(tex_pos, dim=0)
200
+ if not hard_mask is None:
201
+ tex_pos = tex_pos * hard_mask.float()
202
+ batch_size = tex_pos.shape[0]
203
+ tex_pos = tex_pos.reshape(batch_size, -1, 3)
204
+ ###################
205
+ # We use mask to get the texture location (to save the memory)
206
+ if hard_mask is not None:
207
+ n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
208
+ sample_tex_pose_list = []
209
+ max_point = n_point_list.max()
210
+ expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
211
+ for i in range(tex_pos.shape[0]):
212
+ tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
213
+ if tex_pos_one_shape.shape[1] < max_point:
214
+ tex_pos_one_shape = torch.cat(
215
+ [tex_pos_one_shape, torch.zeros(
216
+ 1, max_point - tex_pos_one_shape.shape[1], 3,
217
+ device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
218
+ sample_tex_pose_list.append(tex_pos_one_shape)
219
+ tex_pos = torch.cat(sample_tex_pose_list, dim=0)
220
+
221
+ tex_feat = torch.utils.checkpoint.checkpoint(
222
+ self.synthesizer.get_texture_prediction,
223
+ planes,
224
+ tex_pos,
225
+ use_reentrant=False,
226
+ )
227
+
228
+ if hard_mask is not None:
229
+ final_tex_feat = torch.zeros(
230
+ planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
231
+ expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
232
+ for i in range(planes.shape[0]):
233
+ final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
234
+ tex_feat = final_tex_feat
235
+
236
+ return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
237
+
238
+ def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
239
+ '''
240
+ Function to render a generated mesh with nvdiffrast
241
+ :param mesh_v: List of vertices for the mesh
242
+ :param mesh_f: List of faces for the mesh
243
+ :param cam_mv: 4x4 rotation matrix
244
+ :return:
245
+ '''
246
+ return_value_list = []
247
+ for i_mesh in range(len(mesh_v)):
248
+ return_value = self.geometry.render_mesh(
249
+ mesh_v[i_mesh],
250
+ mesh_f[i_mesh].int(),
251
+ cam_mv[i_mesh],
252
+ resolution=render_size,
253
+ hierarchical_mask=False
254
+ )
255
+ return_value_list.append(return_value)
256
+
257
+ return_keys = return_value_list[0].keys()
258
+ return_value = dict()
259
+ for k in return_keys:
260
+ value = [v[k] for v in return_value_list]
261
+ return_value[k] = value
262
+
263
+ mask = torch.cat(return_value['mask'], dim=0)
264
+ hard_mask = torch.cat(return_value['hard_mask'], dim=0)
265
+ tex_pos = return_value['tex_pos']
266
+ depth = torch.cat(return_value['depth'], dim=0)
267
+ normal = torch.cat(return_value['normal'], dim=0)
268
+ return mask, hard_mask, tex_pos, depth, normal
269
+
270
+ def forward_geometry(self, planes, render_cameras, render_size=256):
271
+ '''
272
+ Main function of our Generator. It first generate 3D mesh, then render it into 2D image
273
+ with given `render_cameras`.
274
+ :param planes: triplane features
275
+ :param render_cameras: cameras to render generated 3D shape
276
+ '''
277
+ B, NV = render_cameras.shape[:2]
278
+
279
+ # Generate 3D mesh first
280
+ mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
281
+
282
+ # Render the mesh into 2D image (get 3d position of each image plane)
283
+ cam_mv = render_cameras
284
+ run_n_view = cam_mv.shape[1]
285
+ antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size)
286
+
287
+ tex_hard_mask = hard_mask
288
+ tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
289
+ tex_hard_mask = torch.cat(
290
+ [torch.cat(
291
+ [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
292
+ for i_view in range(run_n_view)], dim=2)
293
+ for i in range(planes.shape[0])], dim=0)
294
+
295
+ # Querying the texture field to predict the texture feature for each pixel on the image
296
+ tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
297
+ background_feature = torch.ones_like(tex_feat) # white background
298
+
299
+ # Merge them together
300
+ img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
301
+
302
+ # We should split it back to the original image shape
303
+ img_feat = torch.cat(
304
+ [torch.cat(
305
+ [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)]
306
+ for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0)
307
+
308
+ img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
309
+ antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
310
+ depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive
311
+ normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
312
+
313
+ out = {
314
+ 'img': img,
315
+ 'mask': antilias_mask,
316
+ 'depth': depth,
317
+ 'normal': normal,
318
+ 'sdf': sdf,
319
+ 'mesh_v': mesh_v,
320
+ 'mesh_f': mesh_f,
321
+ 'sdf_reg_loss': sdf_reg_loss,
322
+ }
323
+ return out
324
+
325
+ def forward(self, images, cameras, render_cameras, render_size: int):
326
+ # images: [B, V, C_img, H_img, W_img]
327
+ # cameras: [B, V, 16]
328
+ # render_cameras: [B, M, D_cam_render]
329
+ # render_size: int
330
+ B, M = render_cameras.shape[:2]
331
+
332
+ planes = self.forward_planes(images, cameras)
333
+ out = self.forward_geometry(planes, render_cameras, render_size=render_size)
334
+
335
+ return {
336
+ 'planes': planes,
337
+ **out
338
+ }
339
+
340
+ def extract_mesh(
341
+ self,
342
+ planes: torch.Tensor,
343
+ use_texture_map: bool = False,
344
+ texture_resolution: int = 1024,
345
+ **kwargs,
346
+ ):
347
+ '''
348
+ Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
349
+ :param planes: triplane features
350
+ :param use_texture_map: use texture map or vertex color
351
+ :param texture_resolution: the resolution of texure map
352
+ '''
353
+ assert planes.shape[0] == 1
354
+ device = planes.device
355
+
356
+ # predict geometry first
357
+ mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
358
+ vertices, faces = mesh_v[0], mesh_f[0]
359
+
360
+ if not use_texture_map:
361
+ # query vertex colors
362
+ vertices_tensor = vertices.unsqueeze(0)
363
+ vertices_colors = self.synthesizer.get_texture_prediction(
364
+ planes, vertices_tensor).clamp(0, 1).squeeze(0).cpu().numpy()
365
+ vertices_colors = (vertices_colors * 255).astype(np.uint8)
366
+
367
+ return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors
368
+
369
+ # use x-atlas to get uv mapping for the mesh
370
+ ctx = dr.RasterizeCudaContext(device=device)
371
+ uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
372
+ self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution)
373
+ tex_hard_mask = tex_hard_mask.float()
374
+
375
+ # query the texture field to get the RGB color for texture map
376
+ tex_feat = self.get_texture_prediction(
377
+ planes, [gb_pos], tex_hard_mask)
378
+ background_feature = torch.zeros_like(tex_feat)
379
+ img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
380
+ texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
381
+
382
+ return vertices, faces, uvs, mesh_tex_idx, texture_map
src/models/renderer/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
src/models/renderer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (155 Bytes). View file
 
src/models/renderer/__pycache__/synthesizer_mesh.cpython-312.pyc ADDED
Binary file (7.41 kB). View file
 
src/models/renderer/synthesizer.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ORIGINAL LICENSE
2
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
+ #
5
+ # Modified by Jiale Xu
6
+ # The modifications are subject to the same license as the original.
7
+
8
+
9
+ import itertools
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from .utils.renderer import ImportanceRenderer
14
+ from .utils.ray_sampler import RaySampler
15
+
16
+
17
+ class OSGDecoder(nn.Module):
18
+ """
19
+ Triplane decoder that gives RGB and sigma values from sampled features.
20
+ Using ReLU here instead of Softplus in the original implementation.
21
+
22
+ Reference:
23
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
24
+ """
25
+ def __init__(self, n_features: int,
26
+ hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
27
+ super().__init__()
28
+ self.net = nn.Sequential(
29
+ nn.Linear(3 * n_features, hidden_dim),
30
+ activation(),
31
+ *itertools.chain(*[[
32
+ nn.Linear(hidden_dim, hidden_dim),
33
+ activation(),
34
+ ] for _ in range(num_layers - 2)]),
35
+ nn.Linear(hidden_dim, 1 + 3),
36
+ )
37
+ # init all bias to zero
38
+ for m in self.modules():
39
+ if isinstance(m, nn.Linear):
40
+ nn.init.zeros_(m.bias)
41
+
42
+ def forward(self, sampled_features, ray_directions):
43
+ # Aggregate features by mean
44
+ # sampled_features = sampled_features.mean(1)
45
+ # Aggregate features by concatenation
46
+ _N, n_planes, _M, _C = sampled_features.shape
47
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
48
+ x = sampled_features
49
+
50
+ N, M, C = x.shape
51
+ x = x.contiguous().view(N*M, C)
52
+
53
+ x = self.net(x)
54
+ x = x.view(N, M, -1)
55
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
56
+ sigma = x[..., 0:1]
57
+
58
+ return {'rgb': rgb, 'sigma': sigma}
59
+
60
+
61
+ class TriplaneSynthesizer(nn.Module):
62
+ """
63
+ Synthesizer that renders a triplane volume with planes and a camera.
64
+
65
+ Reference:
66
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
67
+ """
68
+
69
+ DEFAULT_RENDERING_KWARGS = {
70
+ 'ray_start': 'auto',
71
+ 'ray_end': 'auto',
72
+ 'box_warp': 2.,
73
+ 'white_back': True,
74
+ 'disparity_space_sampling': False,
75
+ 'clamp_mode': 'softplus',
76
+ 'sampler_bbox_min': -1.,
77
+ 'sampler_bbox_max': 1.,
78
+ }
79
+
80
+ def __init__(self, triplane_dim: int, samples_per_ray: int):
81
+ super().__init__()
82
+
83
+ # attributes
84
+ self.triplane_dim = triplane_dim
85
+ self.rendering_kwargs = {
86
+ **self.DEFAULT_RENDERING_KWARGS,
87
+ 'depth_resolution': samples_per_ray // 2,
88
+ 'depth_resolution_importance': samples_per_ray // 2,
89
+ }
90
+
91
+ # renderings
92
+ self.renderer = ImportanceRenderer()
93
+ self.ray_sampler = RaySampler()
94
+
95
+ # modules
96
+ self.decoder = OSGDecoder(n_features=triplane_dim)
97
+
98
+ def forward(self, planes, cameras, render_size=128, crop_params=None):
99
+ # planes: (N, 3, D', H', W')
100
+ # cameras: (N, M, D_cam)
101
+ # render_size: int
102
+ assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
103
+ N, M = cameras.shape[:2]
104
+
105
+ cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
106
+ intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
107
+
108
+ # Create a batch of rays for volume rendering
109
+ ray_origins, ray_directions = self.ray_sampler(
110
+ cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
111
+ intrinsics=intrinsics.reshape(-1, 3, 3),
112
+ render_size=render_size,
113
+ )
114
+ assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
115
+ assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
116
+
117
+ # Crop rays if crop_params is available
118
+ if crop_params is not None:
119
+ ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3)
120
+ ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3)
121
+ i, j, h, w = crop_params
122
+ ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
123
+ ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
124
+
125
+ # Perform volume rendering
126
+ rgb_samples, depth_samples, weights_samples = self.renderer(
127
+ planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
128
+ )
129
+
130
+ # Reshape into 'raw' neural-rendered image
131
+ if crop_params is not None:
132
+ Himg, Wimg = crop_params[2:]
133
+ else:
134
+ Himg = Wimg = render_size
135
+ rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
136
+ depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
137
+ weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
138
+
139
+ out = {
140
+ 'images_rgb': rgb_images,
141
+ 'images_depth': depth_images,
142
+ 'images_weight': weight_images,
143
+ }
144
+ return out
145
+
146
+ def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
147
+ # planes: (N, 3, D', H', W')
148
+ # grid_size: int
149
+ # aabb: (N, 2, 3)
150
+ if aabb is None:
151
+ aabb = torch.tensor([
152
+ [self.rendering_kwargs['sampler_bbox_min']] * 3,
153
+ [self.rendering_kwargs['sampler_bbox_max']] * 3,
154
+ ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
155
+ assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
156
+ N = planes.shape[0]
157
+
158
+ # create grid points for triplane query
159
+ grid_points = []
160
+ for i in range(N):
161
+ grid_points.append(torch.stack(torch.meshgrid(
162
+ torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
163
+ torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
164
+ torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
165
+ indexing='ij',
166
+ ), dim=-1).reshape(-1, 3))
167
+ cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
168
+
169
+ features = self.forward_points(planes, cube_grid)
170
+
171
+ # reshape into grid
172
+ features = {
173
+ k: v.reshape(N, grid_size, grid_size, grid_size, -1)
174
+ for k, v in features.items()
175
+ }
176
+ return features
177
+
178
+ def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
179
+ # planes: (N, 3, D', H', W')
180
+ # points: (N, P, 3)
181
+ N, P = points.shape[:2]
182
+
183
+ # query triplane in chunks
184
+ outs = []
185
+ for i in range(0, points.shape[1], chunk_size):
186
+ chunk_points = points[:, i:i+chunk_size]
187
+
188
+ # query triplane
189
+ chunk_out = self.renderer.run_model_activated(
190
+ planes=planes,
191
+ decoder=self.decoder,
192
+ sample_coordinates=chunk_points,
193
+ sample_directions=torch.zeros_like(chunk_points),
194
+ options=self.rendering_kwargs,
195
+ )
196
+ outs.append(chunk_out)
197
+
198
+ # concatenate the outputs
199
+ point_features = {
200
+ k: torch.cat([out[k] for out in outs], dim=1)
201
+ for k in outs[0].keys()
202
+ }
203
+ return point_features
src/models/renderer/synthesizer_mesh.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ORIGINAL LICENSE
2
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
+ #
5
+ # Modified by Jiale Xu
6
+ # The modifications are subject to the same license as the original.
7
+
8
+ import itertools
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes
13
+
14
+
15
+ class OSGDecoder(nn.Module):
16
+ """
17
+ Triplane decoder that gives RGB and sigma values from sampled features.
18
+ Using ReLU here instead of Softplus in the original implementation.
19
+
20
+ Reference:
21
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
22
+ """
23
+ def __init__(self, n_features: int,
24
+ hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
25
+ super().__init__()
26
+
27
+ self.net_sdf = nn.Sequential(
28
+ nn.Linear(3 * n_features, hidden_dim),
29
+ activation(),
30
+ *itertools.chain(*[[
31
+ nn.Linear(hidden_dim, hidden_dim),
32
+ activation(),
33
+ ] for _ in range(num_layers - 2)]),
34
+ nn.Linear(hidden_dim, 1),
35
+ )
36
+ self.net_rgb = nn.Sequential(
37
+ nn.Linear(3 * n_features, hidden_dim),
38
+ activation(),
39
+ *itertools.chain(*[[
40
+ nn.Linear(hidden_dim, hidden_dim),
41
+ activation(),
42
+ ] for _ in range(num_layers - 2)]),
43
+ nn.Linear(hidden_dim, 3),
44
+ )
45
+ self.net_deformation = nn.Sequential(
46
+ nn.Linear(3 * n_features, hidden_dim),
47
+ activation(),
48
+ *itertools.chain(*[[
49
+ nn.Linear(hidden_dim, hidden_dim),
50
+ activation(),
51
+ ] for _ in range(num_layers - 2)]),
52
+ nn.Linear(hidden_dim, 3),
53
+ )
54
+ self.net_weight = nn.Sequential(
55
+ nn.Linear(8 * 3 * n_features, hidden_dim),
56
+ activation(),
57
+ *itertools.chain(*[[
58
+ nn.Linear(hidden_dim, hidden_dim),
59
+ activation(),
60
+ ] for _ in range(num_layers - 2)]),
61
+ nn.Linear(hidden_dim, 21),
62
+ )
63
+
64
+ # init all bias to zero
65
+ for m in self.modules():
66
+ if isinstance(m, nn.Linear):
67
+ nn.init.zeros_(m.bias)
68
+
69
+ def get_geometry_prediction(self, sampled_features, flexicubes_indices):
70
+ _N, n_planes, _M, _C = sampled_features.shape
71
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
72
+
73
+ sdf = self.net_sdf(sampled_features)
74
+ deformation = self.net_deformation(sampled_features)
75
+
76
+ grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1)
77
+ grid_features = grid_features.reshape(
78
+ sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1])
79
+ weight = self.net_weight(grid_features) * 0.1
80
+
81
+ return sdf, deformation, weight
82
+
83
+ def get_texture_prediction(self, sampled_features):
84
+ _N, n_planes, _M, _C = sampled_features.shape
85
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
86
+
87
+ rgb = self.net_rgb(sampled_features)
88
+ rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
89
+
90
+ return rgb
91
+
92
+
93
+ class TriplaneSynthesizer(nn.Module):
94
+ """
95
+ Synthesizer that renders a triplane volume with planes and a camera.
96
+
97
+ Reference:
98
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
99
+ """
100
+
101
+ DEFAULT_RENDERING_KWARGS = {
102
+ 'ray_start': 'auto',
103
+ 'ray_end': 'auto',
104
+ 'box_warp': 2.,
105
+ 'white_back': True,
106
+ 'disparity_space_sampling': False,
107
+ 'clamp_mode': 'softplus',
108
+ 'sampler_bbox_min': -1.,
109
+ 'sampler_bbox_max': 1.,
110
+ }
111
+
112
+ def __init__(self, triplane_dim: int, samples_per_ray: int):
113
+ super().__init__()
114
+
115
+ # attributes
116
+ self.triplane_dim = triplane_dim
117
+ self.rendering_kwargs = {
118
+ **self.DEFAULT_RENDERING_KWARGS,
119
+ 'depth_resolution': samples_per_ray // 2,
120
+ 'depth_resolution_importance': samples_per_ray // 2,
121
+ }
122
+
123
+ # modules
124
+ self.plane_axes = generate_planes()
125
+ self.decoder = OSGDecoder(n_features=triplane_dim)
126
+
127
+ def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices):
128
+ plane_axes = self.plane_axes.to(planes.device)
129
+ sampled_features = sample_from_planes(
130
+ plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
131
+
132
+ sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
133
+ return sdf, deformation, weight
134
+
135
+ def get_texture_prediction(self, planes, sample_coordinates):
136
+ plane_axes = self.plane_axes.to(planes.device)
137
+ sampled_features = sample_from_planes(
138
+ plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
139
+
140
+ rgb = self.decoder.get_texture_prediction(sampled_features)
141
+ return rgb
src/models/renderer/utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
src/models/renderer/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (161 Bytes). View file
 
src/models/renderer/utils/__pycache__/math_utils.cpython-312.pyc ADDED
Binary file (5.1 kB). View file