File size: 2,340 Bytes
aaa261a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Optional

import torch
import torch.nn as nn
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttnProcessor2_0



def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
    attn_procs = {}
    for key, attn_processor in unet.attn_processors.items():
        if "attn1" in key:
            attn_procs[key] = ImageDreamAttnProcessor2_0()
        else:
            attn_procs[key] = attn_processor
    unet.set_attn_processor(attn_procs)
    return unet


class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        num_views: int = 1,
        *args,
        **kwargs,
    ):
        if num_views == 1:
            return super().__call__(
                attn=attn,
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                temb=temb,
                *args,
                **kwargs,
            )

        input_ndim = hidden_states.ndim
        B = hidden_states.size(0)
        if B % num_views:
            raise ValueError(
                f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
            )
        real_B = B // num_views
        if input_ndim == 4:
            H, W = hidden_states.shape[2:]
            hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
        else:
            hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1))
        hidden_states = super().__call__(
            attn=attn,
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            temb=temb,
            *args,
            **kwargs,
        )
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W)
        else:
            hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1))
        return hidden_states