File size: 2,340 Bytes
aaa261a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from typing import Optional
import torch
import torch.nn as nn
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttnProcessor2_0
def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
attn_procs = {}
for key, attn_processor in unet.attn_processors.items():
if "attn1" in key:
attn_procs[key] = ImageDreamAttnProcessor2_0()
else:
attn_procs[key] = attn_processor
unet.set_attn_processor(attn_procs)
return unet
class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
num_views: int = 1,
*args,
**kwargs,
):
if num_views == 1:
return super().__call__(
attn=attn,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
temb=temb,
*args,
**kwargs,
)
input_ndim = hidden_states.ndim
B = hidden_states.size(0)
if B % num_views:
raise ValueError(
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
)
real_B = B // num_views
if input_ndim == 4:
H, W = hidden_states.shape[2:]
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
else:
hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1))
hidden_states = super().__call__(
attn=attn,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
temb=temb,
*args,
**kwargs,
)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W)
else:
hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1))
return hidden_states
|