kiigii commited on
Commit
59cc439
·
verified ·
1 Parent(s): 1555d91

set num_views as attr of attn_processor to support torch.compile

Browse files
Files changed (1) hide show
  1. pipeline_imagedream.py +24 -11
pipeline_imagedream.py CHANGED
@@ -76,7 +76,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
76
  weight_name: Union[str, List[str]] = "ip-adapter-plus_imagedream.bin",
77
  image_encoder_folder: Optional[str] = "image_encoder",
78
  **kwargs,
79
- ):
80
  super().load_ip_adapter(
81
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
82
  subfolder=subfolder,
@@ -89,12 +89,17 @@ class ImageDreamPipeline(StableDiffusionPipeline):
89
  if weight_name == "ip-adapter-plus_imagedream.bin":
90
  setattr(self.image_encoder, "visual_projection", nn.Identity())
91
  add_imagedream_attn_processor(self.unet)
 
92
  logging.set_verbosity_error()
93
  print(
94
  "ImageDream Cross-Attention uses `num_views` kwarg, "
95
  "and set logging verbosity to error."
96
  )
97
 
 
 
 
 
98
  def encode_image_to_latents(
99
  self,
100
  image: PipelineImageInput,
@@ -326,9 +331,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
326
  guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
327
  ).to(device=device, dtype=latents.dtype)
328
 
329
- cross_attention_kwargs = {"num_views": num_views}
330
- if self.cross_attention_kwargs is not None:
331
- cross_attention_kwargs.update(self.cross_attention_kwargs)
332
 
333
  # fmt: off
334
  # 7. Denoising loop
@@ -352,7 +355,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
352
  class_labels=camera,
353
  encoder_hidden_states=prompt_embeds,
354
  timestep_cond=timestep_cond,
355
- cross_attention_kwargs=cross_attention_kwargs,
356
  added_cond_kwargs=added_cond_kwargs,
357
  return_dict=False,
358
  )[0]
@@ -508,7 +511,7 @@ def get_camera(
508
  # fmt: on
509
 
510
 
511
- def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
512
  attn_procs = {}
513
  for key, attn_processor in unet.attn_processors.items():
514
  if "attn1" in key:
@@ -519,7 +522,18 @@ def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
519
  return unet
520
 
521
 
 
 
 
 
 
 
 
522
  class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
 
 
 
 
523
  def __call__(
524
  self,
525
  attn: Attention,
@@ -527,11 +541,10 @@ class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
527
  encoder_hidden_states: Optional[torch.Tensor] = None,
528
  attention_mask: Optional[torch.Tensor] = None,
529
  temb: Optional[torch.Tensor] = None,
530
- num_views: int = 1,
531
  *args,
532
  **kwargs,
533
  ):
534
- if num_views == 1:
535
  return super().__call__(
536
  attn=attn,
537
  hidden_states=hidden_states,
@@ -544,11 +557,11 @@ class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
544
 
545
  input_ndim = hidden_states.ndim
546
  B = hidden_states.size(0)
547
- if B % num_views:
548
  raise ValueError(
549
- f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
550
  )
551
- real_B = B // num_views
552
  if input_ndim == 4:
553
  H, W = hidden_states.shape[2:]
554
  hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
 
76
  weight_name: Union[str, List[str]] = "ip-adapter-plus_imagedream.bin",
77
  image_encoder_folder: Optional[str] = "image_encoder",
78
  **kwargs,
79
+ ) -> None:
80
  super().load_ip_adapter(
81
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
82
  subfolder=subfolder,
 
89
  if weight_name == "ip-adapter-plus_imagedream.bin":
90
  setattr(self.image_encoder, "visual_projection", nn.Identity())
91
  add_imagedream_attn_processor(self.unet)
92
+ set_num_views(self.unet, self.num_views + 1)
93
  logging.set_verbosity_error()
94
  print(
95
  "ImageDream Cross-Attention uses `num_views` kwarg, "
96
  "and set logging verbosity to error."
97
  )
98
 
99
+ def unload_ip_adapter(self) -> None:
100
+ super().unload_ip_adapter()
101
+ set_num_views(self.unet, self.num_views)
102
+
103
  def encode_image_to_latents(
104
  self,
105
  image: PipelineImageInput,
 
331
  guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
332
  ).to(device=device, dtype=latents.dtype)
333
 
334
+ set_num_views(self.unet, num_views)
 
 
335
 
336
  # fmt: off
337
  # 7. Denoising loop
 
355
  class_labels=camera,
356
  encoder_hidden_states=prompt_embeds,
357
  timestep_cond=timestep_cond,
358
+ cross_attention_kwargs=self.cross_attention_kwargs,
359
  added_cond_kwargs=added_cond_kwargs,
360
  return_dict=False,
361
  )[0]
 
511
  # fmt: on
512
 
513
 
514
+ def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> UNet2DConditionModel:
515
  attn_procs = {}
516
  for key, attn_processor in unet.attn_processors.items():
517
  if "attn1" in key:
 
522
  return unet
523
 
524
 
525
+ def set_num_views(unet: UNet2DConditionModel, num_views: int) -> UNet2DConditionModel:
526
+ for key, attn_processor in unet.attn_processors.items():
527
+ if isinstance(attn_processor, ImageDreamAttnProcessor2_0):
528
+ attn_processor.num_views = num_views
529
+ return unet
530
+
531
+
532
  class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
533
+ def __init__(self, num_views: int = 4):
534
+ super().__init__()
535
+ self.num_views = num_views
536
+
537
  def __call__(
538
  self,
539
  attn: Attention,
 
541
  encoder_hidden_states: Optional[torch.Tensor] = None,
542
  attention_mask: Optional[torch.Tensor] = None,
543
  temb: Optional[torch.Tensor] = None,
 
544
  *args,
545
  **kwargs,
546
  ):
547
+ if self.num_views == 1:
548
  return super().__call__(
549
  attn=attn,
550
  hidden_states=hidden_states,
 
557
 
558
  input_ndim = hidden_states.ndim
559
  B = hidden_states.size(0)
560
+ if B % self.num_views:
561
  raise ValueError(
562
+ f"`batch_size`(got {B}) must be a multiple of `num_views`(got {self.num_views})."
563
  )
564
+ real_B = B // self.num_views
565
  if input_ndim == 4:
566
  H, W = hidden_states.shape[2:]
567
  hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)