Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- app.py +1 -7
- attention_processor_faceid.py +426 -0
- helper.py +236 -0
- ipown.py +470 -0
- loader.py +95 -0
- requirements.txt +14 -0
- resampler.py +158 -0
- utils.py +170 -0
app.py
CHANGED
|
@@ -1,7 +1 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
def greet(name):
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
-
|
| 6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 7 |
-
demo.launch()
|
|
|
|
| 1 |
+
import loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_processor_faceid.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from diffusers.models.lora import LoRALinearLayer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LoRAAttnProcessor(nn.Module):
|
| 10 |
+
r"""
|
| 11 |
+
Default processor for performing attention-related computations.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
hidden_size=None,
|
| 17 |
+
cross_attention_dim=None,
|
| 18 |
+
rank=4,
|
| 19 |
+
network_alpha=None,
|
| 20 |
+
lora_scale=1.0,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
self.rank = rank
|
| 25 |
+
self.lora_scale = lora_scale
|
| 26 |
+
|
| 27 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 28 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 29 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 30 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 31 |
+
|
| 32 |
+
def __call__(
|
| 33 |
+
self,
|
| 34 |
+
attn,
|
| 35 |
+
hidden_states,
|
| 36 |
+
encoder_hidden_states=None,
|
| 37 |
+
attention_mask=None,
|
| 38 |
+
temb=None,
|
| 39 |
+
):
|
| 40 |
+
residual = hidden_states
|
| 41 |
+
|
| 42 |
+
if attn.spatial_norm is not None:
|
| 43 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 44 |
+
|
| 45 |
+
input_ndim = hidden_states.ndim
|
| 46 |
+
|
| 47 |
+
if input_ndim == 4:
|
| 48 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 49 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 50 |
+
|
| 51 |
+
batch_size, sequence_length, _ = (
|
| 52 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 53 |
+
)
|
| 54 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 55 |
+
|
| 56 |
+
if attn.group_norm is not None:
|
| 57 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 58 |
+
|
| 59 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
| 60 |
+
|
| 61 |
+
if encoder_hidden_states is None:
|
| 62 |
+
encoder_hidden_states = hidden_states
|
| 63 |
+
elif attn.norm_cross:
|
| 64 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 65 |
+
|
| 66 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
| 67 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
| 68 |
+
|
| 69 |
+
query = attn.head_to_batch_dim(query)
|
| 70 |
+
key = attn.head_to_batch_dim(key)
|
| 71 |
+
value = attn.head_to_batch_dim(value)
|
| 72 |
+
|
| 73 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 74 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 75 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 76 |
+
|
| 77 |
+
# linear proj
|
| 78 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
| 79 |
+
# dropout
|
| 80 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 81 |
+
|
| 82 |
+
if input_ndim == 4:
|
| 83 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 84 |
+
|
| 85 |
+
if attn.residual_connection:
|
| 86 |
+
hidden_states = hidden_states + residual
|
| 87 |
+
|
| 88 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 89 |
+
|
| 90 |
+
return hidden_states
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class LoRAIPAttnProcessor(nn.Module):
|
| 94 |
+
r"""
|
| 95 |
+
Attention processor for IP-Adapater.
|
| 96 |
+
Args:
|
| 97 |
+
hidden_size (`int`):
|
| 98 |
+
The hidden size of the attention layer.
|
| 99 |
+
cross_attention_dim (`int`):
|
| 100 |
+
The number of channels in the `encoder_hidden_states`.
|
| 101 |
+
scale (`float`, defaults to 1.0):
|
| 102 |
+
the weight scale of image prompt.
|
| 103 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
| 104 |
+
The context length of the image features.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
self.rank = rank
|
| 111 |
+
self.lora_scale = lora_scale
|
| 112 |
+
|
| 113 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 114 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 115 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 116 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 117 |
+
|
| 118 |
+
self.hidden_size = hidden_size
|
| 119 |
+
self.cross_attention_dim = cross_attention_dim
|
| 120 |
+
self.scale = scale
|
| 121 |
+
self.num_tokens = num_tokens
|
| 122 |
+
|
| 123 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 124 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 125 |
+
|
| 126 |
+
def __call__(
|
| 127 |
+
self,
|
| 128 |
+
attn,
|
| 129 |
+
hidden_states,
|
| 130 |
+
encoder_hidden_states=None,
|
| 131 |
+
attention_mask=None,
|
| 132 |
+
temb=None,
|
| 133 |
+
):
|
| 134 |
+
residual = hidden_states
|
| 135 |
+
|
| 136 |
+
if attn.spatial_norm is not None:
|
| 137 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 138 |
+
|
| 139 |
+
input_ndim = hidden_states.ndim
|
| 140 |
+
|
| 141 |
+
if input_ndim == 4:
|
| 142 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 143 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 144 |
+
|
| 145 |
+
batch_size, sequence_length, _ = (
|
| 146 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 147 |
+
)
|
| 148 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 149 |
+
|
| 150 |
+
if attn.group_norm is not None:
|
| 151 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 152 |
+
|
| 153 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
| 154 |
+
|
| 155 |
+
if encoder_hidden_states is None:
|
| 156 |
+
encoder_hidden_states = hidden_states
|
| 157 |
+
else:
|
| 158 |
+
# get encoder_hidden_states, ip_hidden_states
|
| 159 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
| 160 |
+
encoder_hidden_states, ip_hidden_states = (
|
| 161 |
+
encoder_hidden_states[:, :end_pos, :],
|
| 162 |
+
encoder_hidden_states[:, end_pos:, :],
|
| 163 |
+
)
|
| 164 |
+
if attn.norm_cross:
|
| 165 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 166 |
+
|
| 167 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
| 168 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
| 169 |
+
|
| 170 |
+
query = attn.head_to_batch_dim(query)
|
| 171 |
+
key = attn.head_to_batch_dim(key)
|
| 172 |
+
value = attn.head_to_batch_dim(value)
|
| 173 |
+
|
| 174 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 175 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 176 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 177 |
+
|
| 178 |
+
# for ip-adapter
|
| 179 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 180 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
| 181 |
+
|
| 182 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
| 183 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
| 184 |
+
|
| 185 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
| 186 |
+
self.attn_map = ip_attention_probs
|
| 187 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
| 188 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
| 189 |
+
|
| 190 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
| 191 |
+
|
| 192 |
+
# linear proj
|
| 193 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
| 194 |
+
# dropout
|
| 195 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 196 |
+
|
| 197 |
+
if input_ndim == 4:
|
| 198 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 199 |
+
|
| 200 |
+
if attn.residual_connection:
|
| 201 |
+
hidden_states = hidden_states + residual
|
| 202 |
+
|
| 203 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 204 |
+
|
| 205 |
+
return hidden_states
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
| 209 |
+
|
| 210 |
+
r"""
|
| 211 |
+
Default processor for performing attention-related computations.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
hidden_size=None,
|
| 217 |
+
cross_attention_dim=None,
|
| 218 |
+
rank=4,
|
| 219 |
+
network_alpha=None,
|
| 220 |
+
lora_scale=1.0,
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.rank = rank
|
| 225 |
+
self.lora_scale = lora_scale
|
| 226 |
+
|
| 227 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 228 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 229 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 230 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 231 |
+
|
| 232 |
+
def __call__(
|
| 233 |
+
self,
|
| 234 |
+
attn,
|
| 235 |
+
hidden_states,
|
| 236 |
+
encoder_hidden_states=None,
|
| 237 |
+
attention_mask=None,
|
| 238 |
+
temb=None,
|
| 239 |
+
):
|
| 240 |
+
residual = hidden_states
|
| 241 |
+
|
| 242 |
+
if attn.spatial_norm is not None:
|
| 243 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 244 |
+
|
| 245 |
+
input_ndim = hidden_states.ndim
|
| 246 |
+
|
| 247 |
+
if input_ndim == 4:
|
| 248 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 249 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 250 |
+
|
| 251 |
+
batch_size, sequence_length, _ = (
|
| 252 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 253 |
+
)
|
| 254 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 255 |
+
|
| 256 |
+
if attn.group_norm is not None:
|
| 257 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 258 |
+
|
| 259 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
| 260 |
+
|
| 261 |
+
if encoder_hidden_states is None:
|
| 262 |
+
encoder_hidden_states = hidden_states
|
| 263 |
+
elif attn.norm_cross:
|
| 264 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 265 |
+
|
| 266 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
| 267 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
| 268 |
+
|
| 269 |
+
inner_dim = key.shape[-1]
|
| 270 |
+
head_dim = inner_dim // attn.heads
|
| 271 |
+
|
| 272 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 273 |
+
|
| 274 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 275 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 276 |
+
|
| 277 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 278 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 279 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 280 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 284 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 285 |
+
|
| 286 |
+
# linear proj
|
| 287 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
| 288 |
+
# dropout
|
| 289 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 290 |
+
|
| 291 |
+
if input_ndim == 4:
|
| 292 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 293 |
+
|
| 294 |
+
if attn.residual_connection:
|
| 295 |
+
hidden_states = hidden_states + residual
|
| 296 |
+
|
| 297 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 298 |
+
|
| 299 |
+
return hidden_states
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class LoRAIPAttnProcessor2_0(nn.Module):
|
| 303 |
+
r"""
|
| 304 |
+
Processor for implementing the LoRA attention mechanism.
|
| 305 |
+
Args:
|
| 306 |
+
hidden_size (`int`, *optional*):
|
| 307 |
+
The hidden size of the attention layer.
|
| 308 |
+
cross_attention_dim (`int`, *optional*):
|
| 309 |
+
The number of channels in the `encoder_hidden_states`.
|
| 310 |
+
rank (`int`, defaults to 4):
|
| 311 |
+
The dimension of the LoRA update matrices.
|
| 312 |
+
network_alpha (`int`, *optional*):
|
| 313 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
|
| 317 |
+
super().__init__()
|
| 318 |
+
|
| 319 |
+
self.rank = rank
|
| 320 |
+
self.lora_scale = lora_scale
|
| 321 |
+
self.num_tokens = num_tokens
|
| 322 |
+
|
| 323 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 324 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 325 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
| 326 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
self.hidden_size = hidden_size
|
| 330 |
+
self.cross_attention_dim = cross_attention_dim
|
| 331 |
+
self.scale = scale
|
| 332 |
+
|
| 333 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 334 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 335 |
+
|
| 336 |
+
def __call__(
|
| 337 |
+
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
| 338 |
+
):
|
| 339 |
+
residual = hidden_states
|
| 340 |
+
|
| 341 |
+
if attn.spatial_norm is not None:
|
| 342 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 343 |
+
|
| 344 |
+
input_ndim = hidden_states.ndim
|
| 345 |
+
|
| 346 |
+
if input_ndim == 4:
|
| 347 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 348 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 349 |
+
|
| 350 |
+
batch_size, sequence_length, _ = (
|
| 351 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 352 |
+
)
|
| 353 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 354 |
+
|
| 355 |
+
if attn.group_norm is not None:
|
| 356 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 357 |
+
|
| 358 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
| 359 |
+
#query = attn.head_to_batch_dim(query)
|
| 360 |
+
|
| 361 |
+
if encoder_hidden_states is None:
|
| 362 |
+
encoder_hidden_states = hidden_states
|
| 363 |
+
else:
|
| 364 |
+
# get encoder_hidden_states, ip_hidden_states
|
| 365 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
| 366 |
+
encoder_hidden_states, ip_hidden_states = (
|
| 367 |
+
encoder_hidden_states[:, :end_pos, :],
|
| 368 |
+
encoder_hidden_states[:, end_pos:, :],
|
| 369 |
+
)
|
| 370 |
+
if attn.norm_cross:
|
| 371 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 372 |
+
|
| 373 |
+
# for text
|
| 374 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
| 375 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
| 376 |
+
|
| 377 |
+
inner_dim = key.shape[-1]
|
| 378 |
+
head_dim = inner_dim // attn.heads
|
| 379 |
+
|
| 380 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 381 |
+
|
| 382 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 383 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 384 |
+
|
| 385 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 386 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 387 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 388 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 392 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 393 |
+
|
| 394 |
+
# for ip
|
| 395 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 396 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
| 397 |
+
|
| 398 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 399 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 400 |
+
|
| 401 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 402 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 403 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
| 404 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 409 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
| 410 |
+
|
| 411 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
| 412 |
+
|
| 413 |
+
# linear proj
|
| 414 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
| 415 |
+
# dropout
|
| 416 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 417 |
+
|
| 418 |
+
if input_ndim == 4:
|
| 419 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 420 |
+
|
| 421 |
+
if attn.residual_connection:
|
| 422 |
+
hidden_states = hidden_states + residual
|
| 423 |
+
|
| 424 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 425 |
+
|
| 426 |
+
return hidden_states
|
helper.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import uuid
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
def parse_prompt_attention(text):
|
| 7 |
+
re_attention = re.compile(r"""
|
| 8 |
+
\\\(|
|
| 9 |
+
\\\)|
|
| 10 |
+
\\\[|
|
| 11 |
+
\\]|
|
| 12 |
+
\\\\|
|
| 13 |
+
\\|
|
| 14 |
+
\(|
|
| 15 |
+
\[|
|
| 16 |
+
:([+-]?[.\d]+)\)|
|
| 17 |
+
\)|
|
| 18 |
+
]|
|
| 19 |
+
[^\\()\[\]:]+|
|
| 20 |
+
:
|
| 21 |
+
""", re.X)
|
| 22 |
+
|
| 23 |
+
res = []
|
| 24 |
+
round_brackets = []
|
| 25 |
+
square_brackets = []
|
| 26 |
+
|
| 27 |
+
round_bracket_multiplier = 1.1
|
| 28 |
+
square_bracket_multiplier = 1 / 1.1
|
| 29 |
+
|
| 30 |
+
def multiply_range(start_position, multiplier):
|
| 31 |
+
for p in range(start_position, len(res)):
|
| 32 |
+
res[p][1] *= multiplier
|
| 33 |
+
|
| 34 |
+
for m in re_attention.finditer(text):
|
| 35 |
+
text = m.group(0)
|
| 36 |
+
weight = m.group(1)
|
| 37 |
+
|
| 38 |
+
if text.startswith('\\'):
|
| 39 |
+
res.append([text[1:], 1.0])
|
| 40 |
+
elif text == '(':
|
| 41 |
+
round_brackets.append(len(res))
|
| 42 |
+
elif text == '[':
|
| 43 |
+
square_brackets.append(len(res))
|
| 44 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 45 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 46 |
+
elif text == ')' and len(round_brackets) > 0:
|
| 47 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 48 |
+
elif text == ']' and len(square_brackets) > 0:
|
| 49 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 50 |
+
else:
|
| 51 |
+
parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
|
| 52 |
+
for i, part in enumerate(parts):
|
| 53 |
+
if i > 0:
|
| 54 |
+
res.append(["BREAK", -1])
|
| 55 |
+
res.append([part, 1.0])
|
| 56 |
+
|
| 57 |
+
for pos in round_brackets:
|
| 58 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 59 |
+
|
| 60 |
+
for pos in square_brackets:
|
| 61 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 62 |
+
|
| 63 |
+
if len(res) == 0:
|
| 64 |
+
res = [["", 1.0]]
|
| 65 |
+
|
| 66 |
+
# merge runs of identical weights
|
| 67 |
+
i = 0
|
| 68 |
+
while i + 1 < len(res):
|
| 69 |
+
if res[i][1] == res[i + 1][1]:
|
| 70 |
+
res[i][0] += res[i + 1][0]
|
| 71 |
+
res.pop(i + 1)
|
| 72 |
+
else:
|
| 73 |
+
i += 1
|
| 74 |
+
|
| 75 |
+
return res
|
| 76 |
+
|
| 77 |
+
def prompt_attention_to_invoke_prompt(attention):
|
| 78 |
+
tokens = []
|
| 79 |
+
for text, weight in attention:
|
| 80 |
+
# Round weight to 2 decimal places
|
| 81 |
+
weight = round(weight, 2)
|
| 82 |
+
if weight == 1.0:
|
| 83 |
+
tokens.append(text)
|
| 84 |
+
elif weight < 1.0:
|
| 85 |
+
if weight < 0.8:
|
| 86 |
+
tokens.append(f"({text}){weight}")
|
| 87 |
+
else:
|
| 88 |
+
tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
|
| 89 |
+
else:
|
| 90 |
+
if weight < 1.3:
|
| 91 |
+
tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
|
| 92 |
+
else:
|
| 93 |
+
tokens.append(f"({text}){weight}")
|
| 94 |
+
return "".join(tokens)
|
| 95 |
+
|
| 96 |
+
def concat_tensor(t):
|
| 97 |
+
t_list = torch.split(t, 1, dim=0)
|
| 98 |
+
t = torch.cat(t_list, dim=1)
|
| 99 |
+
return t
|
| 100 |
+
|
| 101 |
+
def merge_embeds(prompt_chanks, compel):
|
| 102 |
+
num_chanks = len(prompt_chanks)
|
| 103 |
+
if num_chanks != 0:
|
| 104 |
+
power_prompt = 1/(num_chanks*(num_chanks+1)//2)
|
| 105 |
+
prompt_embs = compel(prompt_chanks)
|
| 106 |
+
t_list = list(torch.split(prompt_embs, 1, dim=0))
|
| 107 |
+
for i in range(num_chanks):
|
| 108 |
+
t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
|
| 109 |
+
prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
|
| 110 |
+
else:
|
| 111 |
+
prompt_emb = compel('')
|
| 112 |
+
return prompt_emb
|
| 113 |
+
|
| 114 |
+
def detokenize(chunk, actual_prompt):
|
| 115 |
+
chunk[-1] = chunk[-1].replace('</w>', '')
|
| 116 |
+
chanked_prompt = ''.join(chunk).strip()
|
| 117 |
+
while '</w>' in chanked_prompt:
|
| 118 |
+
if actual_prompt[chanked_prompt.find('</w>')] == ' ':
|
| 119 |
+
chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
|
| 120 |
+
else:
|
| 121 |
+
chanked_prompt = chanked_prompt.replace('</w>', '', 1)
|
| 122 |
+
actual_prompt = actual_prompt.replace(chanked_prompt,'')
|
| 123 |
+
return chanked_prompt.strip(), actual_prompt.strip()
|
| 124 |
+
|
| 125 |
+
def tokenize_line(line, tokenizer): # split into chunks
|
| 126 |
+
actual_prompt = line.lower().strip()
|
| 127 |
+
actual_tokens = tokenizer.tokenize(actual_prompt)
|
| 128 |
+
max_tokens = tokenizer.model_max_length - 2
|
| 129 |
+
comma_token = tokenizer.tokenize(',')[0]
|
| 130 |
+
|
| 131 |
+
chunks = []
|
| 132 |
+
chunk = []
|
| 133 |
+
for item in actual_tokens:
|
| 134 |
+
chunk.append(item)
|
| 135 |
+
if len(chunk) == max_tokens:
|
| 136 |
+
if chunk[-1] != comma_token:
|
| 137 |
+
for i in range(max_tokens-1, -1, -1):
|
| 138 |
+
if chunk[i] == comma_token:
|
| 139 |
+
actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
|
| 140 |
+
chunks.append(actual_chunk)
|
| 141 |
+
chunk = chunk[i+1:]
|
| 142 |
+
break
|
| 143 |
+
else:
|
| 144 |
+
actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
|
| 145 |
+
chunks.append(actual_chunk)
|
| 146 |
+
chunk = []
|
| 147 |
+
else:
|
| 148 |
+
actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
|
| 149 |
+
chunks.append(actual_chunk)
|
| 150 |
+
chunk = []
|
| 151 |
+
if chunk:
|
| 152 |
+
actual_chunk, _ = detokenize(chunk, actual_prompt)
|
| 153 |
+
chunks.append(actual_chunk)
|
| 154 |
+
|
| 155 |
+
return chunks
|
| 156 |
+
|
| 157 |
+
def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):
|
| 158 |
+
|
| 159 |
+
if compel_process_sd:
|
| 160 |
+
return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel)
|
| 161 |
+
else:
|
| 162 |
+
# fix bug weights conversion excessive emphasis
|
| 163 |
+
prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
|
| 164 |
+
|
| 165 |
+
# Convert to Compel
|
| 166 |
+
attention = parse_prompt_attention(prompt)
|
| 167 |
+
global_attention_chanks = []
|
| 168 |
+
|
| 169 |
+
for att in attention:
|
| 170 |
+
for chank in att[0].split(','):
|
| 171 |
+
temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer)
|
| 172 |
+
for small_chank in temp_prompt_chanks:
|
| 173 |
+
temp_dict = {
|
| 174 |
+
"weight": round(att[1], 2),
|
| 175 |
+
"lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')),
|
| 176 |
+
"prompt": f'{small_chank},'
|
| 177 |
+
}
|
| 178 |
+
global_attention_chanks.append(temp_dict)
|
| 179 |
+
|
| 180 |
+
max_tokens = pipeline.tokenizer.model_max_length - 2
|
| 181 |
+
global_prompt_chanks = []
|
| 182 |
+
current_list = []
|
| 183 |
+
current_length = 0
|
| 184 |
+
for item in global_attention_chanks:
|
| 185 |
+
if current_length + item['lenght'] > max_tokens:
|
| 186 |
+
global_prompt_chanks.append(current_list)
|
| 187 |
+
current_list = [[item['prompt'], item['weight']]]
|
| 188 |
+
current_length = item['lenght']
|
| 189 |
+
else:
|
| 190 |
+
if not current_list:
|
| 191 |
+
current_list.append([item['prompt'], item['weight']])
|
| 192 |
+
else:
|
| 193 |
+
if item['weight'] != current_list[-1][1]:
|
| 194 |
+
current_list.append([item['prompt'], item['weight']])
|
| 195 |
+
else:
|
| 196 |
+
current_list[-1][0] += f" {item['prompt']}"
|
| 197 |
+
current_length += item['lenght']
|
| 198 |
+
if current_list:
|
| 199 |
+
global_prompt_chanks.append(current_list)
|
| 200 |
+
|
| 201 |
+
if only_convert_string:
|
| 202 |
+
return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks])
|
| 203 |
+
|
| 204 |
+
return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel)
|
| 205 |
+
|
| 206 |
+
def add_comma_after_pattern_ti(text):
|
| 207 |
+
pattern = re.compile(r'\b\w+_\d+\b')
|
| 208 |
+
modified_text = pattern.sub(lambda x: x.group() + ',', text)
|
| 209 |
+
return modified_text
|
| 210 |
+
|
| 211 |
+
def save_image(img):
|
| 212 |
+
path = "./tmp/"
|
| 213 |
+
|
| 214 |
+
# Check if the input is a string (file path) and load the image if it is
|
| 215 |
+
if isinstance(img, str):
|
| 216 |
+
img = Image.open(img) # Load the image from the file path
|
| 217 |
+
|
| 218 |
+
# Ensure the Hugging Face path exists locally
|
| 219 |
+
if not os.path.exists(path):
|
| 220 |
+
os.makedirs(path)
|
| 221 |
+
|
| 222 |
+
# Generate a unique filename
|
| 223 |
+
unique_name = str(uuid.uuid4()) + ".webp"
|
| 224 |
+
unique_name = os.path.join(path, unique_name)
|
| 225 |
+
|
| 226 |
+
# Convert the image to WebP format
|
| 227 |
+
webp_img = img.convert("RGB") # Ensure the image is in RGB mode
|
| 228 |
+
|
| 229 |
+
# Save the image in WebP format with high quality
|
| 230 |
+
webp_img.save(unique_name, "WEBP", quality=90)
|
| 231 |
+
|
| 232 |
+
# Open the saved WebP file and return it as a PIL Image object
|
| 233 |
+
with Image.open(unique_name) as webp_file:
|
| 234 |
+
webp_image = webp_file.copy()
|
| 235 |
+
|
| 236 |
+
return unique_name
|
ipown.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers import StableDiffusionPipeline
|
| 6 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from safetensors import safe_open
|
| 9 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 10 |
+
|
| 11 |
+
from attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
| 12 |
+
from utils import is_torch2_available
|
| 13 |
+
|
| 14 |
+
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
|
| 15 |
+
if is_torch2_available() and (not USE_DAFAULT_ATTN):
|
| 16 |
+
from attention_processor_faceid import (
|
| 17 |
+
LoRAAttnProcessor2_0 as LoRAAttnProcessor,
|
| 18 |
+
)
|
| 19 |
+
from attention_processor_faceid import (
|
| 20 |
+
LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
|
| 21 |
+
)
|
| 22 |
+
else:
|
| 23 |
+
from attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
| 24 |
+
from resampler import PerceiverAttention, FeedForward
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FacePerceiverResampler(torch.nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
*,
|
| 31 |
+
dim=768,
|
| 32 |
+
depth=4,
|
| 33 |
+
dim_head=64,
|
| 34 |
+
heads=16,
|
| 35 |
+
embedding_dim=1280,
|
| 36 |
+
output_dim=768,
|
| 37 |
+
ff_mult=4,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.proj_in = torch.nn.Linear(embedding_dim, dim)
|
| 42 |
+
self.proj_out = torch.nn.Linear(dim, output_dim)
|
| 43 |
+
self.norm_out = torch.nn.LayerNorm(output_dim)
|
| 44 |
+
self.layers = torch.nn.ModuleList([])
|
| 45 |
+
for _ in range(depth):
|
| 46 |
+
self.layers.append(
|
| 47 |
+
torch.nn.ModuleList(
|
| 48 |
+
[
|
| 49 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 50 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 51 |
+
]
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, latents, x):
|
| 56 |
+
x = self.proj_in(x)
|
| 57 |
+
for attn, ff in self.layers:
|
| 58 |
+
latents = attn(x, latents) + latents
|
| 59 |
+
latents = ff(latents) + latents
|
| 60 |
+
latents = self.proj_out(latents)
|
| 61 |
+
return self.norm_out(latents)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MLPProjModel(torch.nn.Module):
|
| 65 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.cross_attention_dim = cross_attention_dim
|
| 69 |
+
self.num_tokens = num_tokens
|
| 70 |
+
|
| 71 |
+
self.proj = torch.nn.Sequential(
|
| 72 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
| 73 |
+
torch.nn.GELU(),
|
| 74 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
| 75 |
+
)
|
| 76 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 77 |
+
|
| 78 |
+
def forward(self, id_embeds):
|
| 79 |
+
x = self.proj(id_embeds)
|
| 80 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
| 81 |
+
x = self.norm(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ProjPlusModel(torch.nn.Module):
|
| 86 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.cross_attention_dim = cross_attention_dim
|
| 90 |
+
self.num_tokens = num_tokens
|
| 91 |
+
|
| 92 |
+
self.proj = torch.nn.Sequential(
|
| 93 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
| 94 |
+
torch.nn.GELU(),
|
| 95 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
| 96 |
+
)
|
| 97 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 98 |
+
|
| 99 |
+
self.perceiver_resampler = FacePerceiverResampler(
|
| 100 |
+
dim=cross_attention_dim,
|
| 101 |
+
depth=4,
|
| 102 |
+
dim_head=64,
|
| 103 |
+
heads=cross_attention_dim // 64,
|
| 104 |
+
embedding_dim=clip_embeddings_dim,
|
| 105 |
+
output_dim=cross_attention_dim,
|
| 106 |
+
ff_mult=4,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
|
| 110 |
+
|
| 111 |
+
x = self.proj(id_embeds)
|
| 112 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
| 113 |
+
x = self.norm(x)
|
| 114 |
+
out = self.perceiver_resampler(x, clip_embeds)
|
| 115 |
+
if shortcut:
|
| 116 |
+
out = x + scale * out
|
| 117 |
+
return out
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class IPAdapterFaceID:
|
| 121 |
+
def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
|
| 122 |
+
self.device = device
|
| 123 |
+
self.ip_ckpt = ip_ckpt
|
| 124 |
+
self.lora_rank = lora_rank
|
| 125 |
+
self.num_tokens = num_tokens
|
| 126 |
+
self.torch_dtype = torch_dtype
|
| 127 |
+
|
| 128 |
+
self.pipe = sd_pipe.to(self.device)
|
| 129 |
+
self.set_ip_adapter()
|
| 130 |
+
|
| 131 |
+
# image proj model
|
| 132 |
+
self.image_proj_model = self.init_proj()
|
| 133 |
+
|
| 134 |
+
self.load_ip_adapter()
|
| 135 |
+
|
| 136 |
+
def init_proj(self):
|
| 137 |
+
image_proj_model = MLPProjModel(
|
| 138 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
| 139 |
+
id_embeddings_dim=512,
|
| 140 |
+
num_tokens=self.num_tokens,
|
| 141 |
+
).to(self.device, dtype=self.torch_dtype)
|
| 142 |
+
return image_proj_model
|
| 143 |
+
|
| 144 |
+
def set_ip_adapter(self):
|
| 145 |
+
unet = self.pipe.unet
|
| 146 |
+
attn_procs = {}
|
| 147 |
+
for name in unet.attn_processors.keys():
|
| 148 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 149 |
+
if name.startswith("mid_block"):
|
| 150 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 151 |
+
elif name.startswith("up_blocks"):
|
| 152 |
+
block_id = int(name[len("up_blocks.")])
|
| 153 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 154 |
+
elif name.startswith("down_blocks"):
|
| 155 |
+
block_id = int(name[len("down_blocks.")])
|
| 156 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 157 |
+
if cross_attention_dim is None:
|
| 158 |
+
attn_procs[name] = LoRAAttnProcessor(
|
| 159 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
|
| 160 |
+
).to(self.device, dtype=self.torch_dtype)
|
| 161 |
+
else:
|
| 162 |
+
attn_procs[name] = LoRAIPAttnProcessor(
|
| 163 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
|
| 164 |
+
).to(self.device, dtype=self.torch_dtype)
|
| 165 |
+
unet.set_attn_processor(attn_procs)
|
| 166 |
+
|
| 167 |
+
def load_ip_adapter(self):
|
| 168 |
+
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
| 169 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
| 170 |
+
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
| 171 |
+
for key in f.keys():
|
| 172 |
+
if key.startswith("image_proj."):
|
| 173 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
| 174 |
+
elif key.startswith("ip_adapter."):
|
| 175 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
| 176 |
+
else:
|
| 177 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
| 178 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
| 179 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
| 180 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
| 181 |
+
|
| 182 |
+
@torch.inference_mode()
|
| 183 |
+
def get_image_embeds(self, faceid_embeds):
|
| 184 |
+
|
| 185 |
+
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
|
| 186 |
+
print(faceid_embeds.device)
|
| 187 |
+
print(next(self.image_proj_model.parameters()).device)
|
| 188 |
+
image_prompt_embeds = self.image_proj_model(faceid_embeds)
|
| 189 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
|
| 190 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
| 191 |
+
|
| 192 |
+
def set_scale(self, scale):
|
| 193 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
| 194 |
+
if isinstance(attn_processor, LoRAIPAttnProcessor):
|
| 195 |
+
attn_processor.scale = scale
|
| 196 |
+
|
| 197 |
+
def generate(
|
| 198 |
+
self,
|
| 199 |
+
faceid_embeds=None,
|
| 200 |
+
prompt=None,
|
| 201 |
+
negative_prompt=None,
|
| 202 |
+
scale=1.0,
|
| 203 |
+
num_samples=4,
|
| 204 |
+
seed=None,
|
| 205 |
+
guidance_scale=7.5,
|
| 206 |
+
num_inference_steps=30,
|
| 207 |
+
**kwargs,
|
| 208 |
+
):
|
| 209 |
+
self.set_scale(scale)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
num_prompts = faceid_embeds.size(0)
|
| 213 |
+
|
| 214 |
+
if prompt is None:
|
| 215 |
+
prompt = "best quality, high quality"
|
| 216 |
+
if negative_prompt is None:
|
| 217 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
| 218 |
+
|
| 219 |
+
if not isinstance(prompt, List):
|
| 220 |
+
prompt = [prompt] * num_prompts
|
| 221 |
+
if not isinstance(negative_prompt, List):
|
| 222 |
+
negative_prompt = [negative_prompt] * num_prompts
|
| 223 |
+
|
| 224 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
|
| 225 |
+
|
| 226 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
| 227 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
| 228 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 229 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
| 230 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 231 |
+
|
| 232 |
+
with torch.inference_mode():
|
| 233 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
| 234 |
+
prompt,
|
| 235 |
+
device=self.device,
|
| 236 |
+
num_images_per_prompt=num_samples,
|
| 237 |
+
do_classifier_free_guidance=True,
|
| 238 |
+
negative_prompt=negative_prompt,
|
| 239 |
+
)
|
| 240 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
| 241 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
| 242 |
+
|
| 243 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
| 244 |
+
images = self.pipe(
|
| 245 |
+
prompt_embeds=prompt_embeds,
|
| 246 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 247 |
+
guidance_scale=guidance_scale,
|
| 248 |
+
num_inference_steps=num_inference_steps,
|
| 249 |
+
generator=generator,
|
| 250 |
+
**kwargs,
|
| 251 |
+
).images
|
| 252 |
+
|
| 253 |
+
return images
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class IPAdapterFaceIDPlus:
|
| 257 |
+
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
|
| 258 |
+
self.device = device
|
| 259 |
+
self.image_encoder_path = image_encoder_path
|
| 260 |
+
self.ip_ckpt = ip_ckpt
|
| 261 |
+
self.lora_rank = lora_rank
|
| 262 |
+
self.num_tokens = num_tokens
|
| 263 |
+
self.torch_dtype = torch_dtype
|
| 264 |
+
|
| 265 |
+
self.pipe = sd_pipe.to(self.device)
|
| 266 |
+
self.set_ip_adapter()
|
| 267 |
+
|
| 268 |
+
# load image encoder
|
| 269 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
| 270 |
+
self.device, dtype=self.torch_dtype
|
| 271 |
+
)
|
| 272 |
+
self.clip_image_processor = CLIPImageProcessor()
|
| 273 |
+
# image proj model
|
| 274 |
+
self.image_proj_model = self.init_proj()
|
| 275 |
+
|
| 276 |
+
self.load_ip_adapter()
|
| 277 |
+
|
| 278 |
+
def init_proj(self):
|
| 279 |
+
image_proj_model = ProjPlusModel(
|
| 280 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
| 281 |
+
id_embeddings_dim=512,
|
| 282 |
+
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
| 283 |
+
num_tokens=self.num_tokens,
|
| 284 |
+
).to(self.device, dtype=self.torch_dtype)
|
| 285 |
+
return image_proj_model
|
| 286 |
+
|
| 287 |
+
def set_ip_adapter(self):
|
| 288 |
+
unet = self.pipe.unet
|
| 289 |
+
attn_procs = {}
|
| 290 |
+
for name in unet.attn_processors.keys():
|
| 291 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 292 |
+
if name.startswith("mid_block"):
|
| 293 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 294 |
+
elif name.startswith("up_blocks"):
|
| 295 |
+
block_id = int(name[len("up_blocks.")])
|
| 296 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 297 |
+
elif name.startswith("down_blocks"):
|
| 298 |
+
block_id = int(name[len("down_blocks.")])
|
| 299 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 300 |
+
if cross_attention_dim is None:
|
| 301 |
+
attn_procs[name] = LoRAAttnProcessor(
|
| 302 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
|
| 303 |
+
).to(self.device, dtype=self.torch_dtype)
|
| 304 |
+
else:
|
| 305 |
+
attn_procs[name] = LoRAIPAttnProcessor(
|
| 306 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
|
| 307 |
+
).to(self.device, dtype=self.torch_dtype)
|
| 308 |
+
unet.set_attn_processor(attn_procs)
|
| 309 |
+
|
| 310 |
+
def load_ip_adapter(self):
|
| 311 |
+
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
| 312 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
| 313 |
+
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
| 314 |
+
for key in f.keys():
|
| 315 |
+
if key.startswith("image_proj."):
|
| 316 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
| 317 |
+
elif key.startswith("ip_adapter."):
|
| 318 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
| 319 |
+
else:
|
| 320 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
| 321 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
| 322 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
| 323 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
| 324 |
+
|
| 325 |
+
@torch.inference_mode()
|
| 326 |
+
def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
|
| 327 |
+
if isinstance(face_image, Image.Image):
|
| 328 |
+
pil_image = [face_image]
|
| 329 |
+
clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
|
| 330 |
+
clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
|
| 331 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
| 332 |
+
uncond_clip_image_embeds = self.image_encoder(
|
| 333 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
| 334 |
+
).hidden_states[-2]
|
| 335 |
+
|
| 336 |
+
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
|
| 337 |
+
image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
| 338 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
| 339 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
| 340 |
+
|
| 341 |
+
def set_scale(self, scale):
|
| 342 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
| 343 |
+
if isinstance(attn_processor, LoRAIPAttnProcessor):
|
| 344 |
+
attn_processor.scale = scale
|
| 345 |
+
|
| 346 |
+
def generate(
|
| 347 |
+
self,
|
| 348 |
+
face_image=None,
|
| 349 |
+
faceid_embeds=None,
|
| 350 |
+
prompt=None,
|
| 351 |
+
negative_prompt=None,
|
| 352 |
+
scale=1.0,
|
| 353 |
+
num_samples=4,
|
| 354 |
+
seed=None,
|
| 355 |
+
guidance_scale=7.5,
|
| 356 |
+
num_inference_steps=30,
|
| 357 |
+
s_scale=1.0,
|
| 358 |
+
shortcut=False,
|
| 359 |
+
**kwargs,
|
| 360 |
+
):
|
| 361 |
+
self.set_scale(scale)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
num_prompts = faceid_embeds.size(0)
|
| 365 |
+
|
| 366 |
+
if prompt is None:
|
| 367 |
+
prompt = "best quality, high quality"
|
| 368 |
+
if negative_prompt is None:
|
| 369 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
| 370 |
+
|
| 371 |
+
if not isinstance(prompt, List):
|
| 372 |
+
prompt = [prompt] * num_prompts
|
| 373 |
+
if not isinstance(negative_prompt, List):
|
| 374 |
+
negative_prompt = [negative_prompt] * num_prompts
|
| 375 |
+
|
| 376 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
|
| 377 |
+
|
| 378 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
| 379 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
| 380 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 381 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
| 382 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 383 |
+
|
| 384 |
+
with torch.inference_mode():
|
| 385 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
| 386 |
+
prompt,
|
| 387 |
+
device=self.device,
|
| 388 |
+
num_images_per_prompt=num_samples,
|
| 389 |
+
do_classifier_free_guidance=True,
|
| 390 |
+
negative_prompt=negative_prompt,
|
| 391 |
+
)
|
| 392 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
| 393 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
| 394 |
+
|
| 395 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
| 396 |
+
images = self.pipe(
|
| 397 |
+
prompt_embeds=prompt_embeds,
|
| 398 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 399 |
+
guidance_scale=guidance_scale,
|
| 400 |
+
num_inference_steps=num_inference_steps,
|
| 401 |
+
generator=generator,
|
| 402 |
+
**kwargs,
|
| 403 |
+
).images
|
| 404 |
+
|
| 405 |
+
return images
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class IPAdapterFaceIDXL(IPAdapterFaceID):
|
| 409 |
+
"""SDXL"""
|
| 410 |
+
|
| 411 |
+
def generate(
|
| 412 |
+
self,
|
| 413 |
+
faceid_embeds=None,
|
| 414 |
+
prompt=None,
|
| 415 |
+
negative_prompt=None,
|
| 416 |
+
scale=1.0,
|
| 417 |
+
num_samples=4,
|
| 418 |
+
seed=None,
|
| 419 |
+
num_inference_steps=30,
|
| 420 |
+
**kwargs,
|
| 421 |
+
):
|
| 422 |
+
self.set_scale(scale)
|
| 423 |
+
|
| 424 |
+
num_prompts = faceid_embeds.size(0)
|
| 425 |
+
|
| 426 |
+
if prompt is None:
|
| 427 |
+
prompt = "best quality, high quality"
|
| 428 |
+
if negative_prompt is None:
|
| 429 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
| 430 |
+
|
| 431 |
+
if not isinstance(prompt, List):
|
| 432 |
+
prompt = [prompt] * num_prompts
|
| 433 |
+
if not isinstance(negative_prompt, List):
|
| 434 |
+
negative_prompt = [negative_prompt] * num_prompts
|
| 435 |
+
|
| 436 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
|
| 437 |
+
|
| 438 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
| 439 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
| 440 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 441 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
| 442 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 443 |
+
|
| 444 |
+
with torch.inference_mode():
|
| 445 |
+
(
|
| 446 |
+
prompt_embeds,
|
| 447 |
+
negative_prompt_embeds,
|
| 448 |
+
pooled_prompt_embeds,
|
| 449 |
+
negative_pooled_prompt_embeds,
|
| 450 |
+
) = self.pipe.encode_prompt(
|
| 451 |
+
prompt,
|
| 452 |
+
num_images_per_prompt=num_samples,
|
| 453 |
+
do_classifier_free_guidance=True,
|
| 454 |
+
negative_prompt=negative_prompt,
|
| 455 |
+
)
|
| 456 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
| 457 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
| 458 |
+
|
| 459 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
| 460 |
+
images = self.pipe(
|
| 461 |
+
prompt_embeds=prompt_embeds,
|
| 462 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 463 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 464 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 465 |
+
num_inference_steps=num_inference_steps,
|
| 466 |
+
generator=generator,
|
| 467 |
+
**kwargs,
|
| 468 |
+
).images
|
| 469 |
+
|
| 470 |
+
return images
|
loader.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from huggingface_hub import hf_hub_download
|
| 3 |
+
|
| 4 |
+
def load_script(file_str: str):
|
| 5 |
+
"""
|
| 6 |
+
file_str: something like 'myorg/myrepo/mysubfolder/myscript.py'
|
| 7 |
+
This function downloads the file from the Hugging Face Hub into ./ (current directory).
|
| 8 |
+
"""
|
| 9 |
+
try:
|
| 10 |
+
# Split the path by "/"
|
| 11 |
+
parts = file_str.split("/")
|
| 12 |
+
|
| 13 |
+
if len(parts) < 3:
|
| 14 |
+
raise ValueError(
|
| 15 |
+
f"Invalid file specification '{file_str}'. "
|
| 16 |
+
f"Expected format: 'repo_id/[subfolder]/filename'"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# First two parts form the repo_id (e.g. 'myorg/myrepo')
|
| 20 |
+
repo_id = "/".join(parts[:2])
|
| 21 |
+
|
| 22 |
+
# Last part is the actual filename (e.g. 'myscript.py')
|
| 23 |
+
filename = parts[-1]
|
| 24 |
+
|
| 25 |
+
# Anything between the second and last parts is a subfolder path
|
| 26 |
+
subfolder = None
|
| 27 |
+
if len(parts) > 3:
|
| 28 |
+
subfolder = "/".join(parts[2:-1])
|
| 29 |
+
|
| 30 |
+
# Retrieve HF token from environment
|
| 31 |
+
hf_token = os.getenv("HF_TOKEN", None)
|
| 32 |
+
|
| 33 |
+
# Download the file into current directory "."
|
| 34 |
+
file_path = hf_hub_download(
|
| 35 |
+
repo_id=repo_id,
|
| 36 |
+
filename=filename,
|
| 37 |
+
subfolder=subfolder,
|
| 38 |
+
token=hf_token,
|
| 39 |
+
local_dir="." # Download into current directory
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
print(f"Downloaded {filename} from {repo_id} to {file_path}")
|
| 43 |
+
return file_path
|
| 44 |
+
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Error downloading the script '{file_str}': {e}")
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_scripts():
|
| 51 |
+
"""
|
| 52 |
+
1. Get the path of the 'FILE_LIST' file from the environment variable FILE_LIST.
|
| 53 |
+
2. Download that file list using load_script().
|
| 54 |
+
3. Read its lines, and each line is another file to be downloaded using load_script().
|
| 55 |
+
4. After all lines are downloaded, execute the last file.
|
| 56 |
+
"""
|
| 57 |
+
file_list = os.getenv("FILE_LIST", "").strip()
|
| 58 |
+
if not file_list:
|
| 59 |
+
print("No FILE_LIST environment variable set. Nothing to download.")
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
# Step 1: Download the file list itself
|
| 63 |
+
file_list_path = load_script(file_list)
|
| 64 |
+
if not file_list_path or not os.path.exists(file_list_path):
|
| 65 |
+
print(f"Could not download or find file list: {file_list_path}")
|
| 66 |
+
return
|
| 67 |
+
|
| 68 |
+
# Step 2: Read each line in the downloaded file list
|
| 69 |
+
try:
|
| 70 |
+
with open(file_list_path, 'r') as f:
|
| 71 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error reading file list: {e}")
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
# Step 3: Download each file from the lines
|
| 77 |
+
downloaded_files = []
|
| 78 |
+
for file_str in lines:
|
| 79 |
+
file_path = load_script(file_str)
|
| 80 |
+
if file_path:
|
| 81 |
+
downloaded_files.append(file_path)
|
| 82 |
+
|
| 83 |
+
# Step 4: Execute the last downloaded file
|
| 84 |
+
if downloaded_files:
|
| 85 |
+
last_file_path = downloaded_files[-1]
|
| 86 |
+
print(f"Executing the last downloaded script: {last_file_path}")
|
| 87 |
+
try:
|
| 88 |
+
with open(last_file_path, 'r') as f:
|
| 89 |
+
exec(f.read(), globals())
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error executing the last downloaded script: {e}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Run the load_scripts function
|
| 95 |
+
load_scripts()
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
insightface==0.7.3
|
| 2 |
+
diffusers
|
| 3 |
+
transformers
|
| 4 |
+
accelerate
|
| 5 |
+
safetensors
|
| 6 |
+
einops
|
| 7 |
+
onnxruntime-gpu
|
| 8 |
+
spaces==0.19.4
|
| 9 |
+
opencv-python
|
| 10 |
+
pyjwt
|
| 11 |
+
torchsde
|
| 12 |
+
compel
|
| 13 |
+
hidiffusion
|
| 14 |
+
git+https://github.com/tencent-ailab/IP-Adapter.git
|
resampler.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
| 2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from einops.layers.torch import Rearrange
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# FFN
|
| 13 |
+
def FeedForward(dim, mult=4):
|
| 14 |
+
inner_dim = int(dim * mult)
|
| 15 |
+
return nn.Sequential(
|
| 16 |
+
nn.LayerNorm(dim),
|
| 17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
| 18 |
+
nn.GELU(),
|
| 19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def reshape_tensor(x, heads):
|
| 24 |
+
bs, length, width = x.shape
|
| 25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
| 26 |
+
x = x.view(bs, length, heads, -1)
|
| 27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
| 28 |
+
x = x.transpose(1, 2)
|
| 29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
| 30 |
+
x = x.reshape(bs, heads, length, -1)
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PerceiverAttention(nn.Module):
|
| 35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.scale = dim_head**-0.5
|
| 38 |
+
self.dim_head = dim_head
|
| 39 |
+
self.heads = heads
|
| 40 |
+
inner_dim = dim_head * heads
|
| 41 |
+
|
| 42 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 43 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 44 |
+
|
| 45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 48 |
+
|
| 49 |
+
def forward(self, x, latents):
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
x (torch.Tensor): image features
|
| 53 |
+
shape (b, n1, D)
|
| 54 |
+
latent (torch.Tensor): latent features
|
| 55 |
+
shape (b, n2, D)
|
| 56 |
+
"""
|
| 57 |
+
x = self.norm1(x)
|
| 58 |
+
latents = self.norm2(latents)
|
| 59 |
+
|
| 60 |
+
b, l, _ = latents.shape
|
| 61 |
+
|
| 62 |
+
q = self.to_q(latents)
|
| 63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
| 64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 65 |
+
|
| 66 |
+
q = reshape_tensor(q, self.heads)
|
| 67 |
+
k = reshape_tensor(k, self.heads)
|
| 68 |
+
v = reshape_tensor(v, self.heads)
|
| 69 |
+
|
| 70 |
+
# attention
|
| 71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| 72 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
| 73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 74 |
+
out = weight @ v
|
| 75 |
+
|
| 76 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
| 77 |
+
|
| 78 |
+
return self.to_out(out)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Resampler(nn.Module):
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
dim=1024,
|
| 85 |
+
depth=8,
|
| 86 |
+
dim_head=64,
|
| 87 |
+
heads=16,
|
| 88 |
+
num_queries=8,
|
| 89 |
+
embedding_dim=768,
|
| 90 |
+
output_dim=1024,
|
| 91 |
+
ff_mult=4,
|
| 92 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
| 93 |
+
apply_pos_emb: bool = False,
|
| 94 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
| 98 |
+
|
| 99 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
| 100 |
+
|
| 101 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
| 102 |
+
|
| 103 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
| 104 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
| 105 |
+
|
| 106 |
+
self.to_latents_from_mean_pooled_seq = (
|
| 107 |
+
nn.Sequential(
|
| 108 |
+
nn.LayerNorm(dim),
|
| 109 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
| 110 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
| 111 |
+
)
|
| 112 |
+
if num_latents_mean_pooled > 0
|
| 113 |
+
else None
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.layers = nn.ModuleList([])
|
| 117 |
+
for _ in range(depth):
|
| 118 |
+
self.layers.append(
|
| 119 |
+
nn.ModuleList(
|
| 120 |
+
[
|
| 121 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 122 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
if self.pos_emb is not None:
|
| 129 |
+
n, device = x.shape[1], x.device
|
| 130 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
| 131 |
+
x = x + pos_emb
|
| 132 |
+
|
| 133 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
| 134 |
+
|
| 135 |
+
x = self.proj_in(x)
|
| 136 |
+
|
| 137 |
+
if self.to_latents_from_mean_pooled_seq:
|
| 138 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
| 139 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
| 140 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
| 141 |
+
|
| 142 |
+
for attn, ff in self.layers:
|
| 143 |
+
latents = attn(x, latents) + latents
|
| 144 |
+
latents = ff(latents) + latents
|
| 145 |
+
|
| 146 |
+
latents = self.proj_out(latents)
|
| 147 |
+
return self.norm_out(latents)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def masked_mean(t, *, dim, mask=None):
|
| 151 |
+
if mask is None:
|
| 152 |
+
return t.mean(dim=dim)
|
| 153 |
+
|
| 154 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
| 155 |
+
mask = rearrange(mask, "b n -> b n 1")
|
| 156 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
| 157 |
+
|
| 158 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
utils.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import requests
|
| 6 |
+
from datetime import datetime,timedelta
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
attn_maps = {}
|
| 10 |
+
def hook_fn(name):
|
| 11 |
+
def forward_hook(module, input, output):
|
| 12 |
+
if hasattr(module.processor, "attn_map"):
|
| 13 |
+
attn_maps[name] = module.processor.attn_map
|
| 14 |
+
del module.processor.attn_map
|
| 15 |
+
|
| 16 |
+
return forward_hook
|
| 17 |
+
|
| 18 |
+
def register_cross_attention_hook(unet):
|
| 19 |
+
for name, module in unet.named_modules():
|
| 20 |
+
if name.split('.')[-1].startswith('attn2'):
|
| 21 |
+
module.register_forward_hook(hook_fn(name))
|
| 22 |
+
|
| 23 |
+
return unet
|
| 24 |
+
|
| 25 |
+
def upscale(attn_map, target_size):
|
| 26 |
+
attn_map = torch.mean(attn_map, dim=0)
|
| 27 |
+
attn_map = attn_map.permute(1,0)
|
| 28 |
+
temp_size = None
|
| 29 |
+
|
| 30 |
+
for i in range(0,5):
|
| 31 |
+
scale = 2 ** i
|
| 32 |
+
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
|
| 33 |
+
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
assert temp_size is not None, "temp_size cannot is None"
|
| 37 |
+
|
| 38 |
+
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
|
| 39 |
+
|
| 40 |
+
attn_map = F.interpolate(
|
| 41 |
+
attn_map.unsqueeze(0).to(dtype=torch.float32),
|
| 42 |
+
size=target_size,
|
| 43 |
+
mode='bilinear',
|
| 44 |
+
align_corners=False
|
| 45 |
+
)[0]
|
| 46 |
+
|
| 47 |
+
attn_map = torch.softmax(attn_map, dim=0)
|
| 48 |
+
return attn_map
|
| 49 |
+
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
|
| 50 |
+
|
| 51 |
+
idx = 0 if instance_or_negative else 1
|
| 52 |
+
net_attn_maps = []
|
| 53 |
+
|
| 54 |
+
for name, attn_map in attn_maps.items():
|
| 55 |
+
attn_map = attn_map.cpu() if detach else attn_map
|
| 56 |
+
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
|
| 57 |
+
attn_map = upscale(attn_map, image_size)
|
| 58 |
+
net_attn_maps.append(attn_map)
|
| 59 |
+
|
| 60 |
+
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
| 61 |
+
|
| 62 |
+
return net_attn_maps
|
| 63 |
+
|
| 64 |
+
def attnmaps2images(net_attn_maps):
|
| 65 |
+
|
| 66 |
+
#total_attn_scores = 0
|
| 67 |
+
images = []
|
| 68 |
+
|
| 69 |
+
for attn_map in net_attn_maps:
|
| 70 |
+
attn_map = attn_map.cpu().numpy()
|
| 71 |
+
#total_attn_scores += attn_map.mean().item()
|
| 72 |
+
|
| 73 |
+
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
|
| 74 |
+
normalized_attn_map = normalized_attn_map.astype(np.uint8)
|
| 75 |
+
#print("norm: ", normalized_attn_map.shape)
|
| 76 |
+
image = Image.fromarray(normalized_attn_map)
|
| 77 |
+
|
| 78 |
+
#image = fix_save_attn_map(attn_map)
|
| 79 |
+
images.append(image)
|
| 80 |
+
|
| 81 |
+
#print(total_attn_scores)
|
| 82 |
+
return images
|
| 83 |
+
def is_torch2_available():
|
| 84 |
+
return hasattr(F, "scaled_dot_product_attention")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class RemoteJson:
|
| 88 |
+
def __init__(self, url, refresh_gap_seconds=3600, processor=None):
|
| 89 |
+
"""
|
| 90 |
+
Initialize the RemoteJsonManager.
|
| 91 |
+
:param url: The URL of the remote JSON file.
|
| 92 |
+
:param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed.
|
| 93 |
+
:param processor: Optional callback function to process the JSON after it's loaded successfully.
|
| 94 |
+
"""
|
| 95 |
+
self.url = url
|
| 96 |
+
self.refresh_gap_seconds = refresh_gap_seconds
|
| 97 |
+
self.processor = processor
|
| 98 |
+
self.json_data = None
|
| 99 |
+
self.last_updated = None
|
| 100 |
+
|
| 101 |
+
def _load_json(self):
|
| 102 |
+
"""
|
| 103 |
+
Load JSON from the remote URL. If loading fails, return None.
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
response = requests.get(self.url)
|
| 107 |
+
response.raise_for_status()
|
| 108 |
+
return response.json()
|
| 109 |
+
except requests.RequestException as e:
|
| 110 |
+
print(f"Failed to fetch JSON: {e}")
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
def _should_refresh(self):
|
| 114 |
+
"""
|
| 115 |
+
Check whether the JSON should be refreshed based on the time gap.
|
| 116 |
+
"""
|
| 117 |
+
if not self.last_updated:
|
| 118 |
+
return True # If no last update, always refresh
|
| 119 |
+
return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds)
|
| 120 |
+
|
| 121 |
+
def _update_json(self):
|
| 122 |
+
"""
|
| 123 |
+
Fetch and load the JSON from the remote URL. If it fails, keep the previous data.
|
| 124 |
+
"""
|
| 125 |
+
new_json = self._load_json()
|
| 126 |
+
if new_json:
|
| 127 |
+
self.json_data = new_json
|
| 128 |
+
self.last_updated = datetime.now()
|
| 129 |
+
print("JSON updated successfully.")
|
| 130 |
+
if self.processor:
|
| 131 |
+
self.json_data = self.processor(self.json_data)
|
| 132 |
+
else:
|
| 133 |
+
print("Failed to update JSON. Keeping the previous version.")
|
| 134 |
+
|
| 135 |
+
def get(self):
|
| 136 |
+
"""
|
| 137 |
+
Get the JSON, checking whether it needs to be refreshed.
|
| 138 |
+
If refresh is required, it fetches the new data and applies the processor.
|
| 139 |
+
"""
|
| 140 |
+
if self._should_refresh():
|
| 141 |
+
print("Refreshing JSON...")
|
| 142 |
+
self._update_json()
|
| 143 |
+
else:
|
| 144 |
+
print("Using cached JSON.")
|
| 145 |
+
|
| 146 |
+
return self.json_data
|
| 147 |
+
|
| 148 |
+
def extract_key_value_pairs(input_string):
|
| 149 |
+
# Define the regular expression to match [xxx:yyy] where yyy can have special characters
|
| 150 |
+
pattern = r"\[([^\]]+):([^\]]+)\]"
|
| 151 |
+
|
| 152 |
+
# Find all matches in the input string with the original matching string
|
| 153 |
+
matches = re.finditer(pattern, input_string)
|
| 154 |
+
|
| 155 |
+
# Convert matches to a list of dictionaries including the raw matching string
|
| 156 |
+
result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches]
|
| 157 |
+
|
| 158 |
+
return result
|
| 159 |
+
|
| 160 |
+
def extract_characters(prefix, input_string):
|
| 161 |
+
# Define the regular expression to match placeholders starting with "@" and ending with space or comma
|
| 162 |
+
pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)"
|
| 163 |
+
|
| 164 |
+
# Find all matches in the input string
|
| 165 |
+
matches = re.findall(pattern, input_string)
|
| 166 |
+
|
| 167 |
+
# Return a list of dictionaries with the extracted placeholders
|
| 168 |
+
result = [{"raw": f"{prefix}{match}", "key": match} for match in matches]
|
| 169 |
+
|
| 170 |
+
return result
|