kiigii commited on
Commit
1555d91
·
verified ·
1 Parent(s): 9c226de

Delete attention_processor.py

Browse files
Files changed (1) hide show
  1. attention_processor.py +0 -70
attention_processor.py DELETED
@@ -1,70 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn as nn
5
- from diffusers.models import UNet2DConditionModel
6
- from diffusers.models.attention import Attention
7
- from diffusers.models.attention_processor import AttnProcessor2_0
8
-
9
-
10
-
11
- def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
12
- attn_procs = {}
13
- for key, attn_processor in unet.attn_processors.items():
14
- if "attn1" in key:
15
- attn_procs[key] = ImageDreamAttnProcessor2_0()
16
- else:
17
- attn_procs[key] = attn_processor
18
- unet.set_attn_processor(attn_procs)
19
- return unet
20
-
21
-
22
- class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
23
- def __call__(
24
- self,
25
- attn: Attention,
26
- hidden_states: torch.Tensor,
27
- encoder_hidden_states: Optional[torch.Tensor] = None,
28
- attention_mask: Optional[torch.Tensor] = None,
29
- temb: Optional[torch.Tensor] = None,
30
- num_views: int = 1,
31
- *args,
32
- **kwargs,
33
- ):
34
- if num_views == 1:
35
- return super().__call__(
36
- attn=attn,
37
- hidden_states=hidden_states,
38
- encoder_hidden_states=encoder_hidden_states,
39
- attention_mask=attention_mask,
40
- temb=temb,
41
- *args,
42
- **kwargs,
43
- )
44
-
45
- input_ndim = hidden_states.ndim
46
- B = hidden_states.size(0)
47
- if B % num_views:
48
- raise ValueError(
49
- f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
50
- )
51
- real_B = B // num_views
52
- if input_ndim == 4:
53
- H, W = hidden_states.shape[2:]
54
- hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
55
- else:
56
- hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1))
57
- hidden_states = super().__call__(
58
- attn=attn,
59
- hidden_states=hidden_states,
60
- encoder_hidden_states=encoder_hidden_states,
61
- attention_mask=attention_mask,
62
- temb=temb,
63
- *args,
64
- **kwargs,
65
- )
66
- if input_ndim == 4:
67
- hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W)
68
- else:
69
- hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1))
70
- return hidden_states