YiftachEde commited on
Commit
b8fee6a
·
1 Parent(s): f40ef65
Files changed (39) hide show
  1. configs/instant-mesh-base.yaml +22 -0
  2. configs/instant-mesh-large-train.yaml +67 -0
  3. configs/instant-mesh-large.yaml +22 -0
  4. configs/instant-mesh-large_refine.yaml +22 -0
  5. configs/instant-nerf-base.yaml +21 -0
  6. configs/instant-nerf-large-best.yaml +21 -0
  7. configs/instant-nerf-large-train.yaml +65 -0
  8. configs/instant-nerf-large.yaml +21 -0
  9. configs/instant-nerf-sdedit.yaml +21 -0
  10. configs/zero123plus-finetune.yaml +47 -0
  11. configs/zero123plus-refine_finetune.yaml +54 -0
  12. configs/zero123plus-refine_finetune_2.yaml +51 -0
  13. configs/zero123plus-refine_finetune_relit.yaml +52 -0
  14. configs/zero123plus-refine_finetune_single_light.yaml +56 -0
  15. configs/zero123plus-refine_finetune_single_view.yaml +55 -0
  16. src/__pycache__/__init__.cpython-310.pyc +0 -0
  17. src/models/__pycache__/__init__.cpython-310.pyc +0 -0
  18. src/models/__pycache__/lrm.cpython-310.pyc +0 -0
  19. src/models/decoder/__pycache__/__init__.cpython-310.pyc +0 -0
  20. src/models/decoder/__pycache__/transformer.cpython-310.pyc +0 -0
  21. src/models/encoder/__pycache__/__init__.cpython-310.pyc +0 -0
  22. src/models/encoder/__pycache__/dino.cpython-310.pyc +0 -0
  23. src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc +0 -0
  24. src/models/renderer/__pycache__/__init__.cpython-310.pyc +0 -0
  25. src/models/renderer/__pycache__/synthesizer.cpython-310.pyc +0 -0
  26. src/models/renderer/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  27. src/models/renderer/utils/__pycache__/math_utils.cpython-310.pyc +0 -0
  28. src/models/renderer/utils/__pycache__/ray_marcher.cpython-310.pyc +0 -0
  29. src/models/renderer/utils/__pycache__/ray_sampler.cpython-310.pyc +0 -0
  30. src/models/renderer/utils/__pycache__/renderer.cpython-310.pyc +0 -0
  31. src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  32. src/utils/__pycache__/camera_util.cpython-310.pyc +0 -0
  33. src/utils/__pycache__/infer_util.cpython-310.pyc +0 -0
  34. src/utils/__pycache__/mesh_util.cpython-310.pyc +0 -0
  35. src/utils/__pycache__/train_util.cpython-310.pyc +0 -0
  36. zero123plus/__pycache__/model.cpython-310.pyc +0 -0
  37. zero123plus/__pycache__/pipeline.cpython-310.pyc +0 -0
  38. zero123plus/model.py +547 -0
  39. zero123plus/pipeline.py +1125 -0
