ccchenzc's picture
Init demo.
f2f17f4
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import ToTensor
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import math
def register_attn_control(unet, controller, cache=None):
def attn_forward(self):
def forward(
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if self.spatial_norm is not None:
hidden_states = self.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = self.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, self.heads, -1, attention_mask.shape[-1]
)
if self.group_norm is not None:
hidden_states = self.group_norm(
hidden_states.transpose(1, 2)
).transpose(1, 2)
q = self.to_q(hidden_states)
is_self = encoder_hidden_states is None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(
encoder_hidden_states
)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
inner_dim = k.shape[-1]
head_dim = inner_dim // self.heads
q = q.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
if is_self and controller.cur_self_layer in controller.self_layers:
cache.add(q, k, v, hidden_states)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, self.heads * head_dim
)
hidden_states = hidden_states.to(q.dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if self.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / self.rescale_output_factor
if is_self:
controller.cur_self_layer += 1
return hidden_states
return forward
def modify_forward(net, count):
for name, subnet in net.named_children():
if net.__class__.__name__ == "Attention": # spatial Transformer layer
net.forward = attn_forward(net)
return count + 1
elif hasattr(net, "children"):
count = modify_forward(subnet, count)
return count
cross_att_count = 0
for net_name, net in unet.named_children():
cross_att_count += modify_forward(net, 0)
controller.num_self_layers = cross_att_count // 2
def load_image(image_path, size=None, mode="RGB"):
img = Image.open(image_path).convert(mode)
if size is None:
width, height = img.size
new_width = (width // 64) * 64
new_height = (height // 64) * 64
size = (new_width, new_height)
img = img.resize(size, Image.BICUBIC)
return ToTensor()(img).unsqueeze(0)
def adain(source, target, eps=1e-6):
source_mean, source_std = torch.mean(source, dim=(2, 3), keepdim=True), torch.std(
source, dim=(2, 3), keepdim=True
)
target_mean, target_std = torch.mean(
target, dim=(0, 2, 3), keepdim=True
), torch.std(target, dim=(0, 2, 3), keepdim=True)
normalized_source = (source - source_mean) / (source_std + eps)
transferred_source = normalized_source * target_std + target_mean
return transferred_source
class Controller:
def step(self):
self.cur_self_layer = 0
def __init__(self, self_layers=(0, 16)):
self.num_self_layers = -1
self.cur_self_layer = 0
self.self_layers = list(range(*self_layers))
class DataCache:
def __init__(self):
self.q = []
self.k = []
self.v = []
self.out = []
def clear(self):
self.q.clear()
self.k.clear()
self.v.clear()
self.out.clear()
def add(self, q, k, v, out):
self.q.append(q)
self.k.append(k)
self.v.append(v)
self.out.append(out)
def get(self):
return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy()
def show_image(path, title, display_height=3, title_fontsize=12):
img = Image.open(path)
img_width, img_height = img.size
aspect_ratio = img_width / img_height
display_width = display_height * aspect_ratio
plt.figure(figsize=(display_width, display_height))
plt.imshow(img)
plt.title(title,
fontsize=title_fontsize,
fontweight='bold',
pad=20)
plt.axis('off')
plt.tight_layout()
plt.show()