JiantaoLin
commited on
Commit
Β·
c0dbb78
1
Parent(s):
2b11b58
new
Browse files- pipeline/kiss3d_wrapper.py +16 -16
pipeline/kiss3d_wrapper.py
CHANGED
|
@@ -67,7 +67,7 @@ def init_wrapper_from_config(config_path):
|
|
| 67 |
flux_dtype = config_['flux'].get('dtype', 'bf16')
|
| 68 |
flux_controlnet_pth = config_['flux'].get('controlnet', None)
|
| 69 |
# flux_lora_pth = config_['flux'].get('lora', None)
|
| 70 |
-
flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="
|
| 71 |
flux_redux_pth = config_['flux'].get('redux', None)
|
| 72 |
# taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype_[flux_dtype]).to(flux_device)
|
| 73 |
if flux_base_model_pth.endswith('safetensors'):
|
|
@@ -102,23 +102,23 @@ def init_wrapper_from_config(config_path):
|
|
| 102 |
# logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
|
| 103 |
|
| 104 |
# init multiview model
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
|
| 121 |
-
|
| 122 |
# logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
|
| 123 |
multiview_pipeline = None
|
| 124 |
|
|
|
|
| 67 |
flux_dtype = config_['flux'].get('dtype', 'bf16')
|
| 68 |
flux_controlnet_pth = config_['flux'].get('controlnet', None)
|
| 69 |
# flux_lora_pth = config_['flux'].get('lora', None)
|
| 70 |
+
flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="rgb_normal.safetensors", repo_type="model", token=access_token)
|
| 71 |
flux_redux_pth = config_['flux'].get('redux', None)
|
| 72 |
# taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype_[flux_dtype]).to(flux_device)
|
| 73 |
if flux_base_model_pth.endswith('safetensors'):
|
|
|
|
| 102 |
# logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
|
| 103 |
|
| 104 |
# init multiview model
|
| 105 |
+
logger.info('==> Loading multiview diffusion model ...')
|
| 106 |
+
multiview_device = config_['multiview'].get('device', 'cpu')
|
| 107 |
+
multiview_pipeline = DiffusionPipeline.from_pretrained(
|
| 108 |
+
config_['multiview']['base_model'],
|
| 109 |
+
custom_pipeline=config_['multiview']['custom_pipeline'],
|
| 110 |
+
torch_dtype=torch.float16,
|
| 111 |
+
)
|
| 112 |
+
multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
| 113 |
+
multiview_pipeline.scheduler.config, timestep_spacing='trailing'
|
| 114 |
+
)
|
| 115 |
|
| 116 |
+
unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
|
| 117 |
+
if unet_ckpt_path is not None:
|
| 118 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
| 119 |
+
multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
|
| 120 |
|
| 121 |
+
multiview_pipeline.to(multiview_device)
|
| 122 |
# logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
|
| 123 |
multiview_pipeline = None
|
| 124 |
|