from typing import Optional import torch import torch.nn as nn from diffusers.models import UNet2DConditionModel from diffusers.models.attention import Attention from diffusers.models.attention_processor import AttnProcessor2_0 def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module: attn_procs = {} for key, attn_processor in unet.attn_processors.items(): if "attn1" in key: attn_procs[key] = ImageDreamAttnProcessor2_0() else: attn_procs[key] = attn_processor unet.set_attn_processor(attn_procs) return unet class ImageDreamAttnProcessor2_0(AttnProcessor2_0): def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, num_views: int = 1, *args, **kwargs, ): if num_views == 1: return super().__call__( attn=attn, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, temb=temb, *args, **kwargs, ) input_ndim = hidden_states.ndim B = hidden_states.size(0) if B % num_views: raise ValueError( f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})." ) real_B = B // num_views if input_ndim == 4: H, W = hidden_states.shape[2:] hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2) else: hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1)) hidden_states = super().__call__( attn=attn, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, temb=temb, *args, **kwargs, ) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W) else: hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1)) return hidden_states