cai-qi's picture
Super-squash branch 'main' using huggingface_hub
1063690 verified
import argparse
import os
import imageio
import torch
import numpy as np
from einops import rearrange
from torch import Tensor, nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from safetensors.torch import load_file
import torch.utils.checkpoint as checkpoint
from .conv import Conv
from .multiscale_bsq import MultiScaleBSQ
ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16}
class Normalize(nn.Module):
def __init__(self, in_channels, norm_type, norm_axis="spatial"):
super().__init__()
self.norm_axis = norm_axis
assert norm_type in ['group', 'batch', "no"]
if norm_type == 'group':
if in_channels % 32 == 0:
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif in_channels % 24 == 0:
self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True)
else:
raise NotImplementedError
elif norm_type == 'batch':
self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True
elif norm_type == 'no':
self.norm = nn.Identity()
def forward(self, x):
if self.norm_axis == "spatial":
if x.ndim == 4:
x = self.norm(x)
else:
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.norm(x)
x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
elif self.norm_axis == "spatial-temporal":
x = self.norm(x)
else:
raise NotImplementedError
return x
def swish(x: Tensor) -> Tensor:
try:
return x * torch.sigmoid(x)
except:
device = x.device
x = x.cpu().pin_memory()
return (x*torch.sigmoid(x)).to(device=device)
class AttnBlock(nn.Module):
def __init__(self, in_channels, norm_type='group', cnn_param=None):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
self.q = Conv(in_channels, in_channels, kernel_size=1)
self.k = Conv(in_channels, in_channels, kernel_size=1)
self.v = Conv(in_channels, in_channels, kernel_size=1)
self.proj_out = Conv(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
B, _, T, _, _ = h_.shape
h_ = self.norm(h_)
h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, norm_type='group', cnn_param=None):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
if cnn_param["res_conv_2d"] in ["half", "full"]:
self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
else:
self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
self.norm2 = Normalize(out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
if cnn_param["res_conv_2d"] in ["full"]:
self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
else:
self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
if self.in_channels != self.out_channels:
self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False):
super().__init__()
assert spatial_down == True
if cnn_type == "2d":
self.pad = (0,1,0,1)
if cnn_type == "3d":
self.pad = (0,1,0,1,0,0) # add padding to the right for h-axis and w-axis. No padding for t-axis
# no asymmetric padding in torch conv, must do it ourselves
self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0, cnn_type=cnn_type, temporal_down=temporal_down)
def forward(self, x: Tensor):
x = nn.functional.pad(x, self.pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels, cnn_type="2d", spatial_up=False, temporal_up=False, use_pxsl=False):
super().__init__()
if cnn_type == "2d":
self.scale_factor = 2
self.causal_offset = 0
else:
assert spatial_up == True
if temporal_up:
self.scale_factor = (2,2,2)
self.causal_offset = -1
else:
self.scale_factor = (1,2,2)
self.causal_offset = 0
self.use_pxsl = use_pxsl
if self.use_pxsl:
self.conv = Conv(in_channels, in_channels*4, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
self.pxsl = nn.PixelShuffle(2)
else:
self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
def forward(self, x: Tensor):
if self.use_pxsl:
x = self.conv(x)
x = self.pxsl(x)
else:
try:
x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
except:
# shard across channel
_xs = []
for i in range(x.shape[1]):
_x = F.interpolate(x[:,i:i+1,...], scale_factor=self.scale_factor, mode="nearest")
_xs.append(_x)
x = torch.cat(_xs, dim=1)
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
in_channels = 3,
patch_size=8, temporal_patch_size=4,
norm_type='group', cnn_param=None,
use_checkpoint=False,
use_vae=True,
):
super().__init__()
self.max_down = np.log2(patch_size)
self.temporal_max_down = np.log2(temporal_patch_size)
self.temporal_down_offset = self.max_down - self.temporal_max_down
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.cnn_param = cnn_param
self.use_checkpoint = use_checkpoint
# downsampling
# self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
# cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos
if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video
self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
else:
self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
# downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE
spatial_down = True if i_level < self.max_down else False
temporal_down = True if i_level < self.max_down and i_level >= self.temporal_down_offset else False
if spatial_down or temporal_down:
down.downsample = Downsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_down=spatial_down, temporal_down=temporal_down)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
if cnn_param["cnn_attention"] == "yes":
self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
# end
self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
if cnn_param["conv_inner_2d"] == "yes":
self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
else:
self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
def forward(self, x, return_hidden=False):
if not self.use_checkpoint:
return self._forward(x, return_hidden=return_hidden)
else:
return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False)
def _forward(self, x: Tensor, return_hidden=False) -> Tensor:
# downsampling
h0 = self.conv_in(x)
hs = [h0]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if hasattr(self.down[i_level], "downsample"):
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
hs_mid = [h]
h = self.mid.block_1(h)
if self.cnn_param["cnn_attention"] == "yes":
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
hs_mid.append(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
if return_hidden:
return h, hs, hs_mid
else:
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
out_ch = 3,
patch_size=8, temporal_patch_size=4,
norm_type="group", cnn_param=None,
use_checkpoint=False,
use_freq_dec=False, # use frequency features for decoder
use_pxsf=False
):
super().__init__()
self.max_up = np.log2(patch_size)
self.temporal_max_up = np.log2(temporal_patch_size)
self.temporal_up_offset = self.max_up - self.temporal_max_up
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.ffactor = 2 ** (self.num_resolutions - 1)
self.cnn_param = cnn_param
self.use_checkpoint = use_checkpoint
self.use_freq_dec = use_freq_dec
self.use_pxsf = use_pxsf
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
# z to block_in
if cnn_param["conv_inner_2d"] == "yes":
self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d")
else:
self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
if cnn_param["cnn_attention"] == "yes":
self.mid.attn_1 = AttnBlock(block_in, norm_type=norm_type, cnn_param=cnn_param)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
# upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder
# https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228
spatial_up = True if 1 <= i_level <= self.max_up else False
temporal_up = True if 1 <= i_level <= self.max_up and i_level >= self.temporal_up_offset+1 else False
if spatial_up or temporal_up:
up.upsample = Upsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_up=spatial_up, temporal_up=temporal_up, use_pxsl=self.use_pxsf)
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
if cnn_param["conv_in_out_2d"] == "yes":
self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
else:
self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
def forward(self, z):
if not self.use_checkpoint:
return self._forward(z)
else:
return checkpoint.checkpoint(self._forward, z, use_reentrant=False)
def _forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
if self.cnn_param["cnn_attention"] == "yes":
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if hasattr(self.up[i_level], "upsample"):
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class AutoEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
cnn_param = dict(
cnn_type=args.cnn_type,
conv_in_out_2d=args.conv_in_out_2d,
res_conv_2d=args.res_conv_2d,
cnn_attention=args.cnn_attention,
cnn_norm_axis=args.cnn_norm_axis,
conv_inner_2d=args.conv_inner_2d,
)
self.encoder = Encoder(
ch=args.base_ch,
ch_mult=args.encoder_ch_mult,
num_res_blocks=args.num_res_blocks,
z_channels=args.codebook_dim,
patch_size=args.patch_size,
temporal_patch_size=args.temporal_patch_size,
cnn_param=cnn_param,
use_checkpoint=args.use_checkpoint,
use_vae=args.use_vae,
)
self.decoder = Decoder(
ch=args.base_ch,
ch_mult=args.decoder_ch_mult,
num_res_blocks=args.num_res_blocks,
z_channels=args.codebook_dim,
patch_size=args.patch_size,
temporal_patch_size=args.temporal_patch_size,
cnn_param=cnn_param,
use_checkpoint=args.use_checkpoint,
use_freq_dec=args.use_freq_dec,
use_pxsf=args.use_pxsf # pixelshuffle for upsampling
)
self.z_drop = nn.Dropout(args.z_drop)
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.codebook_dim = self.embed_dim = args.codebook_dim
self.gan_feat_weight = args.gan_feat_weight
self.video_perceptual_weight = args.video_perceptual_weight
self.recon_loss_type = args.recon_loss_type
self.l1_weight = args.l1_weight
self.use_vae = args.use_vae
self.kl_weight = args.kl_weight
self.lfq_weight = args.lfq_weight
self.image_gan_weight = args.image_gan_weight # image GAN loss weight
self.video_gan_weight = args.video_gan_weight # video GAN loss weight
self.perceptual_weight = args.perceptual_weight
self.flux_weight = args.flux_weight
self.cycle_weight = args.cycle_weight
self.cycle_feat_weight = args.cycle_feat_weight
self.cycle_gan_weight = args.cycle_gan_weight
self.flux_image_encoder = None
if not args.use_vae:
if args.quantizer_type == 'MultiScaleBSQ':
self.quantizer = MultiScaleBSQ(
dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
codebook_size = args.codebook_size, # codebook size, must be a power of 2
entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss
diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ
ln_before_quant=args.ln_before_quant, # use layer norm before quantization
ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d)
commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss
new_quant=args.new_quant,
use_decay_factor=args.use_decay_factor,
mask_out=args.mask_out,
use_stochastic_depth=args.use_stochastic_depth,
drop_rate=args.drop_rate,
schedule_mode=args.schedule_mode,
keep_first_quant=args.keep_first_quant,
keep_last_quant=args.keep_last_quant,
remove_residual_detach=args.remove_residual_detach,
use_out_phi=args.use_out_phi,
use_out_phi_res=args.use_out_phi_res,
random_flip = args.random_flip,
flip_prob = args.flip_prob,
flip_mode = args.flip_mode,
max_flip_lvl = args.max_flip_lvl,
random_flip_1lvl = args.random_flip_1lvl,
flip_lvl_idx = args.flip_lvl_idx,
drop_when_test = args.drop_when_test,
drop_lvl_idx = args.drop_lvl_idx,
drop_lvl_num = args.drop_lvl_num,
)
self.quantize = self.quantizer
self.vocab_size = args.codebook_size
else:
raise NotImplementedError(f"{args.quantizer_type} not supported")
def forward(self, x):
is_image = x.ndim == 4
if not is_image:
B, C, T, H, W = x.shape
else:
B, C, H, W = x.shape
T = 1
enc_dtype = ptdtype[self.args.encoder_dtype]
with torch.amp.autocast("cuda", dtype=enc_dtype):
h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
hs = [_h.detach() for _h in hs]
hs_mid = [_h.detach() for _h in hs_mid]
h = h.to(dtype=torch.float32)
# print(z.shape)
# Multiscale LFQ
z, all_indices, _, _, all_loss, _ = self.quantizer(h)
x_recon = self.decoder(z)
vq_output = {
"commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty
"encodings": all_indices,
}
return x_recon, vq_output
def encode_for_raw_features(self, x, scale_schedule, return_residual_norm_per_scale=False):
is_image = x.ndim == 4
if not is_image:
B, C, T, H, W = x.shape
else:
B, C, H, W = x.shape
T = 1
enc_dtype = ptdtype[self.args.encoder_dtype]
with torch.amp.autocast("cuda", dtype=enc_dtype):
h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
hs = [_h.detach() for _h in hs]
hs_mid = [_h.detach() for _h in hs_mid]
h = h.to(dtype=torch.float32)
return h, hs, hs_mid
def encode(self, x, scale_schedule, return_residual_norm_per_scale=False):
h, hs, hs_mid = self.encode_for_raw_features(x, scale_schedule, return_residual_norm_per_scale)
# Multiscale LFQ
z, all_indices, all_bit_indices, residual_norm_per_scale, all_loss, var_input = self.quantizer(h, scale_schedule=scale_schedule, return_residual_norm_per_scale=return_residual_norm_per_scale)
return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input
def decode(self, z):
x_recon = self.decoder(z)
x_recon = torch.clamp(x_recon, min=-1, max=1)
return x_recon
def decode_from_indices(self, all_indices, scale_schedule, label_type):
summed_codes = 0
for idx_Bl in all_indices:
codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type)
summed_codes += F.interpolate(codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up)
assert summed_codes.shape[-3] == 1
x_recon = self.decoder(summed_codes.squeeze(-3))
x_recon = torch.clamp(x_recon, min=-1, max=1)
return summed_codes, x_recon
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--flux_weight", type=float, default=0)
parser.add_argument("--cycle_weight", type=float, default=0)
parser.add_argument("--cycle_feat_weight", type=float, default=0)
parser.add_argument("--cycle_gan_weight", type=float, default=0)
parser.add_argument("--cycle_loop", type=int, default=0)
parser.add_argument("--z_drop", type=float, default=0.)
return parser