add export_mesh and export_video
Browse files- modeling.py +131 -13
modeling.py
CHANGED
@@ -3,12 +3,12 @@
|
|
3 |
import torch.nn as nn
|
4 |
from transformers import PreTrainedModel, PretrainedConfig
|
5 |
import torch
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
from
|
10 |
-
from
|
11 |
-
|
12 |
|
13 |
class CameraEmbedder(nn.Module):
|
14 |
def __init__(self, raw_dim: int, embed_dim: int):
|
@@ -67,14 +67,8 @@ class LRMGenerator(PreTrainedModel):
|
|
67 |
triplane_dim=config.triplane_dim, samples_per_ray=config.rendering_samples_per_ray,
|
68 |
)
|
69 |
|
70 |
-
def forward(self, image, camera):
|
71 |
|
72 |
-
# we use image processor directly in the forward pass
|
73 |
-
#TODO: we should have the following:
|
74 |
-
# processor = AutoProcessor.from_pretrained("jadechoghari/vfusion3d")
|
75 |
-
# processed_image, source_camera = processor(image)
|
76 |
-
#
|
77 |
-
|
78 |
assert image.shape[0] == camera.shape[0], "Batch size mismatch"
|
79 |
N = image.shape[0]
|
80 |
|
@@ -92,4 +86,128 @@ class LRMGenerator(PreTrainedModel):
|
|
92 |
planes = self.transformer(image_feats, camera_embeddings)
|
93 |
assert planes.shape[0] == N, "Batch size mismatch for planes"
|
94 |
assert planes.shape[1] == 3, "Planes should have 3 channels"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
return planes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch.nn as nn
|
4 |
from transformers import PreTrainedModel, PretrainedConfig
|
5 |
import torch
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
|
9 |
+
from dino_wrapper2 import DinoWrapper
|
10 |
+
from transformer import TriplaneTransformer
|
11 |
+
from synthesizer_part import TriplaneSynthesizer
|
12 |
|
13 |
class CameraEmbedder(nn.Module):
|
14 |
def __init__(self, raw_dim: int, embed_dim: int):
|
|
|
67 |
triplane_dim=config.triplane_dim, samples_per_ray=config.rendering_samples_per_ray,
|
68 |
)
|
69 |
|
70 |
+
def forward(self, image, camera, export_mesh=False, mesh_size=256, render_size=256, export_video=False, fps=30):
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
assert image.shape[0] == camera.shape[0], "Batch size mismatch"
|
73 |
N = image.shape[0]
|
74 |
|
|
|
86 |
planes = self.transformer(image_feats, camera_embeddings)
|
87 |
assert planes.shape[0] == N, "Batch size mismatch for planes"
|
88 |
assert planes.shape[1] == 3, "Planes should have 3 channels"
|
89 |
+
|
90 |
+
# Generate the mesh
|
91 |
+
if export_mesh:
|
92 |
+
import mcubes
|
93 |
+
import trimesh
|
94 |
+
grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
|
95 |
+
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
|
96 |
+
vtx = vtx / (mesh_size - 1) * 2 - 1
|
97 |
+
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0)
|
98 |
+
vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
|
99 |
+
vtx_colors = (vtx_colors * 255).astype(np.uint8)
|
100 |
+
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
|
101 |
+
|
102 |
+
mesh_path = "awesome_mesh.obj"
|
103 |
+
mesh.export(mesh_path, 'obj')
|
104 |
+
|
105 |
+
return planes, mesh_path
|
106 |
+
|
107 |
+
# Generate video
|
108 |
+
if export_video:
|
109 |
+
render_cameras = self._default_render_cameras(batch_size=N).to(image.device)
|
110 |
+
|
111 |
+
frames = []
|
112 |
+
chunk_size = 1 # Adjust chunk size as needed
|
113 |
+
for i in range(0, render_cameras.shape[1], chunk_size):
|
114 |
+
frame_chunk = self.synthesizer(
|
115 |
+
planes,
|
116 |
+
render_cameras[:, i:i + chunk_size],
|
117 |
+
render_size,
|
118 |
+
render_size,
|
119 |
+
0,
|
120 |
+
0
|
121 |
+
)
|
122 |
+
frames.append(frame_chunk['images_rgb'])
|
123 |
+
|
124 |
+
frames = torch.cat(frames, dim=1)
|
125 |
+
frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
|
126 |
+
|
127 |
+
# Save video
|
128 |
+
video_path = "awesome_video.mp4"
|
129 |
+
imageio.mimwrite(video_path, frames, fps=fps)
|
130 |
+
|
131 |
+
return planes, video_path
|
132 |
+
|
133 |
return planes
|
134 |
+
|
135 |
+
# Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py
|
136 |
+
# and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py
|
137 |
+
def _default_intrinsics(self):
|
138 |
+
fx = fy = 384
|
139 |
+
cx = cy = 256
|
140 |
+
w = h = 512
|
141 |
+
intrinsics = torch.tensor([
|
142 |
+
[fx, fy],
|
143 |
+
[cx, cy],
|
144 |
+
[w, h],
|
145 |
+
], dtype=torch.float32)
|
146 |
+
return intrinsics
|
147 |
+
|
148 |
+
def _default_render_cameras(self, batch_size=1):
|
149 |
+
M = 160 # Number of views
|
150 |
+
radius = 1.5
|
151 |
+
elevation = 0
|
152 |
+
|
153 |
+
camera_positions = []
|
154 |
+
rand_theta = np.random.uniform(0, np.pi / 180)
|
155 |
+
elevation = math.radians(elevation)
|
156 |
+
for i in range(M):
|
157 |
+
theta = 2 * math.pi * i / M + rand_theta
|
158 |
+
x = radius * math.cos(theta) * math.cos(elevation)
|
159 |
+
y = radius * math.sin(theta) * math.cos(elevation)
|
160 |
+
z = radius * math.sin(elevation)
|
161 |
+
camera_positions.append([x, y, z])
|
162 |
+
|
163 |
+
camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
|
164 |
+
extrinsics = self.center_looking_at_camera_pose(camera_positions)
|
165 |
+
|
166 |
+
intrinsics = self._default_intrinsics().unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
|
167 |
+
render_cameras = self.build_camera_standard(extrinsics, intrinsics)
|
168 |
+
|
169 |
+
return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
170 |
+
|
171 |
+
def center_looking_at_camera_pose(self, camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
|
172 |
+
if look_at is None:
|
173 |
+
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
|
174 |
+
if up_world is None:
|
175 |
+
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
|
176 |
+
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
177 |
+
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
178 |
+
|
179 |
+
z_axis = camera_position - look_at
|
180 |
+
z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
|
181 |
+
x_axis = torch.cross(up_world, z_axis)
|
182 |
+
x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
|
183 |
+
y_axis = torch.cross(z_axis, x_axis)
|
184 |
+
y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
|
185 |
+
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
|
186 |
+
return extrinsics
|
187 |
+
|
188 |
+
def get_normalized_camera_intrinsics(self, intrinsics: torch.Tensor):
|
189 |
+
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
|
190 |
+
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
|
191 |
+
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
|
192 |
+
fx, fy = fx / width, fy / height
|
193 |
+
cx, cy = cx / width, cy / height
|
194 |
+
return fx, fy, cx, cy
|
195 |
+
|
196 |
+
def build_camera_standard(self, RT: torch.Tensor, intrinsics: torch.Tensor):
|
197 |
+
E = self.compose_extrinsic_RT(RT)
|
198 |
+
fx, fy, cx, cy = self.get_normalized_camera_intrinsics(intrinsics)
|
199 |
+
I = torch.stack([
|
200 |
+
torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
|
201 |
+
torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
|
202 |
+
torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
|
203 |
+
], dim=1)
|
204 |
+
return torch.cat([
|
205 |
+
E.reshape(-1, 16),
|
206 |
+
I.reshape(-1, 9),
|
207 |
+
], dim=-1)
|
208 |
+
|
209 |
+
def compose_extrinsic_RT(self, RT: torch.Tensor):
|
210 |
+
return torch.cat([
|
211 |
+
RT,
|
212 |
+
torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(RT.shape[0], 1, 1).to(RT.device)
|
213 |
+
], dim=1)
|