configs/instant-mesh-base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_base.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-mesh-large-train.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.0e-05
3
+ target: src.model_mesh.MVRecon
4
+ params:
5
+ init_ckpt: logs/instant-nerf-large-train/checkpoints/last.ckpt
6
+ input_size: 320
7
+ render_size: 512
8
+
9
+ lrm_generator_config:
10
+ target: src.models.lrm_mesh.InstantMesh
11
+ params:
12
+ encoder_feat_dim: 768
13
+ encoder_freeze: false
14
+ encoder_model_name: facebook/dino-vitb16
15
+ transformer_dim: 1024
16
+ transformer_layers: 16
17
+ transformer_heads: 16
18
+ triplane_low_res: 32
19
+ triplane_high_res: 64
20
+ triplane_dim: 80
21
+ rendering_samples_per_ray: 128
22
+ grid_res: 128
23
+ grid_scale: 2.1
24
+
25
+
26
+ data:
27
+ target: src.data.objaverse.DataModuleFromConfig
28
+ params:
29
+ batch_size: 2
30
+ num_workers: 8
31
+ train:
32
+ target: src.data.objaverse.ObjaverseData
33
+ params:
34
+ root_dir: data/objaverse
35
+ meta_fname: filtered_obj_name.json
36
+ input_image_dir: rendering_random_32views
37
+ target_image_dir: rendering_random_32views
38
+ input_view_num: 6
39
+ target_view_num: 4
40
+ total_view_n: 32
41
+ fov: 50
42
+ camera_rotation: true
43
+ validation: false
44
+ validation:
45
+ target: src.data.objaverse.ValidationData
46
+ params:
47
+ root_dir: data/valid_samples
48
+ input_view_num: 6
49
+ input_image_size: 320
50
+ fov: 30
51
+
52
+
53
+ lightning:
54
+ modelcheckpoint:
55
+ params:
56
+ every_n_train_steps: 2000
57
+ save_top_k: -1
58
+ save_last: true
59
+ callbacks: {}
60
+
61
+ trainer:
62
+ benchmark: true
63
+ max_epochs: -1
64
+ val_check_interval: 1000
65
+ num_sanity_val_steps: 0
66
+ accumulate_grad_batches: 1
67
+ check_val_every_n_epoch: null # if not set this, validation does not run
configs/instant-mesh-large.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_large.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-mesh-large_refine.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: step=00260000.ckpt
20
+ model_path: ckpts/instant_mesh_large.ckpt
21
+ texture_resolution: 8192
22
+ render_resolution: 1536
configs/instant-nerf-base.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_base.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
configs/instant-nerf-large-best.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+
15
+
16
+ infer_config:
17
+ unet_path: best_21.ckpt
18
+ model_path: ckpts/instant_nerf_large.ckpt
19
+ mesh_threshold: 5.0
20
+ mesh_resolution: 256
21
+ render_resolution: 512
configs/instant-nerf-large-train.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.0e-04
3
+ target: src.model.MVRecon
4
+ params:
5
+ input_size: 320
6
+ render_size: 192
7
+
8
+ lrm_generator_config:
9
+ target: src.models.lrm.InstantNeRF
10
+ params:
11
+ encoder_feat_dim: 768
12
+ encoder_freeze: false
13
+ encoder_model_name: facebook/dino-vitb16
14
+ transformer_dim: 1024
15
+ transformer_layers: 16
16
+ transformer_heads: 16
17
+ triplane_low_res: 32
18
+ triplane_high_res: 64
19
+ triplane_dim: 80
20
+ rendering_samples_per_ray: 128
21
+
22
+
23
+ data:
24
+ target: src.data.objaverse.DataModuleFromConfig
25
+ params:
26
+ batch_size: 2
27
+ num_workers: 8
28
+ train:
29
+ target: src.data.objaverse.ObjaverseData
30
+ params:
31
+ root_dir: data/objaverse
32
+ meta_fname: filtered_obj_name.json
33
+ input_image_dir: rendering_random_32views
34
+ target_image_dir: rendering_random_32views
35
+ input_view_num: 6
36
+ target_view_num: 4
37
+ total_view_n: 32
38
+ fov: 50
39
+ camera_rotation: true
40
+ validation: false
41
+ validation:
42
+ target: src.data.objaverse.ValidationData
43
+ params:
44
+ root_dir: data/valid_samples
45
+ input_view_num: 6
46
+ input_image_size: 320
47
+ fov: 30
48
+
49
+
50
+ lightning:
51
+ modelcheckpoint:
52
+ params:
53
+ every_n_train_steps: 1000
54
+ save_top_k: -1
55
+ save_last: true
56
+ callbacks: {}
57
+
58
+ trainer:
59
+ benchmark: true
60
+ max_epochs: -1
61
+ gradient_clip_val: 1.0
62
+ val_check_interval: 1000
63
+ num_sanity_val_steps: 0
64
+ accumulate_grad_batches: 1
65
+ check_val_every_n_epoch: null # if not set this, validation does not run
configs/instant-nerf-large.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+
15
+
16
+ infer_config:
17
+ unet_path: logs/zero123plus-refine_finetune_single_view/checkpoints/step=00210000.ckpt
18
+ model_path: ckpts/instant_nerf_large.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 320
configs/instant-nerf-sdedit.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_large.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 512
configs/zero123plus-finetune.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: zero123plus.model.MVDiffusion
4
+ params:
5
+ drop_cond_prob: 0.1
6
+
7
+ stable_diffusion_config:
8
+ pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2
9
+ custom_pipeline: ./zero123plus
10
+
11
+ data:
12
+ target: src.data.objaverse_zero123plus.DataModuleFromConfig
13
+ params:
14
+ batch_size: 1
15
+ num_workers: 1
16
+ train:
17
+ target: src.data.objaverse_zero123plus.ObjaverseData
18
+ params:
19
+ root_dir: data/objaverse
20
+ meta_fname: lvis-annotations.json
21
+ image_dir: rendering_zero123plus
22
+ validation: false
23
+ validation:
24
+ target: src.data.objaverse_zero123plus.ObjaverseData
25
+ params:
26
+ root_dir: data/objaverse
27
+ meta_fname: lvis-annotations.json
28
+ image_dir: rendering_zero123plus
29
+ validation: true
30
+
31
+
32
+ lightning:
33
+ modelcheckpoint:
34
+ params:
35
+ every_n_train_steps: 1000
36
+ save_top_k: -1
37
+ save_last: true
38
+ callbacks: {}
39
+
40
+ trainer:
41
+ benchmark: true
42
+ max_epochs: -1
43
+ gradient_clip_val: 1.0
44
+ val_check_interval: 1000
45
+ num_sanity_val_steps: 0
46
+ accumulate_grad_batches: 1
47
+ check_val_every_n_epoch: null # if not set this, validation does not run
configs/zero123plus-refine_finetune.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: zero123plus.model.MVDiffusionRefinement
4
+ params:
5
+ drop_cond_prob: 0.1
6
+ refinement: true
7
+ stable_diffusion_config:
8
+ pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2
9
+ custom_pipeline: ./zero123plus
10
+
11
+
12
+
13
+ data:
14
+ target: src.data.objaverse_zero123plus.DataModuleFromConfig
15
+ params:
16
+ batch_size: 3
17
+ num_workers: 1
18
+ train:
19
+ target: src.data.objaverse_zero123plus.RefinementData
20
+ params:
21
+ root_dir: refinement_dataset/
22
+ gt_subpath: gt
23
+ pred_subpath: shap_e
24
+ validation: false
25
+ caption_path: captions.json
26
+ split_path: dataset_splits.json
27
+ validation:
28
+ target: src.data.objaverse_zero123plus.RefinementData
29
+ params:
30
+ root_dir: refinement_dataset/
31
+ gt_subpath: gt
32
+ pred_subpath: shap_e
33
+ validation: true
34
+ caption_path: captions.json
35
+ split_path: dataset_splits.json
36
+
37
+
38
+ lightning:
39
+ modelcheckpoint:
40
+ params:
41
+ every_n_train_steps: 10000
42
+ save_top_k: 1
43
+ save_last: true
44
+ callbacks: {}
45
+
46
+ trainer:
47
+ benchmark: true
48
+ max_epochs: -1
49
+ gradient_clip_val: 1.0
50
+ val_check_interval: 10000
51
+ num_sanity_val_steps: 0
52
+ accumulate_grad_batches: 1
53
+ check_val_every_n_epoch: null # if not set this, validation does not run
54
+
configs/zero123plus-refine_finetune_2.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: zero123plus.model.MVDiffusionRefinement
4
+ params:
5
+ drop_cond_prob: 0.1
6
+ refinement: true
7
+ stable_diffusion_config:
8
+ pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2
9
+ custom_pipeline: ./zero123plus
10
+
11
+
12
+
13
+ data:
14
+ target: src.data.objaverse_zero123plus.DataModuleFromConfig
15
+ params:
16
+ batch_size: 3
17
+ num_workers: 1
18
+ train:
19
+ target: src.data.objaverse_zero123plus.RefinementData
20
+ params:
21
+ root_dir: refinement_dataset/
22
+ gt_subpath: gt
23
+ pred_subpath: shap_e
24
+ validation: false
25
+ overfit: true
26
+ validation:
27
+ target: src.data.objaverse_zero123plus.RefinementData
28
+ params:
29
+ root_dir: refinement_dataset/
30
+ gt_subpath: gt
31
+ pred_subpath: shap_e
32
+ validation: true
33
+ overfit: true
34
+
35
+
36
+ lightning:
37
+ modelcheckpoint:
38
+ params:
39
+ every_n_train_steps: 1000
40
+ save_top_k: -1
41
+ save_last: true
42
+ callbacks: {}
43
+
44
+ trainer:
45
+ benchmark: true
46
+ max_epochs: -1
47
+ gradient_clip_val: 1.0
48
+ val_check_interval: 100
49
+ num_sanity_val_steps: 0
50
+ accumulate_grad_batches: 1
51
+ check_val_every_n_epoch: null # if not set this, validation does not run
configs/zero123plus-refine_finetune_relit.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: zero123plus.model.MVDiffusionRefinement
4
+ params:
5
+ drop_cond_prob: 0.1
6
+ refinement: true
7
+ stable_diffusion_config:
8
+ pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2
9
+ custom_pipeline: ./zero123plus
10
+
11
+
12
+
13
+ data:
14
+ target: src.data.objaverse_zero123plus.DataModuleFromConfig
15
+ params:
16
+ batch_size: 3
17
+ num_workers: 1
18
+ train:
19
+ target: src.data.objaverse_zero123plus.RefinementData
20
+ params:
21
+ root_dir: refinement_dataset_subset_relighted/
22
+ gt_subpath: gt
23
+ pred_subpath: shap_e
24
+ validation: false
25
+ caption_path: captions.json
26
+ validation:
27
+ target: src.data.objaverse_zero123plus.RefinementData
28
+ params:
29
+ root_dir: refinement_dataset_subset_relighted/
30
+ gt_subpath: gt
31
+ pred_subpath: shap_e
32
+ validation: true
33
+ caption_path: captions.json
34
+
35
+
36
+ lightning:
37
+ modelcheckpoint:
38
+ params:
39
+ every_n_train_steps: 1000
40
+ save_top_k: 1
41
+ save_last: true
42
+ callbacks: {}
43
+
44
+ trainer:
45
+ benchmark: true
46
+ max_epochs: -1
47
+ gradient_clip_val: 1.0
48
+ val_check_interval: 1000
49
+ num_sanity_val_steps: 0
50
+ accumulate_grad_batches: 1
51
+ check_val_every_n_epoch: null # if not set this, validation does not run
52
+
configs/zero123plus-refine_finetune_single_light.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: zero123plus.model.MVDiffusionRefinement
4
+ params:
5
+ drop_cond_prob: 0.1
6
+ refinement: true
7
+ stable_diffusion_config:
8
+ pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2
9
+ custom_pipeline: ./zero123plus
10
+
11
+
12
+
13
+ data:
14
+ target: src.data.objaverse_zero123plus.DataModuleFromConfig
15
+ params:
16
+ batch_size: 3
17
+ num_workers: 1
18
+ train:
19
+ target: src.data.objaverse_zero123plus.RefinementData
20
+ params:
21
+ root_dir: refinement_dataset/
22
+ gt_subpath: gt
23
+ pred_subpath: shap_e
24
+ validation: false
25
+ caption_path: captions.json
26
+ split_path: dataset_splits_fixed.json
27
+ single_view: false
28
+ single_light: true
29
+ validation:
30
+ target: src.data.objaverse_zero123plus.RefinementData
31
+ params:
32
+ root_dir: refinement_dataset/
33
+ gt_subpath: gt
34
+ pred_subpath: shap_e
35
+ validation: true
36
+ caption_path: captions.json
37
+ split_path: dataset_splits_fixed.json
38
+ single_view: false
39
+ single_light: true
40
+ lightning:
41
+ modelcheckpoint:
42
+ params:
43
+ every_n_train_steps: 10000
44
+ save_top_k: 1
45
+ save_last: true
46
+ callbacks: {}
47
+
48
+ trainer:
49
+ benchmark: true
50
+ max_epochs: -1
51
+ gradient_clip_val: 1.0
52
+ val_check_interval: 10000000
53
+ num_sanity_val_steps: 0
54
+ accumulate_grad_batches: 1
55
+ check_val_every_n_epoch: null # if not set this, validation does not run
56
+
configs/zero123plus-refine_finetune_single_view.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-05
3
+ target: zero123plus.model.MVDiffusionRefinement
4
+ params:
5
+ drop_cond_prob: 0.1
6
+ refinement: true
7
+ stable_diffusion_config:
8
+ pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2
9
+ custom_pipeline: ./zero123plus
10
+
11
+
12
+
13
+ data:
14
+ target: src.data.objaverse_zero123plus.DataModuleFromConfig
15
+ params:
16
+ batch_size: 18
17
+ num_workers: 1
18
+ train:
19
+ target: src.data.objaverse_zero123plus.RefinementData
20
+ params:
21
+ root_dir: refinement_dataset/
22
+ gt_subpath: gt
23
+ pred_subpath: shap_e
24
+ validation: false
25
+ caption_path: captions.json
26
+ split_path: dataset_splits_fixed.json
27
+ single_view: true
28
+ validation:
29
+ target: src.data.objaverse_zero123plus.RefinementData
30
+ params:
31
+ root_dir: refinement_dataset/
32
+ gt_subpath: gt
33
+ pred_subpath: shap_e
34
+ validation: true
35
+ caption_path: captions.json
36
+ split_path: dataset_splits_fixed.json
37
+ single_view: true
38
+
39
+ lightning:
40
+ modelcheckpoint:
41
+ params:
42
+ every_n_train_steps: 10000
43
+ save_top_k: 1
44
+ save_last: true
45
+ callbacks: {}
46
+
47
+ trainer:
48
+ benchmark: true
49
+ max_epochs: -1
50
+ gradient_clip_val: 1.0
51
+ val_check_interval: 10000000
52
+ num_sanity_val_steps: 0
53
+ accumulate_grad_batches: 1
54
+ check_val_every_n_epoch: null # if not set this, validation does not run
55
+
src/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/__init__.cpython-310.pyc and b/src/__pycache__/__init__.cpython-310.pyc differ
 
src/models/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/models/__pycache__/__init__.cpython-310.pyc and b/src/models/__pycache__/__init__.cpython-310.pyc differ
 
src/models/__pycache__/lrm.cpython-310.pyc CHANGED
Binary files a/src/models/__pycache__/lrm.cpython-310.pyc and b/src/models/__pycache__/lrm.cpython-310.pyc differ
 
src/models/decoder/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/models/decoder/__pycache__/__init__.cpython-310.pyc and b/src/models/decoder/__pycache__/__init__.cpython-310.pyc differ
 
src/models/decoder/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/src/models/decoder/__pycache__/transformer.cpython-310.pyc and b/src/models/decoder/__pycache__/transformer.cpython-310.pyc differ
 
src/models/encoder/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/models/encoder/__pycache__/__init__.cpython-310.pyc and b/src/models/encoder/__pycache__/__init__.cpython-310.pyc differ
 
src/models/encoder/__pycache__/dino.cpython-310.pyc CHANGED
Binary files a/src/models/encoder/__pycache__/dino.cpython-310.pyc and b/src/models/encoder/__pycache__/dino.cpython-310.pyc differ
 
src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc CHANGED
Binary files a/src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc and b/src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc differ
 
src/models/renderer/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/models/renderer/__pycache__/__init__.cpython-310.pyc and b/src/models/renderer/__pycache__/__init__.cpython-310.pyc differ
 
src/models/renderer/__pycache__/synthesizer.cpython-310.pyc CHANGED
Binary files a/src/models/renderer/__pycache__/synthesizer.cpython-310.pyc and b/src/models/renderer/__pycache__/synthesizer.cpython-310.pyc differ
 
src/models/renderer/utils/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/models/renderer/utils/__pycache__/__init__.cpython-310.pyc and b/src/models/renderer/utils/__pycache__/__init__.cpython-310.pyc differ
 
src/models/renderer/utils/__pycache__/math_utils.cpython-310.pyc CHANGED
Binary files a/src/models/renderer/utils/__pycache__/math_utils.cpython-310.pyc and b/src/models/renderer/utils/__pycache__/math_utils.cpython-310.pyc differ
 
src/models/renderer/utils/__pycache__/ray_marcher.cpython-310.pyc CHANGED
Binary files a/src/models/renderer/utils/__pycache__/ray_marcher.cpython-310.pyc and b/src/models/renderer/utils/__pycache__/ray_marcher.cpython-310.pyc differ
 
