Spaces:
Paused
Paused
Update mimicmotion/pipelines/pipeline_mimicmotion.py
Browse files
mimicmotion/pipelines/pipeline_mimicmotion.py
CHANGED
|
@@ -556,21 +556,17 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 556 |
# expand the latents if we are doing classifier free guidance
|
| 557 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 558 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 559 |
-
|
| 560 |
# Concatenate image_latents over channels dimension
|
| 561 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 562 |
-
|
| 563 |
# predict the noise residual
|
| 564 |
noise_pred = torch.zeros_like(image_latents)
|
| 565 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
| 566 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
| 567 |
weight = torch.minimum(weight, 2 - weight)
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
def process_index(idx):
|
| 571 |
-
nonlocal noise_pred, noise_pred_cnt
|
| 572 |
-
result = torch.zeros_like(image_latents[:1, idx]) # Placeholder for thread-safe accumulation
|
| 573 |
-
|
| 574 |
# classification-free inference
|
| 575 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
| 576 |
_noise_pred = self.unet(
|
|
@@ -582,8 +578,8 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 582 |
image_only_indicator=image_only_indicator,
|
| 583 |
return_dict=False,
|
| 584 |
)[0]
|
| 585 |
-
|
| 586 |
-
|
| 587 |
# normal inference
|
| 588 |
_noise_pred = self.unet(
|
| 589 |
latent_model_input[1:, idx],
|
|
@@ -594,34 +590,26 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 594 |
image_only_indicator=image_only_indicator,
|
| 595 |
return_dict=False,
|
| 596 |
)[0]
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 602 |
-
futures = [executor.submit(process_index, idx) for idx in indices]
|
| 603 |
-
for future in concurrent.futures.as_completed(futures):
|
| 604 |
-
_noise_pred, idx = future.result()
|
| 605 |
-
noise_pred[:, idx] += _noise_pred
|
| 606 |
-
noise_pred_cnt[idx] += weight
|
| 607 |
-
progress_bar.update()
|
| 608 |
-
|
| 609 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
| 610 |
-
|
| 611 |
# perform guidance
|
| 612 |
if self.do_classifier_free_guidance:
|
| 613 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 614 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 615 |
-
|
| 616 |
# compute the previous noisy sample x_t -> x_t-1
|
| 617 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 618 |
-
|
| 619 |
if callback_on_step_end is not None:
|
| 620 |
callback_kwargs = {}
|
| 621 |
for k in callback_on_step_end_tensor_inputs:
|
| 622 |
callback_kwargs[k] = locals()[k]
|
| 623 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 624 |
-
|
| 625 |
latents = callback_outputs.pop("latents", latents)
|
| 626 |
|
| 627 |
self.pose_net.cpu()
|
|
|
|
| 556 |
# expand the latents if we are doing classifier free guidance
|
| 557 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 558 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 559 |
+
|
| 560 |
# Concatenate image_latents over channels dimension
|
| 561 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 562 |
+
|
| 563 |
# predict the noise residual
|
| 564 |
noise_pred = torch.zeros_like(image_latents)
|
| 565 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
| 566 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
| 567 |
weight = torch.minimum(weight, 2 - weight)
|
| 568 |
+
for idx in indices:
|
| 569 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
# classification-free inference
|
| 571 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
| 572 |
_noise_pred = self.unet(
|
|
|
|
| 578 |
image_only_indicator=image_only_indicator,
|
| 579 |
return_dict=False,
|
| 580 |
)[0]
|
| 581 |
+
noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None]
|
| 582 |
+
|
| 583 |
# normal inference
|
| 584 |
_noise_pred = self.unet(
|
| 585 |
latent_model_input[1:, idx],
|
|
|
|
| 590 |
image_only_indicator=image_only_indicator,
|
| 591 |
return_dict=False,
|
| 592 |
)[0]
|
| 593 |
+
noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None]
|
| 594 |
+
|
| 595 |
+
noise_pred_cnt[idx] += weight
|
| 596 |
+
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
| 598 |
+
|
| 599 |
# perform guidance
|
| 600 |
if self.do_classifier_free_guidance:
|
| 601 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 602 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 603 |
+
|
| 604 |
# compute the previous noisy sample x_t -> x_t-1
|
| 605 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 606 |
+
|
| 607 |
if callback_on_step_end is not None:
|
| 608 |
callback_kwargs = {}
|
| 609 |
for k in callback_on_step_end_tensor_inputs:
|
| 610 |
callback_kwargs[k] = locals()[k]
|
| 611 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 612 |
+
|
| 613 |
latents = callback_outputs.pop("latents", latents)
|
| 614 |
|
| 615 |
self.pose_net.cpu()
|