set num_views as attr of attn_processor to support torch.compile
Browse files- pipeline_imagedream.py +24 -11
pipeline_imagedream.py
CHANGED
@@ -76,7 +76,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
76 |
weight_name: Union[str, List[str]] = "ip-adapter-plus_imagedream.bin",
|
77 |
image_encoder_folder: Optional[str] = "image_encoder",
|
78 |
**kwargs,
|
79 |
-
):
|
80 |
super().load_ip_adapter(
|
81 |
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
82 |
subfolder=subfolder,
|
@@ -89,12 +89,17 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
89 |
if weight_name == "ip-adapter-plus_imagedream.bin":
|
90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
91 |
add_imagedream_attn_processor(self.unet)
|
|
|
92 |
logging.set_verbosity_error()
|
93 |
print(
|
94 |
"ImageDream Cross-Attention uses `num_views` kwarg, "
|
95 |
"and set logging verbosity to error."
|
96 |
)
|
97 |
|
|
|
|
|
|
|
|
|
98 |
def encode_image_to_latents(
|
99 |
self,
|
100 |
image: PipelineImageInput,
|
@@ -326,9 +331,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
326 |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
327 |
).to(device=device, dtype=latents.dtype)
|
328 |
|
329 |
-
|
330 |
-
if self.cross_attention_kwargs is not None:
|
331 |
-
cross_attention_kwargs.update(self.cross_attention_kwargs)
|
332 |
|
333 |
# fmt: off
|
334 |
# 7. Denoising loop
|
@@ -352,7 +355,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
352 |
class_labels=camera,
|
353 |
encoder_hidden_states=prompt_embeds,
|
354 |
timestep_cond=timestep_cond,
|
355 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
356 |
added_cond_kwargs=added_cond_kwargs,
|
357 |
return_dict=False,
|
358 |
)[0]
|
@@ -508,7 +511,7 @@ def get_camera(
|
|
508 |
# fmt: on
|
509 |
|
510 |
|
511 |
-
def add_imagedream_attn_processor(unet: UNet2DConditionModel) ->
|
512 |
attn_procs = {}
|
513 |
for key, attn_processor in unet.attn_processors.items():
|
514 |
if "attn1" in key:
|
@@ -519,7 +522,18 @@ def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
|
|
519 |
return unet
|
520 |
|
521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
|
|
|
|
|
|
|
|
523 |
def __call__(
|
524 |
self,
|
525 |
attn: Attention,
|
@@ -527,11 +541,10 @@ class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
|
527 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
528 |
attention_mask: Optional[torch.Tensor] = None,
|
529 |
temb: Optional[torch.Tensor] = None,
|
530 |
-
num_views: int = 1,
|
531 |
*args,
|
532 |
**kwargs,
|
533 |
):
|
534 |
-
if num_views == 1:
|
535 |
return super().__call__(
|
536 |
attn=attn,
|
537 |
hidden_states=hidden_states,
|
@@ -544,11 +557,11 @@ class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
|
544 |
|
545 |
input_ndim = hidden_states.ndim
|
546 |
B = hidden_states.size(0)
|
547 |
-
if B % num_views:
|
548 |
raise ValueError(
|
549 |
-
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
|
550 |
)
|
551 |
-
real_B = B // num_views
|
552 |
if input_ndim == 4:
|
553 |
H, W = hidden_states.shape[2:]
|
554 |
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
|
|
|
76 |
weight_name: Union[str, List[str]] = "ip-adapter-plus_imagedream.bin",
|
77 |
image_encoder_folder: Optional[str] = "image_encoder",
|
78 |
**kwargs,
|
79 |
+
) -> None:
|
80 |
super().load_ip_adapter(
|
81 |
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
82 |
subfolder=subfolder,
|
|
|
89 |
if weight_name == "ip-adapter-plus_imagedream.bin":
|
90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
91 |
add_imagedream_attn_processor(self.unet)
|
92 |
+
set_num_views(self.unet, self.num_views + 1)
|
93 |
logging.set_verbosity_error()
|
94 |
print(
|
95 |
"ImageDream Cross-Attention uses `num_views` kwarg, "
|
96 |
"and set logging verbosity to error."
|
97 |
)
|
98 |
|
99 |
+
def unload_ip_adapter(self) -> None:
|
100 |
+
super().unload_ip_adapter()
|
101 |
+
set_num_views(self.unet, self.num_views)
|
102 |
+
|
103 |
def encode_image_to_latents(
|
104 |
self,
|
105 |
image: PipelineImageInput,
|
|
|
331 |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
332 |
).to(device=device, dtype=latents.dtype)
|
333 |
|
334 |
+
set_num_views(self.unet, num_views)
|
|
|
|
|
335 |
|
336 |
# fmt: off
|
337 |
# 7. Denoising loop
|
|
|
355 |
class_labels=camera,
|
356 |
encoder_hidden_states=prompt_embeds,
|
357 |
timestep_cond=timestep_cond,
|
358 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
359 |
added_cond_kwargs=added_cond_kwargs,
|
360 |
return_dict=False,
|
361 |
)[0]
|
|
|
511 |
# fmt: on
|
512 |
|
513 |
|
514 |
+
def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> UNet2DConditionModel:
|
515 |
attn_procs = {}
|
516 |
for key, attn_processor in unet.attn_processors.items():
|
517 |
if "attn1" in key:
|
|
|
522 |
return unet
|
523 |
|
524 |
|
525 |
+
def set_num_views(unet: UNet2DConditionModel, num_views: int) -> UNet2DConditionModel:
|
526 |
+
for key, attn_processor in unet.attn_processors.items():
|
527 |
+
if isinstance(attn_processor, ImageDreamAttnProcessor2_0):
|
528 |
+
attn_processor.num_views = num_views
|
529 |
+
return unet
|
530 |
+
|
531 |
+
|
532 |
class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
533 |
+
def __init__(self, num_views: int = 4):
|
534 |
+
super().__init__()
|
535 |
+
self.num_views = num_views
|
536 |
+
|
537 |
def __call__(
|
538 |
self,
|
539 |
attn: Attention,
|
|
|
541 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
542 |
attention_mask: Optional[torch.Tensor] = None,
|
543 |
temb: Optional[torch.Tensor] = None,
|
|
|
544 |
*args,
|
545 |
**kwargs,
|
546 |
):
|
547 |
+
if self.num_views == 1:
|
548 |
return super().__call__(
|
549 |
attn=attn,
|
550 |
hidden_states=hidden_states,
|
|
|
557 |
|
558 |
input_ndim = hidden_states.ndim
|
559 |
B = hidden_states.size(0)
|
560 |
+
if B % self.num_views:
|
561 |
raise ValueError(
|
562 |
+
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {self.num_views})."
|
563 |
)
|
564 |
+
real_B = B // self.num_views
|
565 |
if input_ndim == 4:
|
566 |
H, W = hidden_states.shape[2:]
|
567 |
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
|