Huiwenshi commited on
Commit
4b39f57
·
verified ·
1 Parent(s): 323fb7a

Update hy3dshape/hy3dshape/pipelines.py

Browse files
Files changed (1) hide show
  1. hy3dshape/hy3dshape/pipelines.py +13 -8
hy3dshape/hy3dshape/pipelines.py CHANGED
@@ -27,6 +27,7 @@ from diffusers.utils.torch_utils import randn_tensor
27
  from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
28
  from tqdm import tqdm
29
 
 
30
  from .models.autoencoders import ShapeVAE
31
  from .models.autoencoders import SurfaceExtractors
32
  from .utils import logger, synchronize_timer, smart_load_model
@@ -601,8 +602,9 @@ class Hunyuan3DDiTPipeline:
601
  batch_size = image.shape[0]
602
 
603
  t_dtype = torch.long
 
604
  timesteps, num_inference_steps = retrieve_timesteps(
605
- self.scheduler, num_inference_steps, device, timesteps, sigmas)
606
 
607
  latents = self.prepare_latents(batch_size, dtype, device, generator)
608
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -621,7 +623,7 @@ class Hunyuan3DDiTPipeline:
621
  latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
622
  else:
623
  latent_model_input = latents
624
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
625
 
626
  # predict the noise residual
627
  timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
@@ -642,11 +644,11 @@ class Hunyuan3DDiTPipeline:
642
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
643
 
644
  # compute the previous noisy sample x_t -> x_t-1
645
- outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
646
  latents = outputs.prev_sample
647
 
648
  if callback is not None and i % callback_steps == 0:
649
- step_idx = i // getattr(self.scheduler, "order", 1)
650
  callback(step_idx, t, outputs)
651
 
652
  return self._export(
@@ -733,11 +735,13 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
733
 
734
  batch_size = image.shape[0]
735
 
 
 
736
  # 5. Prepare timesteps
737
  # NOTE: this is slightly different from common usage, we start from 0.
738
  sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas
739
  timesteps, num_inference_steps = retrieve_timesteps(
740
- self.scheduler,
741
  num_inference_steps,
742
  device,
743
  sigmas=sigmas,
@@ -760,7 +764,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
760
 
761
  # NOTE: we assume model get timesteps ranged from 0 to 1
762
  timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
763
- timestep = timestep / self.scheduler.config.num_train_timesteps
764
  noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)
765
 
766
  if do_classifier_free_guidance:
@@ -768,13 +772,14 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
768
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
769
 
770
  # compute the previous noisy sample x_t -> x_t-1
771
- outputs = self.scheduler.step(noise_pred, t, latents)
772
  latents = outputs.prev_sample
773
 
774
  if callback is not None and i % callback_steps == 0:
775
- step_idx = i // getattr(self.scheduler, "order", 1)
776
  callback(step_idx, t, outputs)
777
 
 
778
  return self._export(
779
  latents,
780
  output_type,
 
27
  from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
28
  from tqdm import tqdm
29
 
30
+ import copy
31
  from .models.autoencoders import ShapeVAE
32
  from .models.autoencoders import SurfaceExtractors
33
  from .utils import logger, synchronize_timer, smart_load_model
 
602
  batch_size = image.shape[0]
603
 
604
  t_dtype = torch.long
605
+ inner_scheduler = copy.deepcopy(scheduler)
606
  timesteps, num_inference_steps = retrieve_timesteps(
607
+ inner_scheduler, num_inference_steps, device, timesteps, sigmas)
608
 
609
  latents = self.prepare_latents(batch_size, dtype, device, generator)
610
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
623
  latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
624
  else:
625
  latent_model_input = latents
626
+ latent_model_input = inner_scheduler.scale_model_input(latent_model_input, t)
627
 
628
  # predict the noise residual
629
  timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
 
644
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
645
 
646
  # compute the previous noisy sample x_t -> x_t-1
647
+ outputs = inner_scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
648
  latents = outputs.prev_sample
649
 
650
  if callback is not None and i % callback_steps == 0:
651
+ step_idx = i // getattr(inner_scheduler, "order", 1)
652
  callback(step_idx, t, outputs)
653
 
654
  return self._export(
 
735
 
736
  batch_size = image.shape[0]
737
 
738
+ inner_scheduler = copy.deepcopy(scheduler)
739
+
740
  # 5. Prepare timesteps
741
  # NOTE: this is slightly different from common usage, we start from 0.
742
  sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas
743
  timesteps, num_inference_steps = retrieve_timesteps(
744
+ inner_scheduler,
745
  num_inference_steps,
746
  device,
747
  sigmas=sigmas,
 
764
 
765
  # NOTE: we assume model get timesteps ranged from 0 to 1
766
  timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
767
+ timestep = timestep / inner_scheduler.config.num_train_timesteps
768
  noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)
769
 
770
  if do_classifier_free_guidance:
 
772
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
773
 
774
  # compute the previous noisy sample x_t -> x_t-1
775
+ outputs = inner_scheduler.step(noise_pred, t, latents)
776
  latents = outputs.prev_sample
777
 
778
  if callback is not None and i % callback_steps == 0:
779
+ step_idx = i // getattr(inner_scheduler, "order", 1)
780
  callback(step_idx, t, outputs)
781
 
782
+ del inner_scheduler
783
  return self._export(
784
  latents,
785
  output_type,