mboss jammmmm commited on
Commit
64fccd8
·
1 Parent(s): c2f384d

Update demo with latest changes

Browse files

Co-authored-by: Aaryaman Vasishta <[email protected]>

gradio_app.py CHANGED
@@ -2,10 +2,12 @@ import os
2
  import random
3
  import tempfile
4
  import time
 
5
  from contextlib import nullcontext
6
  from functools import lru_cache
7
  from typing import Any
8
 
 
9
  import gradio as gr
10
  import numpy as np
11
  import torch
@@ -62,6 +64,23 @@ example_files = [
62
  ]
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def forward_model(
66
  batch,
67
  system,
@@ -105,11 +124,16 @@ def forward_model(
105
 
106
  # forward for the final mesh
107
  trimesh_mesh, _glob_dict = model.generate_mesh(
108
- batch, texture_resolution, remesh=remesh_option, vertex_count=vertex_count
 
 
 
 
109
  )
110
  trimesh_mesh = trimesh_mesh[0]
 
111
 
112
- return trimesh_mesh, pc_rgb_trimesh
113
 
114
 
115
  def run_model(
@@ -169,7 +193,7 @@ def run_model(
169
  dim=1,
170
  )
171
 
172
- trimesh_mesh, trimesh_pc = forward_model(
173
  model_batch,
174
  model,
175
  guidance_scale=guidance_scale,
@@ -191,9 +215,13 @@ def run_model(
191
  trimesh_pc.export(tmp_file_pc)
192
  generated_files.append(tmp_file_pc)
193
 
 
 
 
 
194
  print("Generation took:", time.time() - start, "s")
195
 
196
- return tmp_file, tmp_file_pc, trimesh_pc
197
 
198
 
199
  def create_batch(input_image: Image) -> dict[str, Any]:
@@ -272,7 +300,7 @@ def process_model_run(
272
  f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
273
  )
274
 
275
- glb_file, pc_file, pc_plot = run_model(
276
  background_state,
277
  guidance_scale,
278
  random_seed,
@@ -295,7 +323,7 @@ def process_model_run(
295
  ]
296
  )
297
 
298
- return glb_file, pc_file, point_list
299
 
300
 
301
  def regenerate_run(
@@ -308,7 +336,7 @@ def regenerate_run(
308
  vertex_count,
309
  texture_resolution,
310
  ):
311
- glb_file, pc_file, point_list = process_model_run(
312
  background_state,
313
  guidance_scale,
314
  random_seed,
@@ -318,6 +346,8 @@ def regenerate_run(
318
  vertex_count,
319
  texture_resolution,
320
  )
 
 
321
  return (
322
  gr.update(), # run_btn
323
  gr.update(), # img_proc_state
@@ -325,10 +355,12 @@ def regenerate_run(
325
  gr.update(), # preview_removal
326
  gr.update(value=glb_file, visible=True), # output_3d
327
  gr.update(visible=True), # hdr_row
 
328
  gr.update(visible=True), # point_cloud_row
329
  gr.update(value=point_list), # point_cloud_editor
330
  gr.update(value=pc_file), # pc_download
331
  gr.update(visible=False), # regenerate_btn
 
332
  )
333
 
334
 
@@ -362,7 +394,7 @@ def run_button(
362
  else:
363
  pc_cond = None
364
 
365
- glb_file, pc_file, pc_list = process_model_run(
366
  background_state,
367
  guidance_scale,
368
  random_seed,
@@ -373,6 +405,8 @@ def run_button(
373
  texture_resolution,
374
  )
375
 
 
 
376
  if torch.cuda.is_available():
377
  print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
378
  elif torch.backends.mps.is_available():
@@ -387,10 +421,12 @@ def run_button(
387
  gr.update(), # preview_removal
388
  gr.update(value=glb_file, visible=True), # output_3d
389
  gr.update(visible=True), # hdr_row
 
390
  gr.update(visible=True), # point_cloud_row
391
  gr.update(value=pc_list), # point_cloud_editor
392
  gr.update(value=pc_file), # pc_download
393
  gr.update(visible=False), # regenerate_btn
 
394
  )
395
 
396
  elif run_btn == "Remove Background":
@@ -410,10 +446,12 @@ def run_button(
410
  gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
411
  gr.update(value=None, visible=False), # output_3d
412
  gr.update(visible=False), # hdr_row
 
413
  gr.update(visible=False), # point_cloud_row
414
  gr.update(value=None), # point_cloud_editor
415
  gr.update(value=None), # pc_download
416
  gr.update(visible=False), # regenerate_btn
 
417
  )
418
 
419
 
@@ -425,11 +463,13 @@ def requires_bg_remove(image, fr, no_crop):
425
  None, # background_remove_state
426
  gr.update(value=None, visible=False), # preview_removal
427
  gr.update(value=None, visible=False), # output_3d
428
- gr.update(visible=False), # hdr_row
 
429
  gr.update(visible=False), # point_cloud_row
430
  gr.update(value=None), # point_cloud_editor
431
  gr.update(value=None), # pc_download
432
  gr.update(visible=False), # regenerate_btn
 
433
  )
434
  alpha_channel = np.array(image.getchannel("A"))
435
  min_alpha = alpha_channel.min()
@@ -446,10 +486,12 @@ def requires_bg_remove(image, fr, no_crop):
446
  gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
447
  gr.update(value=None, visible=False), # output_3d
448
  gr.update(visible=False), # hdr_row
 
449
  gr.update(visible=False), # point_cloud_row
450
  gr.update(value=None), # point_cloud_editor
451
  gr.update(value=None), # pc_download
452
  gr.update(visible=False), # regenerate_btn
 
453
  )
454
  return (
455
  gr.update(value="Remove Background", visible=True), # run_Btn
@@ -458,10 +500,12 @@ def requires_bg_remove(image, fr, no_crop):
458
  gr.update(value=None, visible=False), # preview_removal
459
  gr.update(value=None, visible=False), # output_3d
460
  gr.update(visible=False), # hdr_row
 
461
  gr.update(visible=False), # point_cloud_row
462
  gr.update(value=None), # point_cloud_editor
463
  gr.update(value=None), # pc_download
464
  gr.update(visible=False), # regenerate_btn
 
465
  )
466
 
467
 
@@ -487,6 +531,7 @@ def update_resolution_controls(remesh_choice, vertex_count_type):
487
  with gr.Blocks() as demo:
488
  img_proc_state = gr.State()
489
  background_remove_state = gr.State()
 
490
  gr.Markdown(
491
  """
492
  # SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
@@ -699,12 +744,46 @@ with gr.Blocks() as demo:
699
  inputs=hdr_illumination_file,
700
  )
701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  hdr_illumination_file.change(
703
- lambda x: gr.update(env_map=x.name if x is not None else None),
704
- inputs=hdr_illumination_file,
705
- outputs=[output_3d],
706
  )
707
 
 
 
 
 
 
 
 
 
 
 
708
  examples = gr.Examples(
709
  examples=example_files, inputs=input_img, examples_per_page=11
710
  )
@@ -719,10 +798,12 @@ with gr.Blocks() as demo:
719
  preview_removal,
720
  output_3d,
721
  hdr_row,
 
722
  point_cloud_row,
723
  point_cloud_editor,
724
  pc_download,
725
  regenerate_btn,
 
726
  ],
727
  )
728
 
@@ -751,10 +832,12 @@ with gr.Blocks() as demo:
751
  preview_removal,
752
  output_3d,
753
  hdr_row,
 
754
  point_cloud_row,
755
  point_cloud_editor,
756
  pc_download,
757
  regenerate_btn,
 
758
  ],
759
  )
760
 
@@ -782,11 +865,13 @@ with gr.Blocks() as demo:
782
  preview_removal,
783
  output_3d,
784
  hdr_row,
 
785
  point_cloud_row,
786
  point_cloud_editor,
787
  pc_download,
788
  regenerate_btn,
 
789
  ],
790
  )
791
 
792
- demo.queue().launch()
 
2
  import random
3
  import tempfile
4
  import time
5
+ import zipfile
6
  from contextlib import nullcontext
7
  from functools import lru_cache
8
  from typing import Any
9
 
10
+ import cv2
11
  import gradio as gr
12
  import numpy as np
13
  import torch
 
64
  ]
65
 
66
 
67
+ def create_zip_file(glb_file, pc_file, illumination_file):
68
+ if not all([glb_file, pc_file, illumination_file]):
69
+ return None
70
+
71
+ # Create a temporary zip file
72
+ temp_dir = tempfile.mkdtemp()
73
+ zip_path = os.path.join(temp_dir, "spar3d_output.zip")
74
+
75
+ with zipfile.ZipFile(zip_path, "w") as zipf:
76
+ zipf.write(glb_file, "mesh.glb")
77
+ zipf.write(pc_file, "points.ply")
78
+ zipf.write(illumination_file, "illumination.hdr")
79
+
80
+ generated_files.append(zip_path)
81
+ return zip_path
82
+
83
+
84
  def forward_model(
85
  batch,
86
  system,
 
124
 
125
  # forward for the final mesh
126
  trimesh_mesh, _glob_dict = model.generate_mesh(
127
+ batch,
128
+ texture_resolution,
129
+ remesh=remesh_option,
130
+ vertex_count=vertex_count,
131
+ estimate_illumination=True,
132
  )
133
  trimesh_mesh = trimesh_mesh[0]
134
+ illumination = _glob_dict["illumination"]
135
 
136
+ return trimesh_mesh, pc_rgb_trimesh, illumination.cpu().detach().numpy()[0]
137
 
138
 
139
  def run_model(
 
193
  dim=1,
194
  )
195
 
196
+ trimesh_mesh, trimesh_pc, illumination_map = forward_model(
197
  model_batch,
198
  model,
199
  guidance_scale=guidance_scale,
 
215
  trimesh_pc.export(tmp_file_pc)
216
  generated_files.append(tmp_file_pc)
217
 
218
+ tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr")
219
+ cv2.imwrite(tmp_file_illumination, illumination_map)
220
+ generated_files.append(tmp_file_illumination)
221
+
222
  print("Generation took:", time.time() - start, "s")
223
 
224
+ return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc
225
 
226
 
227
  def create_batch(input_image: Image) -> dict[str, Any]:
 
300
  f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
301
  )
302
 
303
+ glb_file, pc_file, illumination_file, pc_plot = run_model(
304
  background_state,
305
  guidance_scale,
306
  random_seed,
 
323
  ]
324
  )
325
 
326
+ return glb_file, pc_file, illumination_file, point_list
327
 
328
 
329
  def regenerate_run(
 
336
  vertex_count,
337
  texture_resolution,
338
  ):
339
+ glb_file, pc_file, illumination_file, point_list = process_model_run(
340
  background_state,
341
  guidance_scale,
342
  random_seed,
 
346
  vertex_count,
347
  texture_resolution,
348
  )
349
+ zip_file = create_zip_file(glb_file, pc_file, illumination_file)
350
+
351
  return (
352
  gr.update(), # run_btn
353
  gr.update(), # img_proc_state
 
355
  gr.update(), # preview_removal
356
  gr.update(value=glb_file, visible=True), # output_3d
357
  gr.update(visible=True), # hdr_row
358
+ illumination_file, # hdr_file
359
  gr.update(visible=True), # point_cloud_row
360
  gr.update(value=point_list), # point_cloud_editor
361
  gr.update(value=pc_file), # pc_download
362
  gr.update(visible=False), # regenerate_btn
363
+ gr.update(value=zip_file, visible=True), # download_all_btn
364
  )
365
 
366
 
 
394
  else:
395
  pc_cond = None
396
 
397
+ glb_file, pc_file, illumination_file, pc_list = process_model_run(
398
  background_state,
399
  guidance_scale,
400
  random_seed,
 
405
  texture_resolution,
406
  )
407
 
408
+ zip_file = create_zip_file(glb_file, pc_file, illumination_file)
409
+
410
  if torch.cuda.is_available():
411
  print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
412
  elif torch.backends.mps.is_available():
 
421
  gr.update(), # preview_removal
422
  gr.update(value=glb_file, visible=True), # output_3d
423
  gr.update(visible=True), # hdr_row
424
+ illumination_file, # hdr_file
425
  gr.update(visible=True), # point_cloud_row
426
  gr.update(value=pc_list), # point_cloud_editor
427
  gr.update(value=pc_file), # pc_download
428
  gr.update(visible=False), # regenerate_btn
429
+ gr.update(value=zip_file, visible=True), # download_all_btn
430
  )
431
 
432
  elif run_btn == "Remove Background":
 
446
  gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
447
  gr.update(value=None, visible=False), # output_3d
448
  gr.update(visible=False), # hdr_row
449
+ None, # hdr_file
450
  gr.update(visible=False), # point_cloud_row
451
  gr.update(value=None), # point_cloud_editor
452
  gr.update(value=None), # pc_download
453
  gr.update(visible=False), # regenerate_btn
454
+ gr.update(value=None, visible=False), # download_all_btn
455
  )
456
 
457
 
 
463
  None, # background_remove_state
464
  gr.update(value=None, visible=False), # preview_removal
465
  gr.update(value=None, visible=False), # output_3d
466
+ gr.update(value=None, visible=False), # hdr_row
467
+ None, # hdr_file
468
  gr.update(visible=False), # point_cloud_row
469
  gr.update(value=None), # point_cloud_editor
470
  gr.update(value=None), # pc_download
471
  gr.update(visible=False), # regenerate_btn
472
+ gr.update(value=None, visible=False), # download_all_btn
473
  )
474
  alpha_channel = np.array(image.getchannel("A"))
475
  min_alpha = alpha_channel.min()
 
486
  gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
487
  gr.update(value=None, visible=False), # output_3d
488
  gr.update(visible=False), # hdr_row
489
+ None, # hdr_file
490
  gr.update(visible=False), # point_cloud_row
491
  gr.update(value=None), # point_cloud_editor
492
  gr.update(value=None), # pc_download
493
  gr.update(visible=False), # regenerate_btn
494
+ gr.update(value=None, visible=False), # download_all_btn
495
  )
496
  return (
497
  gr.update(value="Remove Background", visible=True), # run_Btn
 
500
  gr.update(value=None, visible=False), # preview_removal
501
  gr.update(value=None, visible=False), # output_3d
502
  gr.update(visible=False), # hdr_row
503
+ None, # hdr_file
504
  gr.update(visible=False), # point_cloud_row
505
  gr.update(value=None), # point_cloud_editor
506
  gr.update(value=None), # pc_download
507
  gr.update(visible=False), # regenerate_btn
508
+ gr.update(value=None, visible=False), # download_all_btn
509
  )
510
 
511
 
 
531
  with gr.Blocks() as demo:
532
  img_proc_state = gr.State()
533
  background_remove_state = gr.State()
534
+ hdr_illumination_file_state = gr.State()
535
  gr.Markdown(
536
  """
537
  # SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
 
744
  inputs=hdr_illumination_file,
745
  )
746
 
747
+ def update_hdr_illumination_file(state, cur_update):
748
+ # If the current value of hdr_illumination_file is the same as cur_update, then we don't need to update
749
+ if (
750
+ hdr_illumination_file.value is not None
751
+ and hdr_illumination_file.value == cur_update
752
+ ):
753
+ return (
754
+ gr.update(),
755
+ gr.update(),
756
+ )
757
+ update_value = cur_update if cur_update is not None else state
758
+ if update_value is not None:
759
+ return (
760
+ gr.update(value=update_value),
761
+ gr.update(
762
+ env_map=(
763
+ update_value.name
764
+ if isinstance(update_value, gr.File)
765
+ else update_value
766
+ )
767
+ ),
768
+ )
769
+ return (gr.update(value=None), gr.update(env_map=None))
770
+
771
  hdr_illumination_file.change(
772
+ update_hdr_illumination_file,
773
+ inputs=[hdr_illumination_file_state, hdr_illumination_file],
774
+ outputs=[hdr_illumination_file, output_3d],
775
  )
776
 
777
+ download_all_btn = gr.File(
778
+ label="Download All Files (ZIP)", file_count="single", visible=False
779
+ )
780
+
781
+ hdr_illumination_file_state.change(
782
+ fn=lambda x: gr.update(value=x),
783
+ inputs=hdr_illumination_file_state,
784
+ outputs=hdr_illumination_file,
785
+ )
786
+
787
  examples = gr.Examples(
788
  examples=example_files, inputs=input_img, examples_per_page=11
789
  )
 
798
  preview_removal,
799
  output_3d,
800
  hdr_row,
801
+ hdr_illumination_file_state,
802
  point_cloud_row,
803
  point_cloud_editor,
804
  pc_download,
805
  regenerate_btn,
806
+ download_all_btn,
807
  ],
808
  )
809
 
 
832
  preview_removal,
833
  output_3d,
834
  hdr_row,
835
+ hdr_illumination_file_state,
836
  point_cloud_row,
837
  point_cloud_editor,
838
  pc_download,
839
  regenerate_btn,
840
+ download_all_btn,
841
  ],
842
  )
843
 
 
865
  preview_removal,
866
  output_3d,
867
  hdr_row,
868
+ hdr_illumination_file_state,
869
  point_cloud_row,
870
  point_cloud_editor,
871
  pc_download,
872
  regenerate_btn,
873
+ download_all_btn,
874
  ],
875
  )
876
 
877
+ demo.queue().launch(share=False)
requirements.txt CHANGED
@@ -16,6 +16,7 @@ transparent-background==1.3.3
16
  gradio==4.43.0
17
  gradio-litmodel3d==0.0.1
18
  gradio-pointcloudeditor==0.0.9
 
19
  gpytoolbox==0.2.0
20
  # ./texture_baker/
21
  # ./uv_unwrapper/
 
16
  gradio==4.43.0
17
  gradio-litmodel3d==0.0.1
18
  gradio-pointcloudeditor==0.0.9
19
+ opencv-python==4.10.0.84
20
  gpytoolbox==0.2.0
21
  # ./texture_baker/
22
  # ./uv_unwrapper/
run.py CHANGED
@@ -32,9 +32,9 @@ if __name__ == "__main__":
32
  )
33
  parser.add_argument(
34
  "--pretrained-model",
35
- default="stabilityai/spar3d",
36
  type=str,
37
- help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/spar3d'",
38
  )
39
  parser.add_argument(
40
  "--foreground-ratio",
 
32
  )
33
  parser.add_argument(
34
  "--pretrained-model",
35
+ default="stabilityai/stable-point-aware-3d",
36
  type=str,
37
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-point-aware-3d'",
38
  )
39
  parser.add_argument(
40
  "--foreground-ratio",
spar3d/models/global_estimator/reni_estimator.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass, field
2
- from typing import Any
3
 
4
  import torch
5
  import torch.nn as nn
@@ -95,6 +95,7 @@ class ReniLatentCodeEstimator(BaseModule):
95
  def forward(
96
  self,
97
  triplane: Float[Tensor, "B 3 F Ht Wt"],
 
98
  ) -> dict[str, Any]:
99
  x = self.layers(
100
  triplane.reshape(
@@ -104,9 +105,12 @@ class ReniLatentCodeEstimator(BaseModule):
104
  x = x.mean(dim=[-2, -1])
105
 
106
  latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
107
- rotations = self.fc_rotations(x)
108
  scale = self.fc_scale(x)
109
 
110
- env_map = self.reni_env_map(latents, rotation_6d_to_matrix(rotations), scale)
 
 
 
111
 
112
  return {"illumination": env_map["rgb"]}
 
1
  from dataclasses import dataclass, field
2
+ from typing import Any, Optional
3
 
4
  import torch
5
  import torch.nn as nn
 
95
  def forward(
96
  self,
97
  triplane: Float[Tensor, "B 3 F Ht Wt"],
98
+ rotation: Optional[Float[Tensor, "B 3 3"]] = None,
99
  ) -> dict[str, Any]:
100
  x = self.layers(
101
  triplane.reshape(
 
105
  x = x.mean(dim=[-2, -1])
106
 
107
  latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
108
+ rotations = rotation_6d_to_matrix(self.fc_rotations(x))
109
  scale = self.fc_scale(x)
110
 
111
+ if rotation is not None:
112
+ rotations = rotations @ rotation.to(dtype=rotations.dtype)
113
+
114
+ env_map = self.reni_env_map(latents, rotations, scale)
115
 
116
  return {"illumination": env_map["rgb"]}
spar3d/system.py CHANGED
@@ -506,6 +506,11 @@ class SPAR3D(BaseModule):
506
 
507
  scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
508
 
 
 
 
 
 
509
  global_dict = {}
510
  if self.image_estimator is not None:
511
  global_dict.update(
@@ -514,7 +519,14 @@ class SPAR3D(BaseModule):
514
  )
515
  )
516
  if self.global_estimator is not None and estimate_illumination:
517
- global_dict.update(self.global_estimator(non_postprocessed_codes))
 
 
 
 
 
 
 
518
 
519
  global_dict["pointcloud"] = batch["pc_cond"]
520
 
@@ -700,15 +712,7 @@ class SPAR3D(BaseModule):
700
  uv=uvs, material=material
701
  ),
702
  )
703
- rot = trimesh.transformations.rotation_matrix(
704
- np.radians(-90), [1, 0, 0]
705
- )
706
- tmesh.apply_transform(rot)
707
- tmesh.apply_transform(
708
- trimesh.transformations.rotation_matrix(
709
- np.radians(90), [0, 1, 0]
710
- )
711
- )
712
 
713
  tmesh.invert()
714
 
 
506
 
507
  scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
508
 
509
+ # Create a rotation matrix for the final output domain
510
+ rotation = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0])
511
+ rotation2 = trimesh.transformations.rotation_matrix(np.radians(90), [0, 1, 0])
512
+ output_rotation = rotation2 @ rotation
513
+
514
  global_dict = {}
515
  if self.image_estimator is not None:
516
  global_dict.update(
 
519
  )
520
  )
521
  if self.global_estimator is not None and estimate_illumination:
522
+ rotation_torch = (
523
+ torch.tensor(output_rotation)
524
+ .to(self.device, dtype=torch.float32)[:3, :3]
525
+ .unsqueeze(0)
526
+ )
527
+ global_dict.update(
528
+ self.global_estimator(non_postprocessed_codes, rotation=rotation_torch)
529
+ )
530
 
531
  global_dict["pointcloud"] = batch["pc_cond"]
532
 
 
712
  uv=uvs, material=material
713
  ),
714
  )
715
+ tmesh.apply_transform(output_rotation)
 
 
 
 
 
 
 
 
716
 
717
  tmesh.invert()
718