src/models/renderer/utils/__pycache__/ray_sampler.cpython-310.pyc CHANGED
Binary files a/src/models/renderer/utils/__pycache__/ray_sampler.cpython-310.pyc and b/src/models/renderer/utils/__pycache__/ray_sampler.cpython-310.pyc differ
 
src/models/renderer/utils/__pycache__/renderer.cpython-310.pyc CHANGED
Binary files a/src/models/renderer/utils/__pycache__/renderer.cpython-310.pyc and b/src/models/renderer/utils/__pycache__/renderer.cpython-310.pyc differ
 
src/utils/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/utils/__pycache__/__init__.cpython-310.pyc and b/src/utils/__pycache__/__init__.cpython-310.pyc differ
 
src/utils/__pycache__/camera_util.cpython-310.pyc CHANGED
Binary files a/src/utils/__pycache__/camera_util.cpython-310.pyc and b/src/utils/__pycache__/camera_util.cpython-310.pyc differ
 
src/utils/__pycache__/infer_util.cpython-310.pyc CHANGED
Binary files a/src/utils/__pycache__/infer_util.cpython-310.pyc and b/src/utils/__pycache__/infer_util.cpython-310.pyc differ
 
src/utils/__pycache__/mesh_util.cpython-310.pyc CHANGED
Binary files a/src/utils/__pycache__/mesh_util.cpython-310.pyc and b/src/utils/__pycache__/mesh_util.cpython-310.pyc differ
 
src/utils/__pycache__/train_util.cpython-310.pyc CHANGED
Binary files a/src/utils/__pycache__/train_util.cpython-310.pyc and b/src/utils/__pycache__/train_util.cpython-310.pyc differ
 
zero123plus/__pycache__/model.cpython-310.pyc ADDED
Binary file (16.5 kB). View file
 
zero123plus/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (22.7 kB). View file
 
