Spaces:
Running
on
Zero
Running
on
Zero
# OFT network module | |
import math | |
import os | |
from typing import Dict, List, Optional, Tuple, Type, Union | |
from diffusers import AutoencoderKL | |
import einops | |
from transformers import CLIPTextModel | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import re | |
from library.utils import setup_logging | |
setup_logging() | |
import logging | |
logger = logging.getLogger(__name__) | |
class OFTModule(torch.nn.Module): | |
""" | |
replaces forward method of the original Linear, instead of replacing the original Linear module. | |
""" | |
def __init__( | |
self, | |
oft_name, | |
org_module: torch.nn.Module, | |
multiplier=1.0, | |
dim=4, | |
alpha=1, | |
split_dims: Optional[List[int]] = None, | |
): | |
""" | |
dim -> num blocks | |
alpha -> constraint | |
split_dims is used to mimic the split qkv of FLUX as same as Diffusers | |
""" | |
super().__init__() | |
self.oft_name = oft_name | |
self.num_blocks = dim | |
if type(alpha) == torch.Tensor: | |
alpha = alpha.detach().numpy() | |
self.register_buffer("alpha", torch.tensor(alpha)) | |
# No conv2d in FLUX | |
# if "Linear" in org_module.__class__.__name__: | |
self.out_dim = org_module.out_features | |
# elif "Conv" in org_module.__class__.__name__: | |
# out_dim = org_module.out_channels | |
if split_dims is None: | |
split_dims = [self.out_dim] | |
else: | |
assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" | |
self.split_dims = split_dims | |
# assert all dim is divisible by num_blocks | |
for split_dim in self.split_dims: | |
assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" | |
self.constraint = [alpha * split_dim for split_dim in self.split_dims] | |
self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] | |
self.oft_blocks = torch.nn.ParameterList( | |
[torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] | |
) | |
self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] | |
self.shape = org_module.weight.shape | |
self.multiplier = multiplier | |
self.org_module = [org_module] # moduleにならないようにlistに入れる | |
def apply_to(self): | |
self.org_forward = self.org_module[0].forward | |
self.org_module[0].forward = self.forward | |
def get_weight(self, multiplier=None): | |
if multiplier is None: | |
multiplier = self.multiplier | |
if self.I[0].device != self.oft_blocks[0].device: | |
self.I = [I.to(self.oft_blocks[0].device) for I in self.I] | |
block_R_weighted_list = [] | |
for i in range(len(self.oft_blocks)): | |
block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) | |
norm_Q = torch.norm(block_Q.flatten()) | |
new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) | |
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) | |
I = self.I[i] | |
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) | |
block_R_weighted = self.multiplier * (block_R - I) + I | |
block_R_weighted_list.append(block_R_weighted) | |
return block_R_weighted_list | |
def forward(self, x, scale=None): | |
if self.multiplier == 0.0: | |
return self.org_forward(x) | |
org_module = self.org_module[0] | |
org_dtype = x.dtype | |
R = self.get_weight() | |
W = org_module.weight.to(torch.float32) | |
B = org_module.bias.to(torch.float32) | |
# split W to match R | |
results = [] | |
d2 = 0 | |
for i in range(len(R)): | |
d1 = d2 | |
d2 += self.split_dims[i] | |
W1 = W[d1:d2] | |
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) | |
RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) | |
RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") | |
B1 = B[d1:d2] | |
result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) | |
results.append(result) | |
result = torch.cat(results, dim=-1) | |
return result | |
class OFTInfModule(OFTModule): | |
def __init__( | |
self, | |
oft_name, | |
org_module: torch.nn.Module, | |
multiplier=1.0, | |
dim=4, | |
alpha=1, | |
split_dims: Optional[List[int]] = None, | |
**kwargs, | |
): | |
# no dropout for inference | |
super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) | |
self.enabled = True | |
self.network: OFTNetwork = None | |
def set_network(self, network): | |
self.network = network | |
def forward(self, x, scale=None): | |
if not self.enabled: | |
return self.org_forward(x) | |
return super().forward(x, scale) | |
def merge_to(self, multiplier=None): | |
# get org weight | |
org_sd = self.org_module[0].state_dict() | |
W = org_sd["weight"].to(torch.float32) | |
R = self.get_weight(multiplier).to(torch.float32) | |
d2 = 0 | |
W_list = [] | |
for i in range(len(self.oft_blocks)): | |
d1 = d2 | |
d2 += self.split_dims[i] | |
W1 = W[d1:d2] | |
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) | |
W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) | |
W1 = einops.rearrange(W1, "k m p -> (k m) p") | |
W_list.append(W1) | |
W = torch.cat(W_list, dim=-1) | |
# convert back to original dtype | |
W = W.to(org_sd["weight"].dtype) | |
# set weight to org_module | |
org_sd["weight"] = W | |
self.org_module[0].load_state_dict(org_sd) | |
def create_network( | |
multiplier: float, | |
network_dim: Optional[int], | |
network_alpha: Optional[float], | |
vae: AutoencoderKL, | |
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], | |
unet, | |
neuron_dropout: Optional[float] = None, | |
**kwargs, | |
): | |
if network_dim is None: | |
network_dim = 4 # default | |
if network_alpha is None: # should be set | |
logger.info( | |
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" | |
) | |
network_alpha = 1e-3 | |
elif network_alpha >= 1: | |
logger.warning( | |
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" | |
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" | |
) | |
# attn only or all linear (FFN) layers | |
enable_all_linear = kwargs.get("enable_all_linear", None) | |
# enable_conv = kwargs.get("enable_conv", None) | |
if enable_all_linear is not None: | |
enable_all_linear = bool(enable_all_linear) | |
# if enable_conv is not None: | |
# enable_conv = bool(enable_conv) | |
network = OFTNetwork( | |
text_encoder, | |
unet, | |
multiplier=multiplier, | |
dim=network_dim, | |
alpha=network_alpha, | |
enable_all_linear=enable_all_linear, | |
varbose=True, | |
) | |
return network | |
# Create network from weights for inference, weights are not loaded here (because can be merged) | |
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): | |
if weights_sd is None: | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import load_file, safe_open | |
weights_sd = load_file(file) | |
else: | |
weights_sd = torch.load(file, map_location="cpu") | |
# check dim, alpha and if weights have for conv2d | |
dim = None | |
alpha = None | |
all_linear = None | |
for name, param in weights_sd.items(): | |
if name.endswith(".alpha"): | |
if alpha is None: | |
alpha = param.item() | |
elif "qkv" in name: | |
continue # ignore qkv | |
else: | |
if dim is None: | |
dim = param.size()[0] | |
if all_linear is None and "_mlp" in name: | |
all_linear = True | |
if dim is not None and alpha is not None and all_linear is not None: | |
break | |
if all_linear is None: | |
all_linear = False | |
module_class = OFTInfModule if for_inference else OFTModule | |
network = OFTNetwork( | |
text_encoder, | |
unet, | |
multiplier=multiplier, | |
dim=dim, | |
alpha=alpha, | |
enable_all_linear=all_linear, | |
module_class=module_class, | |
) | |
return network, weights_sd | |
class OFTNetwork(torch.nn.Module): | |
FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] | |
FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] | |
OFT_PREFIX_UNET = "oft_unet" | |
def __init__( | |
self, | |
text_encoder: Union[List[CLIPTextModel], CLIPTextModel], | |
unet, | |
multiplier: float = 1.0, | |
dim: int = 4, | |
alpha: float = 1, | |
enable_all_linear: Optional[bool] = False, | |
module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, | |
varbose: Optional[bool] = False, | |
) -> None: | |
super().__init__() | |
self.train_t5xxl = False # make compatible with LoRA | |
self.multiplier = multiplier | |
self.dim = dim | |
self.alpha = alpha | |
logger.info( | |
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" | |
) | |
# create module instances | |
def create_modules( | |
root_module: torch.nn.Module, | |
target_replace_modules: List[torch.nn.Module], | |
) -> List[OFTModule]: | |
prefix = self.OFT_PREFIX_UNET | |
ofts = [] | |
for name, module in root_module.named_modules(): | |
if module.__class__.__name__ in target_replace_modules: | |
for child_name, child_module in module.named_modules(): | |
is_linear = "Linear" in child_module.__class__.__name__ | |
if is_linear: | |
oft_name = prefix + "." + name + "." + child_name | |
oft_name = oft_name.replace(".", "_") | |
# logger.info(oft_name) | |
if "double" in oft_name and "qkv" in oft_name: | |
split_dims = [3072] * 3 | |
elif "single" in oft_name and "linear1" in oft_name: | |
split_dims = [3072] * 3 + [12288] | |
else: | |
split_dims = None | |
oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) | |
ofts.append(oft) | |
return ofts | |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights | |
if enable_all_linear: | |
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR | |
else: | |
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY | |
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) | |
logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") | |
# assertion | |
names = set() | |
for oft in self.unet_ofts: | |
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" | |
names.add(oft.oft_name) | |
def set_multiplier(self, multiplier): | |
self.multiplier = multiplier | |
for oft in self.unet_ofts: | |
oft.multiplier = self.multiplier | |
def load_weights(self, file): | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import load_file | |
weights_sd = load_file(file) | |
else: | |
weights_sd = torch.load(file, map_location="cpu") | |
info = self.load_state_dict(weights_sd, False) | |
return info | |
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): | |
assert apply_unet, "apply_unet must be True" | |
for oft in self.unet_ofts: | |
oft.apply_to() | |
self.add_module(oft.oft_name, oft) | |
# マージできるかどうかを返す | |
def is_mergeable(self): | |
return True | |
# TODO refactor to common function with apply_to | |
def merge_to(self, text_encoder, unet, weights_sd, dtype, device): | |
logger.info("enable OFT for U-Net") | |
for oft in self.unet_ofts: | |
sd_for_lora = {} | |
for key in weights_sd.keys(): | |
if key.startswith(oft.oft_name): | |
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] | |
oft.load_state_dict(sd_for_lora, False) | |
oft.merge_to() | |
logger.info(f"weights are merged") | |
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも | |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): | |
self.requires_grad_(True) | |
all_params = [] | |
def enumerate_params(ofts): | |
params = [] | |
for oft in ofts: | |
params.extend(oft.parameters()) | |
# logger.info num of params | |
num_params = 0 | |
for p in params: | |
num_params += p.numel() | |
logger.info(f"OFT params: {num_params}") | |
return params | |
param_data = {"params": enumerate_params(self.unet_ofts)} | |
if unet_lr is not None: | |
param_data["lr"] = unet_lr | |
all_params.append(param_data) | |
return all_params | |
def enable_gradient_checkpointing(self): | |
# not supported | |
pass | |
def prepare_grad_etc(self, text_encoder, unet): | |
self.requires_grad_(True) | |
def on_epoch_start(self, text_encoder, unet): | |
self.train() | |
def get_trainable_params(self): | |
return self.parameters() | |
def save_weights(self, file, dtype, metadata): | |
if metadata is not None and len(metadata) == 0: | |
metadata = None | |
state_dict = self.state_dict() | |
if dtype is not None: | |
for key in list(state_dict.keys()): | |
v = state_dict[key] | |
v = v.detach().clone().to("cpu").to(dtype) | |
state_dict[key] = v | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import save_file | |
from library import train_util | |
# Precalculate model hashes to save time on indexing | |
if metadata is None: | |
metadata = {} | |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) | |
metadata["sshs_model_hash"] = model_hash | |
metadata["sshs_legacy_hash"] = legacy_hash | |
save_file(state_dict, file, metadata) | |
else: | |
torch.save(state_dict, file) | |
def backup_weights(self): | |
# 重みのバックアップを行う | |
ofts: List[OFTInfModule] = self.unet_ofts | |
for oft in ofts: | |
org_module = oft.org_module[0] | |
if not hasattr(org_module, "_lora_org_weight"): | |
sd = org_module.state_dict() | |
org_module._lora_org_weight = sd["weight"].detach().clone() | |
org_module._lora_restored = True | |
def restore_weights(self): | |
# 重みのリストアを行う | |
ofts: List[OFTInfModule] = self.unet_ofts | |
for oft in ofts: | |
org_module = oft.org_module[0] | |
if not org_module._lora_restored: | |
sd = org_module.state_dict() | |
sd["weight"] = org_module._lora_org_weight | |
org_module.load_state_dict(sd) | |
org_module._lora_restored = True | |
def pre_calculation(self): | |
# 事前計算を行う | |
ofts: List[OFTInfModule] = self.unet_ofts | |
for oft in ofts: | |
org_module = oft.org_module[0] | |
oft.merge_to() | |
# sd = org_module.state_dict() | |
# org_weight = sd["weight"] | |
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) | |
# sd["weight"] = org_weight + lora_weight | |
# assert sd["weight"].shape == org_weight.shape | |
# org_module.load_state_dict(sd) | |
org_module._lora_restored = False | |
oft.enabled = False | |