jadechoghari HF Staff commited on
Commit
5413a8c
·
verified ·
1 Parent(s): 4dbdea8

add export_mesh and export_video

Browse files
Files changed (1) hide show
  1. modeling.py +131 -13
modeling.py CHANGED
@@ -3,12 +3,12 @@
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, PretrainedConfig
5
  import torch
6
-
7
- # import dinowrapper
8
- from .dino_wrapper2 import DinoWrapper
9
- from .transformer import TriplaneTransformer
10
- from .synthesizer_part import TriplaneSynthesizer
11
- # from .processor import LRMImageProcessor
12
 
13
  class CameraEmbedder(nn.Module):
14
  def __init__(self, raw_dim: int, embed_dim: int):
@@ -67,14 +67,8 @@ class LRMGenerator(PreTrainedModel):
67
  triplane_dim=config.triplane_dim, samples_per_ray=config.rendering_samples_per_ray,
68
  )
69
 
70
- def forward(self, image, camera):
71
 
72
- # we use image processor directly in the forward pass
73
- #TODO: we should have the following:
74
- # processor = AutoProcessor.from_pretrained("jadechoghari/vfusion3d")
75
- # processed_image, source_camera = processor(image)
76
- #
77
-
78
  assert image.shape[0] == camera.shape[0], "Batch size mismatch"
79
  N = image.shape[0]
80
 
@@ -92,4 +86,128 @@ class LRMGenerator(PreTrainedModel):
92
  planes = self.transformer(image_feats, camera_embeddings)
93
  assert planes.shape[0] == N, "Batch size mismatch for planes"
94
  assert planes.shape[1] == 3, "Planes should have 3 channels"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return planes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, PretrainedConfig
5
  import torch
6
+ import numpy as np
7
+ import math
8
+
9
+ from dino_wrapper2 import DinoWrapper
10
+ from transformer import TriplaneTransformer
11
+ from synthesizer_part import TriplaneSynthesizer
12
 
13
  class CameraEmbedder(nn.Module):
14
  def __init__(self, raw_dim: int, embed_dim: int):
 
67
  triplane_dim=config.triplane_dim, samples_per_ray=config.rendering_samples_per_ray,
68
  )
69
 
70
+ def forward(self, image, camera, export_mesh=False, mesh_size=256, render_size=256, export_video=False, fps=30):
71
 
 
 
 
 
 
 
72
  assert image.shape[0] == camera.shape[0], "Batch size mismatch"
73
  N = image.shape[0]
74
 
 
86
  planes = self.transformer(image_feats, camera_embeddings)
87
  assert planes.shape[0] == N, "Batch size mismatch for planes"
88
  assert planes.shape[1] == 3, "Planes should have 3 channels"
89
+
90
+ # Generate the mesh
91
+ if export_mesh:
92
+ import mcubes
93
+ import trimesh
94
+ grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
95
+ vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
96
+ vtx = vtx / (mesh_size - 1) * 2 - 1
97
+ vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0)
98
+ vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
99
+ vtx_colors = (vtx_colors * 255).astype(np.uint8)
100
+ mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
101
+
102
+ mesh_path = "awesome_mesh.obj"
103
+ mesh.export(mesh_path, 'obj')
104
+
105
+ return planes, mesh_path
106
+
107
+ # Generate video
108
+ if export_video:
109
+ render_cameras = self._default_render_cameras(batch_size=N).to(image.device)
110
+
111
+ frames = []
112
+ chunk_size = 1 # Adjust chunk size as needed
113
+ for i in range(0, render_cameras.shape[1], chunk_size):
114
+ frame_chunk = self.synthesizer(
115
+ planes,
116
+ render_cameras[:, i:i + chunk_size],
117
+ render_size,
118
+ render_size,
119
+ 0,
120
+ 0
121
+ )
122
+ frames.append(frame_chunk['images_rgb'])
123
+
124
+ frames = torch.cat(frames, dim=1)
125
+ frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
126
+
127
+ # Save video
128
+ video_path = "awesome_video.mp4"
129
+ imageio.mimwrite(video_path, frames, fps=fps)
130
+
131
+ return planes, video_path
132
+
133
  return planes
134
+
135
+ # Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py
136
+ # and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py
137
+ def _default_intrinsics(self):
138
+ fx = fy = 384
139
+ cx = cy = 256
140
+ w = h = 512
141
+ intrinsics = torch.tensor([
142
+ [fx, fy],
143
+ [cx, cy],
144
+ [w, h],
145
+ ], dtype=torch.float32)
146
+ return intrinsics
147
+
148
+ def _default_render_cameras(self, batch_size=1):
149
+ M = 160 # Number of views
150
+ radius = 1.5
151
+ elevation = 0
152
+
153
+ camera_positions = []
154
+ rand_theta = np.random.uniform(0, np.pi / 180)
155
+ elevation = math.radians(elevation)
156
+ for i in range(M):
157
+ theta = 2 * math.pi * i / M + rand_theta
158
+ x = radius * math.cos(theta) * math.cos(elevation)
159
+ y = radius * math.sin(theta) * math.cos(elevation)
160
+ z = radius * math.sin(elevation)
161
+ camera_positions.append([x, y, z])
162
+
163
+ camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
164
+ extrinsics = self.center_looking_at_camera_pose(camera_positions)
165
+
166
+ intrinsics = self._default_intrinsics().unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
167
+ render_cameras = self.build_camera_standard(extrinsics, intrinsics)
168
+
169
+ return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
170
+
171
+ def center_looking_at_camera_pose(self, camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
172
+ if look_at is None:
173
+ look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
174
+ if up_world is None:
175
+ up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
176
+ look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
177
+ up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
178
+
179
+ z_axis = camera_position - look_at
180
+ z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
181
+ x_axis = torch.cross(up_world, z_axis)
182
+ x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
183
+ y_axis = torch.cross(z_axis, x_axis)
184
+ y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
185
+ extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
186
+ return extrinsics
187
+
188
+ def get_normalized_camera_intrinsics(self, intrinsics: torch.Tensor):
189
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
190
+ cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
191
+ width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
192
+ fx, fy = fx / width, fy / height
193
+ cx, cy = cx / width, cy / height
194
+ return fx, fy, cx, cy
195
+
196
+ def build_camera_standard(self, RT: torch.Tensor, intrinsics: torch.Tensor):
197
+ E = self.compose_extrinsic_RT(RT)
198
+ fx, fy, cx, cy = self.get_normalized_camera_intrinsics(intrinsics)
199
+ I = torch.stack([
200
+ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
201
+ torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
202
+ torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
203
+ ], dim=1)
204
+ return torch.cat([
205
+ E.reshape(-1, 16),
206
+ I.reshape(-1, 9),
207
+ ], dim=-1)
208
+
209
+ def compose_extrinsic_RT(self, RT: torch.Tensor):
210
+ return torch.cat([
211
+ RT,
212
+ torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(RT.shape[0], 1, 1).to(RT.device)
213
+ ], dim=1)