zero123plus/model.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import pytorch_lightning as pl
7
+ from tqdm import tqdm
8
+ from torchvision.transforms import v2
9
+ from torchvision.utils import make_grid, save_image
10
+ from einops import rearrange
11
+
12
+ from src.utils.train_util import instantiate_from_config
13
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel
14
+ from .pipeline import RefOnlyNoisedUNet
15
+
16
+
17
+ def scale_latents(latents):
18
+ latents = (latents - 0.22) * 0.75
19
+ return latents
20
+
21
+
22
+ def unscale_latents(latents):
23
+ latents = latents / 0.75 + 0.22
24
+ return latents
25
+
26
+
27
+ def scale_image(image):
28
+ image = image * 0.5 / 0.8
29
+ return image
30
+
31
+
32
+ def unscale_image(image):
33
+ image = image / 0.5 * 0.8
34
+ return image
35
+
36
+
37
+ def extract_into_tensor(a, t, x_shape):
38
+ b, *_ = t.shape
39
+ out = a.gather(-1, t)
40
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
41
+
42
+ class MVDiffusionRefinement(pl.LightningModule):
43
+ def __init__(
44
+ self,
45
+ stable_diffusion_config,
46
+ refinement,
47
+ drop_cond_prob=0.1,
48
+ ):
49
+ super(MVDiffusionRefinement, self).__init__()
50
+
51
+ self.drop_cond_prob = drop_cond_prob
52
+ self.refinement = refinement
53
+ self.register_schedule()
54
+
55
+ # init modules
56
+
57
+ pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config,low_cpu_mem_usage=False)
58
+
59
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
60
+ pipeline.scheduler.config, timestep_spacing='trailing'
61
+ )
62
+
63
+ self.pipeline = pipeline
64
+ if refinement:
65
+ from huggingface_hub import hf_hub_download
66
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
67
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
68
+ self.pipeline.unet.load_state_dict(state_dict, strict=False)
69
+ pipeline.unet.load_state_dict(state_dict, strict=False)
70
+ train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config)
71
+
72
+ in_channels = 8
73
+ out_channels = self.pipeline.unet.conv_in.out_channels
74
+ self.pipeline.unet.register_to_config(in_channels=in_channels)
75
+ with torch.no_grad():
76
+ new_conv_in = nn.Conv2d(
77
+ in_channels, out_channels, self.pipeline.unet.conv_in.kernel_size, self.pipeline.unet.conv_in.stride, self.pipeline.unet.conv_in.padding
78
+ )
79
+ new_conv_in.weight.zero_()
80
+ new_conv_in.weight[:, :4, :, :].copy_(self.pipeline.unet.conv_in.weight)
81
+ self.pipeline.unet.conv_in = new_conv_in
82
+
83
+ if isinstance(self.pipeline.unet, UNet2DConditionModel):
84
+ self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler)
85
+
86
+
87
+ self.train_scheduler = train_sched # use ddpm scheduler during training
88
+
89
+ self.unet = pipeline.unet
90
+
91
+ # validation output buffer
92
+ self.validation_step_outputs = []
93
+ with torch.no_grad():
94
+ self.cond_latents_zero = self.encode_condition_image(torch.zeros(1,3,320,320)).to(self.device)
95
+ self.prompt_latents_zero = self.pipeline._encode_prompt([""], self.device, 1, False)
96
+
97
+
98
+ def register_schedule(self):
99
+ self.num_timesteps = 1000
100
+
101
+ # replace scaled_linear schedule with linear schedule as Zero123++
102
+ beta_start = 0.00085
103
+ beta_end = 0.0120
104
+ betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32)
105
+
106
+ alphas = 1. - betas
107
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
108
+ alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
109
+
110
+ self.register_buffer('betas', betas.float())
111
+ self.register_buffer('alphas_cumprod', alphas_cumprod.float())
112
+ self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
113
+
114
+ # calculations for diffusion q(x_t | x_{t-1}) and others
115
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
116
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
117
+
118
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float())
119
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float())
120
+
121
+ def on_fit_start(self):
122
+ device = torch.device(f'cuda:{self.global_rank}')
123
+ self.pipeline.to(device)
124
+ if self.global_rank == 0:
125
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
126
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
127
+
128
+ def prepare_batch_data(self, batch):
129
+ unrefined_imgs = batch['unrefined_imgs'] # (B, 6, C, H, W)
130
+ unrefined_imgs = v2.functional.resize(unrefined_imgs, 320, interpolation=3, antialias=True).clamp(0, 1)
131
+ unrefined_imgs = rearrange(unrefined_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W)
132
+ unrefined_imgs = unrefined_imgs.to(self.device)
133
+
134
+ target_imgs = batch['refined_imgs'] # (B, 6, C, H, W)
135
+ target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1)
136
+ target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W)
137
+ target_imgs = target_imgs.to(self.device)
138
+ return unrefined_imgs, target_imgs
139
+
140
+
141
+ @torch.no_grad()
142
+ def forward_vision_encoder(self, images):
143
+ dtype = next(self.pipeline.vision_encoder.parameters()).dtype
144
+ image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
145
+ image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values
146
+ image_pt = image_pt.to(device=self.device, dtype=dtype)
147
+ global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds
148
+ global_embeds = global_embeds.unsqueeze(-2)
149
+
150
+ encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0]
151
+ ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1)
152
+ encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
153
+
154
+ return encoder_hidden_states
155
+
156
+ @torch.no_grad()
157
+ def encode_condition_image(self, images):
158
+ dtype = next(self.pipeline.vae.parameters()).dtype
159
+ image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
160
+ image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values
161
+ image_pt = image_pt.to(device=self.device, dtype=dtype)
162
+ latents = self.pipeline.vae.encode(image_pt).latent_dist.sample()
163
+ return latents
164
+
165
+ @torch.no_grad()
166
+ def encode_target_images(self, images):
167
+ dtype = next(self.pipeline.vae.parameters()).dtype
168
+ # equals to scaling images to [-1, 1] first and then call scale_image
169
+
170
+ images = (images - 0.5) / 0.8 # [-0.625, 0.625]
171
+ posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist
172
+ latents = posterior.sample() * self.pipeline.vae.config.scaling_factor
173
+ latents = scale_latents(latents)
174
+ return latents
175
+
176
+ def forward_unet(self, latents, t, prompt_embeds, cond_latents, cross_attention_kwargs=None):
177
+ dtype = next(self.pipeline.unet.parameters()).dtype
178
+ latents = latents.to(dtype)
179
+ prompt_embeds = prompt_embeds.to(dtype)
180
+ cond_latents = cond_latents.to(dtype)
181
+ if cross_attention_kwargs is None:
182
+ cross_attention_kwargs = dict()
183
+ cross_attention_kwargs.update(cond_lat=cond_latents)
184
+ # cross_attention_kwargs = dict(cond_lat=cond_latents)
185
+
186
+ pred_noise = self.pipeline.unet(
187
+ latents,
188
+ t,
189
+ encoder_hidden_states=prompt_embeds,
190
+ cross_attention_kwargs=cross_attention_kwargs,
191
+ return_dict=False,
192
+ )[0]
193
+ return pred_noise
194
+
195
+ def predict_start_from_z_and_v(self, x_t, t, v):
196
+ return (
197
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
198
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
199
+ )
200
+
201
+ def get_v(self, x, noise, t):
202
+ return (
203
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
204
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
205
+ )
206
+
207
+ def decode_latents(self, latents_pred):
208
+ latents = unscale_latents(latents_pred)
209
+ images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
210
+ images = (images * 0.5 + 0.5).clamp(0, 1)
211
+ return images
212
+
213
+ def training_step(self, batch, batch_idx):
214
+ # get input
215
+ latents_source, latents_target = batch['unrefined_imgs'], batch['refined_imgs']
216
+ captions = batch['caption']
217
+ # sample random timestep
218
+ B = latents_source.shape[0]
219
+
220
+ t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device)
221
+
222
+ # classifier-free guidance
223
+ if np.random.rand() < self.drop_cond_prob:
224
+ prompt_embeds = self.prompt_latents_zero.to(self.device).expand(B, -1, -1)
225
+ else:
226
+ prompt_embeds = self.pipeline._encode_prompt(captions,self.device, 1, False)
227
+ cond_latents = self.cond_latents_zero.to(self.device)
228
+
229
+ # with torch.no_grad():
230
+ # latents_source = self.pipeline.vae.encode(source_imgs).latent_dist.mode()
231
+ noise = torch.randn_like(latents_target)
232
+ latents_noisy = self.train_scheduler.add_noise(latents_target, noise, t)
233
+ latents_noisy_unet = torch.cat([latents_noisy, latents_source], dim=1)
234
+ cak = dict(dont_forward_cond_state=True)
235
+ v_pred = self.forward_unet(latents_noisy_unet, t, prompt_embeds, cond_latents, cross_attention_kwargs=cak)
236
+ v_target = self.get_v(latents_target, noise, t)
237
+
238
+ loss, loss_dict = self.compute_loss(v_pred, v_target)
239
+
240
+ # logging
241
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
242
+ self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)
243
+ lr = self.optimizers().param_groups[0]['lr']
244
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
245
+
246
+ if self.global_step % 5000000 == 0 and self.global_rank == 0:
247
+ with torch.no_grad():
248
+ latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred)
249
+ images = self.decode_latents(latents_pred)
250
+ target_imgs = self.decode_latents(latents_target)
251
+
252
+ images = torch.cat([target_imgs, images], dim=-2)
253
+
254
+ grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1))
255
+ save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
256
+
257
+ return loss
258
+
259
+ def compute_loss(self, noise_pred, noise_gt):
260
+ loss = F.mse_loss(noise_pred, noise_gt)
261
+
262
+ prefix = 'train'
263
+ loss_dict = {}
264
+ loss_dict.update({f'{prefix}/loss': loss})
265
+
266
+ return loss, loss_dict
267
+
268
+ @torch.no_grad()
269
+ def validation_step(self, batch, batch_idx):
270
+ # get input
271
+ latents_source, latents_target = batch['unrefined_imgs'], batch['refined_imgs']
272
+ prompts = batch['caption']
273
+ source_imgs = self.decode_latents(latents_source)
274
+ target_imgs = self.decode_latents(latents_target)
275
+
276
+ images_pil = [v2.functional.to_pil_image(source_imgs[i]) for i in range(source_imgs.shape[0])]
277
+
278
+ outputs = []
279
+ for source_img,prompt in zip(images_pil,prompts):
280
+ latent = self.pipeline.refine(source_img,prompt=prompt, num_inference_steps=75, output_type='latent').images
281
+ image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
282
+ image = (image * 0.5 + 0.5).clamp(0, 1)
283
+ outputs.append(image)
284
+ outputs = torch.cat(outputs, dim=0).to(self.device)
285
+ images = torch.cat([target_imgs, outputs, source_imgs], dim=-2)
286
+
287
+ self.validation_step_outputs.append(images)
288
+
289
+ @torch.no_grad()
290
+ def on_validation_epoch_end(self):
291
+ images = torch.cat(self.validation_step_outputs, dim=0)
292
+ all_images = self.all_gather(images)
293
+ # all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
294
+ imgs = all_images.chunk(all_images.shape[0], dim=0)
295
+
296
+ if self.global_rank == 0:
297
+ os.makedirs(os.path.join(self.logdir, 'images_val', f'{self.global_step:07d}'), exist_ok=True)
298
+ grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1))
299
+ save_image(grid, os.path.join(self.logdir, 'images_val',f'{self.global_step:07d}', f'all.png'))
300
+ for idx, img in enumerate(imgs):
301
+ target, output, source = img.chunk(3, dim=-2)
302
+ img = torch.cat([source, target, output], dim=-1)
303
+ save_image(img, os.path.join(self.logdir, 'images_val',f'{self.global_step:07d}', f'comparison_img_{idx}.png'))
304
+ source_outputs = torch.cat([source, output], dim=-1)
305
+ save_image(source_outputs, os.path.join(self.logdir, 'images_val',f'{self.global_step:07d}', f'comparison_source_output_img_{idx}.png'))
306
+ self.validation_step_outputs.clear() # free memory
307
+
308
+ def configure_optimizers(self):
309
+ lr = self.learning_rate
310
+
311
+ optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr)
312
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
313
+
314
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
315
+
316
+ class MVDiffusion(pl.LightningModule):
317
+ def __init__(
318
+ self,
319
+ stable_diffusion_config,
320
+ drop_cond_prob=0.2,
321
+ ):
322
+ super(MVDiffusion, self).__init__()
323
+
324
+ self.drop_cond_prob = drop_cond_prob
325
+ self.register_schedule()
326
+
327
+ # init modules
328
+
329
+ pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config)
330
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
331
+ pipeline.scheduler.config, timestep_spacing='trailing'
332
+ )
333
+
334
+ self.pipeline = pipeline
335
+ train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config)
336
+ if isinstance(self.pipeline.unet, UNet2DConditionModel):
337
+ self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler)
338
+
339
+
340
+ self.train_scheduler = train_sched # use ddpm scheduler during training
341
+
342
+ self.unet = pipeline.unet
343
+
344
+ # validation output buffer
345
+ self.validation_step_outputs = []
346
+
347
+ def register_schedule(self):
348
+ self.num_timesteps = 1000
349
+
350
+ # replace scaled_linear schedule with linear schedule as Zero123++
351
+ beta_start = 0.00085
352
+ beta_end = 0.0120
353
+ betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32)
354
+
355
+ alphas = 1. - betas
356
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
357
+ alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
358
+
359
+ self.register_buffer('betas', betas.float())
360
+ self.register_buffer('alphas_cumprod', alphas_cumprod.float())
361
+ self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
362
+
363
+ # calculations for diffusion q(x_t | x_{t-1}) and others
364
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
365
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
366
+
367
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float())
368
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float())
369
+
370
+ def on_fit_start(self):
371
+ device = torch.device(f'cuda:{self.global_rank}')
372
+ self.pipeline.to(device)
373
+ if self.global_rank == 0:
374
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
375
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
376
+
377
+
378
+ def prepare_batch_data(self, batch):
379
+ cond_imgs = batch['cond_imgs'] # (B, C, H, W)
380
+ cond_imgs = cond_imgs.to(self.device)
381
+
382
+ # random resize the condition image
383
+ cond_size = np.random.randint(128, 513)
384
+ cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1)
385
+
386
+ target_imgs = batch['target_imgs'] # (B, 6, C, H, W)
387
+ target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1)
388
+ target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W)
389
+ target_imgs = target_imgs.to(self.device)
390
+
391
+ return cond_imgs, target_imgs
392
+
393
+
394
+ @torch.no_grad()
395
+ def forward_vision_encoder(self, images):
396
+ dtype = next(self.pipeline.vision_encoder.parameters()).dtype
397
+ image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
398
+ image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values
399
+ image_pt = image_pt.to(device=self.device, dtype=dtype)
400
+ global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds
401
+ global_embeds = global_embeds.unsqueeze(-2)
402
+
403
+ encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0]
404
+ ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1)
405
+ encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
406
+
407
+ return encoder_hidden_states
408
+
409
+ @torch.no_grad()
410
+ def encode_condition_image(self, images):
411
+ dtype = next(self.pipeline.vae.parameters()).dtype
412
+ image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
413
+ image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values
414
+ image_pt = image_pt.to(device=self.device, dtype=dtype)
415
+ latents = self.pipeline.vae.encode(image_pt).latent_dist.sample()
416
+ return latents
417
+
418
+ @torch.no_grad()
419
+ def encode_target_images(self, images):
420
+ dtype = next(self.pipeline.vae.parameters()).dtype
421
+ # equals to scaling images to [-1, 1] first and then call scale_image
422
+ images = (images - 0.5) / 0.8 # [-0.625, 0.625]
423
+ posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist
424
+ latents = posterior.sample() * self.pipeline.vae.config.scaling_factor
425
+ latents = scale_latents(latents)
426
+ return latents
427
+
428
+ def forward_unet(self, latents, t, prompt_embeds, cond_latents):
429
+ dtype = next(self.pipeline.unet.parameters()).dtype
430
+ latents = latents.to(dtype)
431
+ prompt_embeds = prompt_embeds.to(dtype)
432
+ cond_latents = cond_latents.to(dtype)
433
+ cross_attention_kwargs = dict(cond_lat=cond_latents)
434
+ pred_noise = self.pipeline.unet(
435
+ latents,
436
+ t,
437
+ encoder_hidden_states=prompt_embeds,
438
+ cross_attention_kwargs=cross_attention_kwargs,
439
+ return_dict=False,
440
+ )[0]
441
+ return pred_noise
442
+
443
+ def predict_start_from_z_and_v(self, x_t, t, v):
444
+ return (
445
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
446
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
447
+ )
448
+
449
+ def get_v(self, x, noise, t):
450
+ return (
451
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
452
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
453
+ )
454
+
455
+ def training_step(self, batch, batch_idx):
456
+ # get input
457
+ cond_imgs, target_imgs = self.prepare_batch_data(batch)
458
+
459
+ # sample random timestep
460
+ B = cond_imgs.shape[0]
461
+
462
+ t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device)
463
+
464
+ # classifier-free guidance
465
+ if np.random.rand() < self.drop_cond_prob:
466
+ prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False)
467
+ cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs))
468
+ else:
469
+ prompt_embeds = self.forward_vision_encoder(cond_imgs)
470
+ cond_latents = self.encode_condition_image(cond_imgs)
471
+
472
+ latents = self.encode_target_images(target_imgs)
473
+ noise = torch.randn_like(latents)
474
+ latents_noisy = self.train_scheduler.add_noise(latents, noise, t)
475
+
476
+ v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents)
477
+ v_target = self.get_v(latents, noise, t)
478
+
479
+ loss, loss_dict = self.compute_loss(v_pred, v_target)
480
+
481
+ # logging
482
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
483
+ self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)
484
+ lr = self.optimizers().param_groups[0]['lr']
485
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
486
+
487
+ if self.global_step % 50 == 0 and self.global_rank == 0:
488
+ with torch.no_grad():
489
+ latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred)
490
+
491
+ latents = unscale_latents(latents_pred)
492
+ images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
493
+ images = (images * 0.5 + 0.5).clamp(0, 1)
494
+ images = torch.cat([target_imgs, images], dim=-2)
495
+
496
+ grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1))
497
+ save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
498
+
499
+ return loss
500
+
501
+ def compute_loss(self, noise_pred, noise_gt):
502
+ loss = F.mse_loss(noise_pred, noise_gt)
503
+
504
+ prefix = 'train'
505
+ loss_dict = {}
506
+ loss_dict.update({f'{prefix}/loss': loss})
507
+
508
+ return loss, loss_dict
509
+
510
+ @torch.no_grad()
511
+ def validation_step(self, batch, batch_idx):
512
+ # get input
513
+ cond_imgs, target_imgs = self.prepare_batch_data(batch)
514
+
515
+ images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])]
516
+
517
+ outputs = []
518
+ for cond_img in images_pil:
519
+ latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images
520
+ image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
521
+ image = (image * 0.5 + 0.5).clamp(0, 1)
522
+ outputs.append(image)
523
+ outputs = torch.cat(outputs, dim=0).to(self.device)
524
+ images = torch.cat([target_imgs, outputs], dim=-2)
525
+
526
+ self.validation_step_outputs.append(images)
527
+
528
+ @torch.no_grad()
529
+ def on_validation_epoch_end(self):
530
+ images = torch.cat(self.validation_step_outputs, dim=0)
531
+
532
+ all_images = self.all_gather(images)
533
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
534
+
535
+ if self.global_rank == 0:
536
+ grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1))
537
+ save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png'))
538
+
539
+ self.validation_step_outputs.clear() # free memory
540
+
541
+ def configure_optimizers(self):
542
+ lr = self.learning_rate
543
+
544
+ optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr)
545
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
546
+
547
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
zero123plus/pipeline.py ADDED
@@ -0,0 +1,1125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
3
+ from diffusers.schedulers import KarrasDiffusionSchedulers
4
+
5
+ import numpy
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.utils.checkpoint
9
+ import torch.distributed
10
+ import transformers
11
+ from collections import OrderedDict
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
15
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
+ from diffusers.utils import randn_tensor
17
+ import diffusers
18
+ from diffusers import (
19
+ AutoencoderKL,
20
+ DDPMScheduler,
21
+ DiffusionPipeline,
22
+ EulerAncestralDiscreteScheduler,
23
+ UNet2DConditionModel,
24
+ ImagePipelineOutput,
25
+ )
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.models.attention_processor import (
28
+ Attention,
29
+ AttnProcessor,
30
+ XFormersAttnProcessor,
31
+ AttnProcessor2_0,
32
+ )
33
+ from diffusers.utils.import_utils import is_xformers_available
34
+
35
+
36
+ def extract_into_tensor(a, t, x_shape):
37
+ b, *_ = t.shape
38
+ out = a.gather(-1, t)
39
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
40
+
41
+
42
+ def to_rgb_image(maybe_rgba: Image.Image):
43
+ if maybe_rgba.mode == "RGB":
44
+ return maybe_rgba
45
+ elif maybe_rgba.mode == "RGBA":
46
+ rgba = maybe_rgba
47
+ img = numpy.random.randint(
48
+ 255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8
49
+ )
50
+ img = Image.fromarray(img, "RGB")
51
+ img.paste(rgba, mask=rgba.getchannel("A"))
52
+ return img
53
+ else:
54
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
55
+
56
+
57
+ class ReferenceOnlyAttnProc(torch.nn.Module):
58
+ def __init__(self, chained_proc, enabled=False, name=None) -> None:
59
+ super().__init__()
60
+ self.enabled = enabled
61
+ self.chained_proc = chained_proc
62
+ self.name = name
63
+
64
+ def __call__(
65
+ self,
66
+ attn: Attention,
67
+ hidden_states,
68
+ encoder_hidden_states=None,
69
+ attention_mask=None,
70
+ mode="w",
71
+ ref_dict: dict = None,
72
+ is_cfg_guidance=False,
73
+ ) -> Any:
74
+ if encoder_hidden_states is None:
75
+ encoder_hidden_states = hidden_states
76
+ if self.enabled and is_cfg_guidance:
77
+ res0 = self.chained_proc(
78
+ attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask
79
+ )
80
+ hidden_states = hidden_states[1:]
81
+ encoder_hidden_states = encoder_hidden_states[1:]
82
+ if self.enabled:
83
+ if mode == "w":
84
+ ref_dict[self.name] = encoder_hidden_states
85
+ elif mode == "r":
86
+ encoder_hidden_states = torch.cat(
87
+ [encoder_hidden_states, ref_dict.pop(self.name)], dim=1
88
+ )
89
+ elif mode == "m":
90
+ encoder_hidden_states = torch.cat(
91
+ [encoder_hidden_states, ref_dict[self.name]], dim=1
92
+ )
93
+ elif mode == "c":
94
+ encoder_hidden_states = torch.cat(
95
+ [encoder_hidden_states, encoder_hidden_states], dim=1
96
+ )
97
+ else:
98
+ assert False, mode
99
+ res = self.chained_proc(
100
+ attn, hidden_states, encoder_hidden_states, attention_mask
101
+ )
102
+ if self.enabled and is_cfg_guidance:
103
+ res = torch.cat([res0, res])
104
+ return res
105
+
106
+
107
+ class RefOnlyNoisedUNet(torch.nn.Module):
108
+ def __init__(
109
+ self,
110
+ unet: UNet2DConditionModel,
111
+ train_sched: DDPMScheduler,
112
+ val_sched: EulerAncestralDiscreteScheduler,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.unet = unet
116
+ self.train_sched = train_sched
117
+ self.val_sched = val_sched
118
+
119
+ unet_lora_attn_procs = dict()
120
+ for name, _ in unet.attn_processors.items():
121
+ if torch.__version__ >= "2.0":
122
+ default_attn_proc = AttnProcessor2_0()
123
+ elif is_xformers_available():
124
+ default_attn_proc = XFormersAttnProcessor()
125
+ else:
126
+ default_attn_proc = AttnProcessor()
127
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
128
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
129
+ )
130
+ unet.set_attn_processor(unet_lora_attn_procs)
131
+
132
+ def __getattr__(self, name: str):
133
+ try:
134
+ return super().__getattr__(name)
135
+ except AttributeError:
136
+ return getattr(self.unet, name)
137
+
138
+ def forward_cond(
139
+ self,
140
+ noisy_cond_lat,
141
+ timestep,
142
+ encoder_hidden_states,
143
+ class_labels,
144
+ ref_dict,
145
+ is_cfg_guidance,
146
+ **kwargs,
147
+ ):
148
+ if is_cfg_guidance:
149
+ encoder_hidden_states = encoder_hidden_states[1:]
150
+ class_labels = class_labels[1:]
151
+ self.unet(
152
+ noisy_cond_lat,
153
+ timestep,
154
+ encoder_hidden_states=encoder_hidden_states,
155
+ class_labels=class_labels,
156
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
157
+ **kwargs,
158
+ )
159
+
160
+ def forward(
161
+ self,
162
+ sample,
163
+ timestep,
164
+ encoder_hidden_states,
165
+ class_labels=None,
166
+ *args,
167
+ cross_attention_kwargs,
168
+ down_block_res_samples=None,
169
+ mid_block_res_sample=None,
170
+ forward_cond_state=True,
171
+ **kwargs,
172
+ ):
173
+ cond_lat = cross_attention_kwargs["cond_lat"]
174
+ is_cfg_guidance = cross_attention_kwargs.get("is_cfg_guidance", False)
175
+ noise = torch.randn_like(cond_lat)
176
+ if self.training:
177
+ noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
178
+ noisy_cond_lat = self.train_sched.scale_model_input(
179
+ noisy_cond_lat, timestep
180
+ )
181
+ else:
182
+ noisy_cond_lat = self.val_sched.add_noise(
183
+ cond_lat, noise, timestep.reshape(-1)
184
+ )
185
+ noisy_cond_lat = self.val_sched.scale_model_input(
186
+ noisy_cond_lat, timestep.reshape(-1)
187
+ )
188
+ ref_dict = {}
189
+ if "dont_forward_cond_state" not in cross_attention_kwargs.keys():
190
+ self.forward_cond(
191
+ noisy_cond_lat,
192
+ timestep,
193
+ encoder_hidden_states,
194
+ class_labels,
195
+ ref_dict,
196
+ is_cfg_guidance,
197
+ **kwargs,
198
+ )
199
+ mode = "r"
200
+ else:
201
+ mode = "c"
202
+ weight_dtype = self.unet.dtype
203
+ return self.unet(
204
+ sample,
205
+ timestep,
206
+ encoder_hidden_states,
207
+ *args,
208
+ class_labels=class_labels,
209
+ cross_attention_kwargs=dict(
210
+ mode=mode, ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance
211
+ ),
212
+ down_block_additional_residuals=[
213
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
214
+ ]
215
+ if down_block_res_samples is not None
216
+ else None,
217
+ mid_block_additional_residual=(
218
+ mid_block_res_sample.to(dtype=weight_dtype)
219
+ if mid_block_res_sample is not None
220
+ else None
221
+ ),
222
+ **kwargs,
223
+ )
224
+
225
+
226
+ def scale_latents(latents):
227
+ latents = (latents - 0.22) * 0.75
228
+ return latents
229
+
230
+
231
+ def unscale_latents(latents):
232
+ latents = latents / 0.75 + 0.22
233
+ return latents
234
+
235
+
236
+ def scale_image(image):
237
+ image = image * 0.5 / 0.8
238
+ return image
239
+
240
+
241
+ def unscale_image(image):
242
+ image = image / 0.5 * 0.8
243
+ return image
244
+
245
+
246
+ class DepthControlUNet(torch.nn.Module):
247
+ def __init__(
248
+ self,
249
+ unet: RefOnlyNoisedUNet,
250
+ controlnet: Optional[diffusers.ControlNetModel] = None,
251
+ conditioning_scale=1.0,
252
+ ) -> None:
253
+ super().__init__()
254
+ self.unet = unet
255
+ if controlnet is None:
256
+ self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
257
+ else:
258
+ self.controlnet = controlnet
259
+ DefaultAttnProc = AttnProcessor2_0
260
+ if is_xformers_available():
261
+ DefaultAttnProc = XFormersAttnProcessor
262
+ self.controlnet.set_attn_processor(DefaultAttnProc())
263
+ self.conditioning_scale = conditioning_scale
264
+
265
+ def __getattr__(self, name: str):
266
+ try:
267
+ return super().__getattr__(name)
268
+ except AttributeError:
269
+ return getattr(self.unet, name)
270
+
271
+ def forward(
272
+ self,
273
+ sample,
274
+ timestep,
275
+ encoder_hidden_states,
276
+ class_labels=None,
277
+ *args,
278
+ cross_attention_kwargs: dict,
279
+ **kwargs,
280
+ ):
281
+ cross_attention_kwargs = dict(cross_attention_kwargs)
282
+ control_depth = cross_attention_kwargs.pop("control_depth")
283
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
284
+ sample,
285
+ timestep,
286
+ encoder_hidden_states=encoder_hidden_states,
287
+ controlnet_cond=control_depth,
288
+ conditioning_scale=self.conditioning_scale,
289
+ return_dict=False,
290
+ )
291
+ return self.unet(
292
+ sample,
293
+ timestep,
294
+ encoder_hidden_states=encoder_hidden_states,
295
+ down_block_res_samples=down_block_res_samples,
296
+ mid_block_res_sample=mid_block_res_sample,
297
+ cross_attention_kwargs=cross_attention_kwargs,
298
+ )
299
+
300
+
301
+ class ModuleListDict(torch.nn.Module):
302
+ def __init__(self, procs: dict) -> None:
303
+ super().__init__()
304
+ self.keys = sorted(procs.keys())
305
+ self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
306
+
307
+ def __getitem__(self, key):
308
+ return self.values[self.keys.index(key)]
309
+
310
+
311
+ class SuperNet(torch.nn.Module):
312
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
313
+ super().__init__()
314
+ state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
315
+ self.layers = torch.nn.ModuleList(state_dict.values())
316
+ self.mapping = dict(enumerate(state_dict.keys()))
317
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
318
+
319
+ # .processor for unet, .self_attn for text encoder
320
+ self.split_keys = [".processor", ".self_attn"]
321
+
322
+ # we add a hook to state_dict() and load_state_dict() so that the
323
+ # naming fits with `unet.attn_processors`
324
+ def map_to(module, state_dict, *args, **kwargs):
325
+ new_state_dict = {}
326
+ for key, value in state_dict.items():
327
+ num = int(key.split(".")[1]) # 0 is always "layers"
328
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
329
+ new_state_dict[new_key] = value
330
+
331
+ return new_state_dict
332
+
333
+ def remap_key(key, state_dict):
334
+ for k in self.split_keys:
335
+ if k in key:
336
+ return key.split(k)[0] + k
337
+ return key.split(".")[0]
338
+
339
+ def map_from(module, state_dict, *args, **kwargs):
340
+ all_keys = list(state_dict.keys())
341
+ for key in all_keys:
342
+ replace_key = remap_key(key, state_dict)
343
+ new_key = key.replace(
344
+ replace_key, f"layers.{module.rev_mapping[replace_key]}"
345
+ )
346
+ state_dict[new_key] = state_dict[key]
347
+ del state_dict[key]
348
+
349
+ self._register_state_dict_hook(map_to)
350
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
351
+
352
+
353
+ class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
354
+ tokenizer: transformers.CLIPTokenizer
355
+ text_encoder: transformers.CLIPTextModel
356
+ vision_encoder: transformers.CLIPVisionModelWithProjection
357
+
358
+ feature_extractor_clip: transformers.CLIPImageProcessor
359
+ unet: UNet2DConditionModel
360
+ scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
361
+
362
+ vae: AutoencoderKL
363
+ ramping: nn.Linear
364
+
365
+ feature_extractor_vae: transformers.CLIPImageProcessor
366
+
367
+ depth_transforms_multi = transforms.Compose(
368
+ [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
369
+ )
370
+
371
+ def __init__(
372
+ self,
373
+ vae: AutoencoderKL,
374
+ text_encoder: CLIPTextModel,
375
+ tokenizer: CLIPTokenizer,
376
+ unet: UNet2DConditionModel,
377
+ scheduler: KarrasDiffusionSchedulers,
378
+ vision_encoder: transformers.CLIPVisionModelWithProjection,
379
+ feature_extractor_clip: CLIPImageProcessor,
380
+ feature_extractor_vae: CLIPImageProcessor,
381
+ ramping_coefficients: Optional[list] = None,
382
+ safety_checker=None,
383
+ ):
384
+ DiffusionPipeline.__init__(self)
385
+
386
+ self.register_modules(
387
+ vae=vae,
388
+ text_encoder=text_encoder,
389
+ tokenizer=tokenizer,
390
+ unet=unet,
391
+ scheduler=scheduler,
392
+ safety_checker=None,
393
+ vision_encoder=vision_encoder,
394
+ feature_extractor_clip=feature_extractor_clip,
395
+ feature_extractor_vae=feature_extractor_vae,
396
+ )
397
+ self.register_to_config(ramping_coefficients=ramping_coefficients)
398
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
399
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
400
+
401
+ def prepare(self):
402
+ train_sched = DDPMScheduler.from_config(self.scheduler.config)
403
+ if isinstance(self.unet, UNet2DConditionModel):
404
+ self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
405
+
406
+ def add_controlnet(
407
+ self,
408
+ controlnet: Optional[diffusers.ControlNetModel] = None,
409
+ conditioning_scale=1.0,
410
+ ):
411
+ self.prepare()
412
+ self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
413
+ return SuperNet(OrderedDict([("controlnet", self.unet.controlnet)]))
414
+
415
+ def encode_condition_image(self, image: torch.Tensor):
416
+ image = self.vae.encode(image).latent_dist.sample()
417
+ return image
418
+
419
+ @torch.no_grad()
420
+ def edit_latents(
421
+ self,
422
+ image_guidance: Image.Image,
423
+ multiview_source_image: Image.Image = None,
424
+ edit_strength: float = 1.0,
425
+ prompt="",
426
+ *args,
427
+ guidance_scale=0.0,
428
+ output_type: Optional[str] = "pil",
429
+ width=640,
430
+ height=960,
431
+ num_inference_steps=28,
432
+ return_dict=True,
433
+ **kwargs,
434
+ ):
435
+ self.prepare()
436
+ if image_guidance is None:
437
+ raise ValueError(
438
+ "Inputting embeddings not supported for this pipeline. Please pass an image."
439
+ )
440
+ if multiview_source_image is None:
441
+ raise ValueError("Multiview source image is required for this pipeline.")
442
+ assert not isinstance(image_guidance, torch.Tensor)
443
+ assert not isinstance(multiview_source_image, torch.Tensor)
444
+ image_guidance = to_rgb_image(image_guidance)
445
+ image_source = to_rgb_image(multiview_source_image)
446
+ image_guidance_1 = self.feature_extractor_vae(
447
+ images=image_guidance, return_tensors="pt"
448
+ ).pixel_values
449
+ image_guidance_2 = self.feature_extractor_clip(
450
+ images=image_source, return_tensors="pt"
451
+ ).pixel_values
452
+ image_guidance = image_guidance_1.to(
453
+ device=self.vae.device, dtype=self.vae.dtype
454
+ )
455
+ image_guidance_2 = image_guidance_2.to(
456
+ device=self.vae.device, dtype=self.vae.dtype
457
+ )
458
+
459
+ cond_lat = self.encode_condition_image(image_guidance)
460
+ # if guidance_scale > 1:
461
+ negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance))
462
+ cond_lat = torch.cat([negative_lat, cond_lat])
463
+ encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False)
464
+
465
+ global_embeds = encoded.image_embeds
466
+ global_embeds = global_embeds.unsqueeze(-2)
467
+ if hasattr(self, "encode_prompt"):
468
+ encoder_hidden_states = self.encode_prompt(prompt, self.device, 1, False)[0]
469
+ else:
470
+ encoder_hidden_states = self._encode_prompt(prompt, self.device, 1, False)
471
+ ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
472
+ encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
473
+ cak = dict(cond_lat=cond_lat)
474
+ mv_image = (
475
+ torch.from_numpy(numpy.array(multiview_source_image)).to(self.vae.device)
476
+ / 255.0
477
+ )
478
+ mv_image = (
479
+ mv_image.permute(2, 0, 1)
480
+ .to(self.vae.device)
481
+ .to(self.vae.dtype)
482
+ .unsqueeze(0)
483
+ )
484
+ latents = (
485
+ self.vae.encode(mv_image * 2.0 - 1.0).latent_dist.sample()
486
+ * self.vae.config.scaling_factor
487
+ )
488
+ latents: torch.Tensor = (
489
+ super()
490
+ .__call__(
491
+ None,
492
+ *args,
493
+ cross_attention_kwargs=cak,
494
+ guidance_scale=guidance_scale,
495
+ num_images_per_prompt=1,
496
+ prompt_embeds=encoder_hidden_states,
497
+ num_inference_steps=num_inference_steps,
498
+ output_type="latent",
499
+ width=width,
500
+ height=height,
501
+ latents=latents,
502
+ edit_strength=edit_strength,
503
+ **kwargs,
504
+ )
505
+ .images
506
+ )
507
+ latents = unscale_latents(latents)
508
+ if not output_type == "latent":
509
+ image = unscale_image(
510
+ self.vae.decode(
511
+ latents / self.vae.config.scaling_factor, return_dict=False
512
+ )[0]
513
+ )
514
+ else:
515
+ image = latents
516
+
517
+ image = self.image_processor.postprocess(image, output_type=output_type)
518
+ if not return_dict:
519
+ return (image,)
520
+
521
+ return ImagePipelineOutput(images=image)
522
+
523
+ @torch.no_grad()
524
+ def encode_target_images(self, images):
525
+ dtype = next(self.vae.parameters()).dtype
526
+ # equals to scaling images to [-1, 1] first and then call scale_image
527
+ images = (images - 0.5) / 0.8 # [-0.625, 0.625]
528
+ posterior = self.vae.encode(images.to(dtype)).latent_dist
529
+ latents = posterior.sample() * self.vae.config.scaling_factor
530
+ latents = scale_latents(latents)
531
+ return latents
532
+
533
+ @torch.no_grad()
534
+ def sdedit(
535
+ self,
536
+ image,
537
+ *args,
538
+ cond_image: Image.Image = None,
539
+ output_type: Optional[str] = "pil",
540
+ width=640,
541
+ height=960,
542
+ num_inference_steps=75,
543
+ edit_strength=1.0,
544
+ return_dict=True,
545
+ guidance_scale=0.0,
546
+ **kwargs,
547
+ ):
548
+ self.prepare()
549
+ if image is None:
550
+ raise ValueError(
551
+ "Inputting embeddings not supported for this pipeline. Please pass an image."
552
+ )
553
+ assert not isinstance(image, torch.Tensor)
554
+ image = to_rgb_image(image)
555
+
556
+ # cond_lat = self.encode_condition_image(image_guidance)
557
+ if hasattr(self, "encode_prompt"):
558
+ encoder_hidden_states = self.encode_prompt([""], self.device, 1, False)[0]
559
+ else:
560
+ encoder_hidden_states = self._encode_prompt([""], self.device, 1, False)
561
+ # negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance))
562
+ # cond_lat = torch.cat([negative_lat, cond_lat])
563
+ # encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False)
564
+
565
+ # global_embeds = encoded.image_embeds
566
+ # global_embeds = global_embeds.unsqueeze(-2)
567
+ # prompt = ""
568
+
569
+ # ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
570
+ # encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
571
+ # cak = dict(cond_lat=cond_lat)
572
+ image = torch.from_numpy(numpy.array(image)).to(self.vae.device) / 255.0
573
+ image = image.permute(2, 0, 1).unsqueeze(0)
574
+ if self.vae.dtype == torch.float16:
575
+ image = image.half()
576
+ # image = image.permute(2, 0, 1).to(self.vae.device).to(self.vae.dtype).unsqueeze(0)
577
+
578
+ latents = self.encode_target_images(image)
579
+ if cond_image is not None:
580
+ cond_image = to_rgb_image(cond_image)
581
+ cond_image = (
582
+ torch.from_numpy(numpy.array(cond_image)).to(self.vae.device) / 255.0
583
+ )
584
+ cond_image = cond_image.permute(2, 0, 1).unsqueeze(0)
585
+ if self.vae.dtype == torch.float16:
586
+ cond_image = cond_image.half()
587
+ cond_lat = self.encode_condition_image(cond_image)
588
+ else:
589
+ cond_lat = self.encode_condition_image(torch.zeros_like(image)).to(
590
+ self.vae.device
591
+ )
592
+ cak = dict(cond_lat=cond_lat, dont_forward_cond_state=True)
593
+ latents = self.forward_sdedit(
594
+ latents,
595
+ cross_attention_kwargs=cak,
596
+ guidance_scale=guidance_scale,
597
+ num_images_per_prompt=1,
598
+ prompt_embeds=encoder_hidden_states,
599
+ num_inference_steps=num_inference_steps,
600
+ output_type="latent",
601
+ width=width,
602
+ height=height,
603
+ edit_strength=edit_strength,
604
+ **kwargs,
605
+ ).images
606
+ # latents = unscale_latents(latents)
607
+ if not output_type == "latent":
608
+ image = unscale_image(
609
+ self.vae.decode(
610
+ latents / self.vae.config.scaling_factor, return_dict=False
611
+ )[0]
612
+ )
613
+ else:
614
+ image = latents
615
+
616
+ image = self.image_processor.postprocess(image, output_type=output_type)
617
+ if not return_dict:
618
+ return (image,)
619
+
620
+ return ImagePipelineOutput(images=image)
621
+
622
+ @torch.no_grad()
623
+ def refine(
624
+ self,
625
+ image: Image.Image = None,
626
+ edit_image: Image.Image = None,
627
+ prompt: Optional[str] = "",
628
+ *args,
629
+ output_type: Optional[str] = "pil",
630
+ width=640,
631
+ height=960,
632
+ num_inference_steps=28,
633
+ edit_strength=1.0,
634
+ return_dict=True,
635
+ guidance_scale=4.0,
636
+ **kwargs,
637
+ ):
638
+ self.prepare()
639
+ if image is None:
640
+ raise ValueError(
641
+ "Inputting embeddings not supported for this pipeline. Please pass an image."
642
+ )
643
+ assert not isinstance(image, torch.Tensor)
644
+ image = to_rgb_image(image)
645
+
646
+ # cond_lat = self.encode_condition_image(image_guidance)
647
+ if hasattr(self, "encode_prompt"):
648
+ encoder_hidden_states = self.encode_prompt(prompt, self.device, 1, False)[0]
649
+ else:
650
+ encoder_hidden_states = self._encode_prompt(prompt, self.device, 1, False)
651
+ # negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance))
652
+ # cond_lat = torch.cat([negative_lat, cond_lat])
653
+ # encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False)
654
+
655
+ # global_embeds = encoded.image_embeds
656
+ # global_embeds = global_embeds.unsqueeze(-2)
657
+ # prompt = ""
658
+
659
+ # ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
660
+ # encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
661
+ # cak = dict(cond_lat=cond_lat)
662
+ latents_edit = None
663
+ if edit_image is not None:
664
+ edit_image = to_rgb_image(edit_image)
665
+ edit_image = (
666
+ torch.from_numpy(numpy.array(edit_image)).to(self.vae.device) / 255.0
667
+ )
668
+ edit_image = edit_image.permute(2, 0, 1).unsqueeze(0)
669
+ if self.vae.dtype == torch.float16:
670
+ edit_image = edit_image.half()
671
+ latents_edit = self.encode_target_images(edit_image)
672
+ image = torch.from_numpy(numpy.array(image)).to(self.vae.device) / 255.0
673
+ image = image.permute(2, 0, 1).unsqueeze(0)
674
+ if self.vae.dtype == torch.float16:
675
+ image = image.half()
676
+ # image = torch.nn.functional.interpolate(
677
+ # image, (height*4, width*4), mode="bilinear", align_corners=False)
678
+ # image = image[...,:320,:320]
679
+ height, width = image.shape[-2:]
680
+ # image = image[...,:640,:]
681
+ # image[...,:320,:] = torch.ones_like(image[...,:320,:])
682
+ # image = image.permute(2, 0, 1).to(self.vae.device).to(self.vae.dtype).unsqueeze(0)
683
+ # height = height * 4
684
+ # width = width * 4
685
+ latents = self.encode_target_images(image)
686
+ # latents[...,-40:,:] = torch.randn_like(latents[...,-40:,:])
687
+
688
+ cond_lat = self.encode_condition_image(torch.zeros_like(image)).to(
689
+ self.vae.device
690
+ )
691
+ cak = dict(cond_lat=cond_lat, dont_forward_cond_state=True)
692
+ latents = self.forward_pipeline(
693
+ latents_edit,
694
+ latents,
695
+ cross_attention_kwargs=cak,
696
+ guidance_scale=guidance_scale,
697
+ num_images_per_prompt=1,
698
+ prompt_embeds=encoder_hidden_states,
699
+ num_inference_steps=num_inference_steps,
700
+ output_type="latent",
701
+ width=width,
702
+ height=height,
703
+ edit_strength=edit_strength,
704
+ **kwargs,
705
+ ).images
706
+ # latents = unscale_latents(latents)
707
+ if not output_type == "latent":
708
+ image = unscale_image(
709
+ self.vae.decode(
710
+ latents / self.vae.config.scaling_factor, return_dict=False
711
+ )[0]
712
+ )
713
+ else:
714
+ image = latents
715
+
716
+ image = self.image_processor.postprocess(image, output_type=output_type)
717
+ if not return_dict:
718
+ return (image,)
719
+
720
+ return ImagePipelineOutput(images=image)
721
+
722
+ def prepare_latents(
723
+ self,
724
+ batch_size,
725
+ num_channels_latents,
726
+ height,
727
+ width,
728
+ dtype,
729
+ device,
730
+ generator,
731
+ latents=None,
732
+ timestep=None,
733
+ ):
734
+ shape = (
735
+ batch_size,
736
+ num_channels_latents,
737
+ height // self.vae_scale_factor,
738
+ width // self.vae_scale_factor,
739
+ )
740
+ if isinstance(generator, list) and len(generator) != batch_size:
741
+ raise ValueError(
742
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
743
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
744
+ )
745
+
746
+ if latents is None:
747
+ latents = randn_tensor(
748
+ shape, generator=generator, device=device, dtype=dtype
749
+ )
750
+ # scale the initial noise by the standard deviation required by the scheduler
751
+ latents = latents * self.scheduler.init_noise_sigma
752
+
753
+ else:
754
+ if timestep is None:
755
+ raise ValueError(
756
+ "When passing `latents` you also need to pass `timestep`."
757
+ )
758
+ latents = latents.to(device)
759
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
760
+ # get latents
761
+ latents = self.scheduler.add_noise(latents, noise, timestep)
762
+
763
+ return latents
764
+
765
+ @torch.no_grad()
766
+ def forward_sdedit(
767
+ self,
768
+ latents: torch.Tensor,
769
+ cross_attention_kwargs: dict,
770
+ guidance_scale: float,
771
+ num_images_per_prompt: int,
772
+ prompt_embeds,
773
+ num_inference_steps: int,
774
+ output_type: str,
775
+ width: int,
776
+ height: int,
777
+ edit_strength: float = 1.0,
778
+ ):
779
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
780
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
781
+
782
+ batch_size = prompt_embeds.shape[0]
783
+ generator = torch.Generator(device=latents.device)
784
+ device = self._execution_device
785
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
786
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
787
+ # corresponds to doing no classifier free guidance.
788
+ do_classifier_free_guidance = guidance_scale > 1.0
789
+
790
+ # 3. Encode input prompt
791
+ text_encoder_lora_scale = (
792
+ cross_attention_kwargs.get("scale", None)
793
+ if cross_attention_kwargs is not None
794
+ else None
795
+ )
796
+ prompt_embeds = self._encode_prompt(
797
+ None,
798
+ device,
799
+ num_images_per_prompt,
800
+ do_classifier_free_guidance,
801
+ None,
802
+ prompt_embeds=prompt_embeds,
803
+ negative_prompt_embeds=None,
804
+ lora_scale=text_encoder_lora_scale,
805
+ )
806
+ # 4. Prepare timesteps
807
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
808
+ # self.scheduler.timesteps = self.scheduler.timesteps
809
+ timesteps = self.scheduler.timesteps
810
+ timesteps = reversed(reversed(timesteps)[: int(edit_strength * len(timesteps))])
811
+
812
+ # 5. Prepare latent variables
813
+ num_channels_latents = self.unet.config.in_channels
814
+
815
+ latents = self.prepare_latents(
816
+ batch_size * num_images_per_prompt,
817
+ num_channels_latents,
818
+ height,
819
+ width,
820
+ prompt_embeds.dtype,
821
+ device,
822
+ generator,
823
+ latents,
824
+ timesteps[0:1],
825
+ )
826
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
827
+ # if do_classifier_free_guidance:
828
+ # cond_latent = cond_latent.expand(batch_size * 2, -1, -1, -1)
829
+
830
+ # 7. Denoising loop
831
+ num_warmup_steps = 0
832
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
833
+ for i, t in enumerate(timesteps):
834
+ # expand the latents if we are doing classifier free guidance
835
+ latent_model_input = (
836
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
837
+ )
838
+ latent_model_input = self.scheduler.scale_model_input(
839
+ latent_model_input, t
840
+ )
841
+ # latent_model_input =
842
+
843
+ # predict the noise residual
844
+ noise_pred = self.unet(
845
+ latent_model_input,
846
+ t,
847
+ encoder_hidden_states=prompt_embeds,
848
+ cross_attention_kwargs=cross_attention_kwargs,
849
+ return_dict=False,
850
+ )[0]
851
+ # exit(0)/
852
+
853
+ # perform guidance
854
+ if do_classifier_free_guidance:
855
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
856
+ noise_pred = noise_pred_uncond + guidance_scale * (
857
+ noise_pred_text - noise_pred_uncond
858
+ )
859
+
860
+ # compute the previous noisy sample x_t -> x_t-1
861
+ latents = self.scheduler.step(
862
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
863
+ )[0]
864
+
865
+ # call the callback, if provided
866
+ if i == len(timesteps) - 1 or (
867
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
868
+ ):
869
+ progress_bar.update()
870
+ latents = unscale_latents(latents)
871
+ if not output_type == "latent":
872
+ image = self.vae.decode(
873
+ latents / self.vae.config.scaling_factor, return_dict=False
874
+ )[0]
875
+ image, has_nsfw_concept = self.run_safety_checker(
876
+ image, device, prompt_embeds.dtype
877
+ )
878
+ else:
879
+ image = latents
880
+ has_nsfw_concept = None
881
+
882
+ if has_nsfw_concept is None:
883
+ do_denormalize = [True] * image.shape[0]
884
+ else:
885
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
886
+
887
+ image = self.image_processor.postprocess(
888
+ image, output_type=output_type, do_denormalize=do_denormalize
889
+ )
890
+
891
+ # Offload last model to CPU
892
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
893
+ self.final_offload_hook.offload()
894
+
895
+ return StableDiffusionPipelineOutput(
896
+ images=image, nsfw_content_detected=has_nsfw_concept
897
+ )
898
+
899
+ @torch.no_grad()
900
+ def forward_pipeline(
901
+ self,
902
+ latents: torch.Tensor,
903
+ cond_latent: torch.Tensor,
904
+ cross_attention_kwargs: dict,
905
+ guidance_scale: float,
906
+ num_images_per_prompt: int,
907
+ prompt_embeds,
908
+ num_inference_steps: int,
909
+ output_type: str,
910
+ width: int,
911
+ height: int,
912
+ edit_strength: float = 1.0,
913
+ ):
914
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
915
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
916
+
917
+ batch_size = 1
918
+ generator = torch.Generator(device=cond_latent.device)
919
+ device = self._execution_device
920
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
921
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
922
+ # corresponds to doing no classifier free guidance.
923
+ do_classifier_free_guidance = guidance_scale > 1.0
924
+
925
+ # 3. Encode input prompt
926
+ text_encoder_lora_scale = (
927
+ cross_attention_kwargs.get("scale", None)
928
+ if cross_attention_kwargs is not None
929
+ else None
930
+ )
931
+ prompt_embeds = self._encode_prompt(
932
+ None,
933
+ device,
934
+ num_images_per_prompt,
935
+ do_classifier_free_guidance,
936
+ None,
937
+ prompt_embeds=prompt_embeds,
938
+ negative_prompt_embeds=None,
939
+ lora_scale=text_encoder_lora_scale,
940
+ )
941
+ # 4. Prepare timesteps
942
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
943
+ # self.scheduler.timesteps = self.scheduler.timesteps
944
+ timesteps = self.scheduler.timesteps
945
+ timesteps = reversed(reversed(timesteps)[: int(edit_strength * len(timesteps))])
946
+
947
+ # 5. Prepare latent variables
948
+ num_channels_latents = self.unet.config.in_channels // 2
949
+
950
+ latents = self.prepare_latents(
951
+ batch_size * num_images_per_prompt,
952
+ num_channels_latents,
953
+ height,
954
+ width,
955
+ prompt_embeds.dtype,
956
+ device,
957
+ generator,
958
+ latents,
959
+ timesteps[0:1],
960
+ )
961
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
962
+ if do_classifier_free_guidance:
963
+ cond_latent = cond_latent.expand(batch_size * 2, -1, -1, -1)
964
+
965
+ # 7. Denoising loop
966
+ num_warmup_steps = 0
967
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
968
+ for i, t in enumerate(timesteps):
969
+ # expand the latents if we are doing classifier free guidance
970
+ latent_model_input = (
971
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
972
+ )
973
+ latent_model_input = self.scheduler.scale_model_input(
974
+ latent_model_input, t
975
+ )
976
+ latent_model_input = torch.cat([latent_model_input, cond_latent], dim=1)
977
+
978
+ # predict the noise residual
979
+ noise_pred = self.unet(
980
+ latent_model_input,
981
+ t,
982
+ encoder_hidden_states=prompt_embeds,
983
+ cross_attention_kwargs=cross_attention_kwargs,
984
+ return_dict=False,
985
+ )[0]
986
+
987
+ # perform guidance
988
+ if do_classifier_free_guidance:
989
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
990
+ noise_pred = noise_pred_uncond + guidance_scale * (
991
+ noise_pred_text - noise_pred_uncond
992
+ )
993
+
994
+ # compute the previous noisy sample x_t -> x_t-1
995
+ latents = self.scheduler.step(
996
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
997
+ )[0]
998
+
999
+ # call the callback, if provided
1000
+ if i == len(timesteps) - 1 or (
1001
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1002
+ ):
1003
+ progress_bar.update()
1004
+ latents = unscale_latents(latents)
1005
+ if not output_type == "latent":
1006
+ image = self.vae.decode(
1007
+ latents / self.vae.config.scaling_factor, return_dict=False
1008
+ )[0]
1009
+ image, has_nsfw_concept = self.run_safety_checker(
1010
+ image, device, prompt_embeds.dtype
1011
+ )
1012
+ else:
1013
+ image = latents
1014
+ has_nsfw_concept = None
1015
+
1016
+ if has_nsfw_concept is None:
1017
+ do_denormalize = [True] * image.shape[0]
1018
+ else:
1019
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1020
+
1021
+ image = self.image_processor.postprocess(
1022
+ image, output_type=output_type, do_denormalize=do_denormalize
1023
+ )
1024
+
1025
+ # Offload last model to CPU
1026
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1027
+ self.final_offload_hook.offload()
1028
+
1029
+ return StableDiffusionPipelineOutput(
1030
+ images=image, nsfw_content_detected=has_nsfw_concept
1031
+ )
1032
+
1033
+ @torch.no_grad()
1034
+ def __call__(
1035
+ self,
1036
+ image: Image.Image = None,
1037
+ source_image: Image.Image = None,
1038
+ prompt="",
1039
+ *args,
1040
+ num_images_per_prompt: Optional[int] = 1,
1041
+ guidance_scale=4.0,
1042
+ depth_image: Image.Image = None,
1043
+ output_type: Optional[str] = "pil",
1044
+ width=640,
1045
+ height=960,
1046
+ num_inference_steps=28,
1047
+ return_dict=True,
1048
+ **kwargs,
1049
+ ):
1050
+ self.prepare()
1051
+ if image is None:
1052
+ raise ValueError(
1053
+ "Inputting embeddings not supported for this pipeline. Please pass an image."
1054
+ )
1055
+ assert not isinstance(image, torch.Tensor)
1056
+ image = to_rgb_image(image)
1057
+ image_1 = self.feature_extractor_vae(
1058
+ images=image, return_tensors="pt"
1059
+ ).pixel_values
1060
+ image_2 = self.feature_extractor_clip(
1061
+ images=image, return_tensors="pt"
1062
+ ).pixel_values
1063
+ # image_source = to_rgb_image(source_image)
1064
+ # image_source_latents = self.feature_extractor_vae(images=image_source, return_tensors="pt")
1065
+ if depth_image is not None and hasattr(self.unet, "controlnet"):
1066
+ depth_image = to_rgb_image(depth_image)
1067
+ depth_image = self.depth_transforms_multi(depth_image).to(
1068
+ device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
1069
+ )
1070
+ image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
1071
+ image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
1072
+ cond_lat = self.encode_condition_image(image)
1073
+ if guidance_scale > 1:
1074
+ negative_lat = self.encode_condition_image(torch.zeros_like(image))
1075
+ cond_lat = torch.cat([negative_lat, cond_lat])
1076
+ encoded = self.vision_encoder(image_2, output_hidden_states=False)
1077
+ global_embeds = encoded.image_embeds
1078
+ global_embeds = global_embeds.unsqueeze(-2)
1079
+
1080
+ if hasattr(self, "encode_prompt"):
1081
+ encoder_hidden_states = self.encode_prompt(
1082
+ prompt, self.device, num_images_per_prompt, False
1083
+ )[0]
1084
+ else:
1085
+ encoder_hidden_states = self._encode_prompt(
1086
+ prompt, self.device, num_images_per_prompt, False
1087
+ )
1088
+ ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
1089
+ encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
1090
+ cak = dict(cond_lat=cond_lat)
1091
+ if hasattr(self.unet, "controlnet"):
1092
+ cak["control_depth"] = depth_image
1093
+ latents: torch.Tensor = (
1094
+ super()
1095
+ .__call__(
1096
+ None,
1097
+ *args,
1098
+ cross_attention_kwargs=cak,
1099
+ guidance_scale=guidance_scale,
1100
+ num_images_per_prompt=num_images_per_prompt,
1101
+ prompt_embeds=encoder_hidden_states,
1102
+ num_inference_steps=num_inference_steps,
1103
+ output_type="latent",
1104
+ width=width,
1105
+ height=height,
1106
+ latents=None,
1107
+ **kwargs,
1108
+ )
1109
+ .images
1110
+ )
1111
+ latents = unscale_latents(latents)
1112
+ if not output_type == "latent":
1113
+ image = unscale_image(
1114
+ self.vae.decode(
1115
+ latents / self.vae.config.scaling_factor, return_dict=False
1116
+ )[0]
1117
+ )
1118
+ else:
1119
+ image = latents
1120
+
1121
+ image = self.image_processor.postprocess(image, output_type=output_type)
1122
+ if not return_dict:
1123
+ return (image,)
1124
+
1125
+ return ImagePipelineOutput(images=image)