Spaces:
Runtime error
Runtime error
working.
Browse files- xora/examples/image_to_video.py +87 -0
- xora/models/autoencoders/causal_video_autoencoder.py +3 -1
- xora/models/autoencoders/vae_encode.py +11 -41
- xora/models/autoencoders/video_autoencoder.py +912 -0
- xora/models/transformers/embeddings.py +125 -0
- xora/models/transformers/transformer3d.py +77 -4
- xora/pipelines/pipeline_video_pixart_alpha.py +181 -13
- xora/schedulers/rf.py +13 -4
- xora/utils/conditioning_method.py +7 -0
- xora/utils/dist_util.py +11 -0
xora/examples/image_to_video.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 3 |
+
from xora.models.transformers.transformer3d import Transformer3DModel
|
| 4 |
+
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
| 5 |
+
from xora.schedulers.rf import RectifiedFlowScheduler
|
| 6 |
+
from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from transformers import T5EncoderModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 12 |
+
vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
|
| 13 |
+
dtype = torch.float32
|
| 14 |
+
vae = CausalVideoAutoencoder.from_pretrained(
|
| 15 |
+
pretrained_model_name_or_path=vae_local_path,
|
| 16 |
+
revision=False,
|
| 17 |
+
torch_dtype=torch.bfloat16,
|
| 18 |
+
load_in_8bit=False,
|
| 19 |
+
).cuda()
|
| 20 |
+
transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
|
| 21 |
+
transformer_config = Transformer3DModel.load_config(transformer_config_path)
|
| 22 |
+
transformer = Transformer3DModel.from_config(transformer_config)
|
| 23 |
+
transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-first-frame-cond-4k-seq/ckpt/01822000/model.pt")
|
| 24 |
+
transformer_ckpt_state_dict = torch.load(transformer_local_path)
|
| 25 |
+
transformer.load_state_dict(transformer_ckpt_state_dict, True)
|
| 26 |
+
transformer = transformer.cuda()
|
| 27 |
+
unet = transformer
|
| 28 |
+
scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
|
| 29 |
+
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
| 30 |
+
scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
|
| 31 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
| 32 |
+
# text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
|
| 33 |
+
|
| 34 |
+
submodel_dict = {
|
| 35 |
+
"unet": unet,
|
| 36 |
+
"transformer": transformer,
|
| 37 |
+
"patchifier": patchifier,
|
| 38 |
+
"text_encoder": None,
|
| 39 |
+
"scheduler": scheduler,
|
| 40 |
+
"vae": vae,
|
| 41 |
+
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
|
| 45 |
+
safety_checker=None,
|
| 46 |
+
revision=None,
|
| 47 |
+
torch_dtype=dtype,
|
| 48 |
+
**submodel_dict,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
num_inference_steps=20
|
| 52 |
+
num_images_per_prompt=2
|
| 53 |
+
guidance_scale=3
|
| 54 |
+
height=512
|
| 55 |
+
width=768
|
| 56 |
+
num_frames=57
|
| 57 |
+
frame_rate=25
|
| 58 |
+
# sample = {
|
| 59 |
+
# "prompt": "A cat", # (B, L, E)
|
| 60 |
+
# 'prompt_attention_mask': None, # (B , L)
|
| 61 |
+
# 'negative_prompt': "Ugly deformed",
|
| 62 |
+
# 'negative_prompt_attention_mask': None # (B , L)
|
| 63 |
+
# }
|
| 64 |
+
|
| 65 |
+
sample = torch.load("/opt/sample.pt")
|
| 66 |
+
for _, item in sample.items():
|
| 67 |
+
if item is not None:
|
| 68 |
+
item = item.cuda()
|
| 69 |
+
media_items = torch.load("/opt/sample_media.pt")
|
| 70 |
+
|
| 71 |
+
images = pipeline(
|
| 72 |
+
num_inference_steps=num_inference_steps,
|
| 73 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 74 |
+
guidance_scale=guidance_scale,
|
| 75 |
+
generator=None,
|
| 76 |
+
output_type="pt",
|
| 77 |
+
callback_on_step_end=None,
|
| 78 |
+
height=height,
|
| 79 |
+
width=width,
|
| 80 |
+
num_frames=num_frames,
|
| 81 |
+
frame_rate=frame_rate,
|
| 82 |
+
**sample,
|
| 83 |
+
is_video=True,
|
| 84 |
+
vae_per_channel_normalize=True,
|
| 85 |
+
).images
|
| 86 |
+
|
| 87 |
+
print()
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
|
@@ -8,11 +8,13 @@ import torch
|
|
| 8 |
import numpy as np
|
| 9 |
from einops import rearrange
|
| 10 |
from torch import nn
|
|
|
|
| 11 |
|
| 12 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 13 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
| 14 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
| 18 |
@classmethod
|
|
@@ -138,7 +140,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 138 |
key = key.replace(k, v)
|
| 139 |
|
| 140 |
if "norm" in key and key not in model_keys:
|
| 141 |
-
|
| 142 |
continue
|
| 143 |
|
| 144 |
converted_state_dict[key] = value
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from einops import rearrange
|
| 10 |
from torch import nn
|
| 11 |
+
from diffusers.utils import logging
|
| 12 |
|
| 13 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 14 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
| 15 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
| 16 |
|
| 17 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 18 |
|
| 19 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
| 20 |
@classmethod
|
|
|
|
| 140 |
key = key.replace(k, v)
|
| 141 |
|
| 142 |
if "norm" in key and key not in model_keys:
|
| 143 |
+
logger.info(f"Removing key {key} from state_dict as it is not present in the model")
|
| 144 |
continue
|
| 145 |
|
| 146 |
converted_state_dict[key] = value
|
xora/models/autoencoders/vae_encode.py
CHANGED
|
@@ -1,44 +1,12 @@
|
|
| 1 |
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
from diffusers import AutoencoderKL
|
| 4 |
from einops import rearrange
|
| 5 |
from torch import Tensor
|
| 6 |
-
from torch.nn import functional
|
| 7 |
|
| 8 |
|
| 9 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
|
| 13 |
-
super().__init__()
|
| 14 |
-
stride: int = 2
|
| 15 |
-
self.padding = padding
|
| 16 |
-
self.in_channels = in_channels
|
| 17 |
-
self.dims = dims
|
| 18 |
-
self.conv = make_conv_nd(
|
| 19 |
-
dims=dims,
|
| 20 |
-
in_channels=in_channels,
|
| 21 |
-
out_channels=out_channels,
|
| 22 |
-
kernel_size=kernel_size,
|
| 23 |
-
stride=stride,
|
| 24 |
-
padding=padding,
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
def forward(self, x, downsample_in_time=True):
|
| 28 |
-
conv = self.conv
|
| 29 |
-
if self.padding == 0:
|
| 30 |
-
if self.dims == 2:
|
| 31 |
-
padding = (0, 1, 0, 1)
|
| 32 |
-
else:
|
| 33 |
-
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
|
| 34 |
-
|
| 35 |
-
x = functional.pad(x, padding, mode="constant", value=0)
|
| 36 |
-
|
| 37 |
-
if self.dims == (2, 1) and not downsample_in_time:
|
| 38 |
-
return conv(x, skip_time_conv=True)
|
| 39 |
-
|
| 40 |
-
return conv(x)
|
| 41 |
-
|
| 42 |
|
| 43 |
|
| 44 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
|
@@ -78,7 +46,7 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
| 78 |
if channels != 3:
|
| 79 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
| 80 |
|
| 81 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
| 82 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
| 83 |
if split_size > 1:
|
| 84 |
if len(media_items) % split_size != 0:
|
|
@@ -86,14 +54,16 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
| 86 |
encode_bs = len(media_items) // split_size
|
| 87 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 88 |
latents = []
|
|
|
|
| 89 |
for image_batch in media_items.split(encode_bs):
|
| 90 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
|
|
|
| 91 |
latents = torch.cat(latents, dim=0)
|
| 92 |
else:
|
| 93 |
latents = vae.encode(media_items).latent_dist.sample()
|
| 94 |
|
| 95 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
| 96 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
| 97 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
| 98 |
return latents
|
| 99 |
|
|
@@ -104,7 +74,7 @@ def vae_decode(
|
|
| 104 |
is_video_shaped = latents.dim() == 5
|
| 105 |
batch_size = latents.shape[0]
|
| 106 |
|
| 107 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
| 108 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
| 109 |
if split_size > 1:
|
| 110 |
if len(latents) % split_size != 0:
|
|
@@ -118,13 +88,13 @@ def vae_decode(
|
|
| 118 |
else:
|
| 119 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
| 120 |
|
| 121 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
| 122 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
| 123 |
return images
|
| 124 |
|
| 125 |
|
| 126 |
def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
|
| 127 |
-
if isinstance(vae, (CausalVideoAutoencoder)):
|
| 128 |
*_, fl, hl, wl = latents.shape
|
| 129 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
| 130 |
latents = latents.to(vae.dtype)
|
|
@@ -148,7 +118,7 @@ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
|
|
| 148 |
else:
|
| 149 |
down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
|
| 150 |
spatial = vae.config.patch_size * 2**down_blocks
|
| 151 |
-
temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae) else 1
|
| 152 |
|
| 153 |
return (temporal, spatial, spatial)
|
| 154 |
|
|
@@ -168,4 +138,4 @@ def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_no
|
|
| 168 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 169 |
if vae_per_channel_normalize
|
| 170 |
else latents / vae.config.scaling_factor
|
| 171 |
-
)
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from diffusers import AutoencoderKL
|
| 3 |
from einops import rearrange
|
| 4 |
from torch import Tensor
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 8 |
+
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
| 9 |
+
import xora.utils.dist_util
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
|
|
|
| 46 |
if channels != 3:
|
| 47 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
| 48 |
|
| 49 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 50 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
| 51 |
if split_size > 1:
|
| 52 |
if len(media_items) % split_size != 0:
|
|
|
|
| 54 |
encode_bs = len(media_items) // split_size
|
| 55 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 56 |
latents = []
|
| 57 |
+
dist_util.execute_graph()
|
| 58 |
for image_batch in media_items.split(encode_bs):
|
| 59 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
| 60 |
+
dist_util.execute_graph()
|
| 61 |
latents = torch.cat(latents, dim=0)
|
| 62 |
else:
|
| 63 |
latents = vae.encode(media_items).latent_dist.sample()
|
| 64 |
|
| 65 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
| 66 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 67 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
| 68 |
return latents
|
| 69 |
|
|
|
|
| 74 |
is_video_shaped = latents.dim() == 5
|
| 75 |
batch_size = latents.shape[0]
|
| 76 |
|
| 77 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 78 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
| 79 |
if split_size > 1:
|
| 80 |
if len(latents) % split_size != 0:
|
|
|
|
| 88 |
else:
|
| 89 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
| 90 |
|
| 91 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 92 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
| 93 |
return images
|
| 94 |
|
| 95 |
|
| 96 |
def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
|
| 97 |
+
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 98 |
*_, fl, hl, wl = latents.shape
|
| 99 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
| 100 |
latents = latents.to(vae.dtype)
|
|
|
|
| 118 |
else:
|
| 119 |
down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
|
| 120 |
spatial = vae.config.patch_size * 2**down_blocks
|
| 121 |
+
temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1
|
| 122 |
|
| 123 |
return (temporal, spatial, spatial)
|
| 124 |
|
|
|
|
| 138 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 139 |
if vae_per_channel_normalize
|
| 140 |
else latents / vae.config.scaling_factor
|
| 141 |
+
)
|
xora/models/autoencoders/video_autoencoder.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
from typing import Any, Mapping, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional
|
| 11 |
+
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
|
| 14 |
+
from txt2img.models.layers.nn import Identity
|
| 15 |
+
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 16 |
+
from xora.models.autoencoders.pixel_norm import PixelNorm
|
| 17 |
+
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VideoAutoencoder(AutoencoderKLWrapper):
|
| 23 |
+
@classmethod
|
| 24 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
|
| 25 |
+
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 26 |
+
config = cls.load_config(config_local_path, **kwargs)
|
| 27 |
+
video_vae = cls.from_config(config)
|
| 28 |
+
video_vae.to(kwargs["torch_dtype"])
|
| 29 |
+
|
| 30 |
+
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
|
| 31 |
+
ckpt_state_dict = torch.load(model_local_path)
|
| 32 |
+
video_vae.load_state_dict(ckpt_state_dict)
|
| 33 |
+
|
| 34 |
+
statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json"
|
| 35 |
+
if statistics_local_path.exists():
|
| 36 |
+
with open(statistics_local_path, "r") as file:
|
| 37 |
+
data = json.load(file)
|
| 38 |
+
transposed_data = list(zip(*data["data"]))
|
| 39 |
+
data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)}
|
| 40 |
+
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 41 |
+
video_vae.register_buffer(
|
| 42 |
+
"mean_of_means", data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"]))
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return video_vae
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def from_config(config):
|
| 49 |
+
assert config["_class_name"] == "VideoAutoencoder", "config must have _class_name=VideoAutoencoder"
|
| 50 |
+
if isinstance(config["dims"], list):
|
| 51 |
+
config["dims"] = tuple(config["dims"])
|
| 52 |
+
|
| 53 |
+
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 54 |
+
|
| 55 |
+
double_z = config.get("double_z", True)
|
| 56 |
+
latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none")
|
| 57 |
+
use_quant_conv = config.get("use_quant_conv", True)
|
| 58 |
+
|
| 59 |
+
if use_quant_conv and latent_log_var == "uniform":
|
| 60 |
+
raise ValueError("uniform latent_log_var requires use_quant_conv=False")
|
| 61 |
+
|
| 62 |
+
encoder = Encoder(
|
| 63 |
+
dims=config["dims"],
|
| 64 |
+
in_channels=config.get("in_channels", 3),
|
| 65 |
+
out_channels=config["latent_channels"],
|
| 66 |
+
block_out_channels=config["block_out_channels"],
|
| 67 |
+
patch_size=config.get("patch_size", 1),
|
| 68 |
+
latent_log_var=latent_log_var,
|
| 69 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 70 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
| 71 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
decoder = Decoder(
|
| 75 |
+
dims=config["dims"],
|
| 76 |
+
in_channels=config["latent_channels"],
|
| 77 |
+
out_channels=config.get("out_channels", 3),
|
| 78 |
+
block_out_channels=config["block_out_channels"],
|
| 79 |
+
patch_size=config.get("patch_size", 1),
|
| 80 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 81 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
| 82 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
dims = config["dims"]
|
| 86 |
+
return VideoAutoencoder(
|
| 87 |
+
encoder=encoder,
|
| 88 |
+
decoder=decoder,
|
| 89 |
+
latent_channels=config["latent_channels"],
|
| 90 |
+
dims=dims,
|
| 91 |
+
use_quant_conv=use_quant_conv,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def config(self):
|
| 96 |
+
return SimpleNamespace(
|
| 97 |
+
_class_name="VideoAutoencoder",
|
| 98 |
+
dims=self.dims,
|
| 99 |
+
in_channels=self.encoder.conv_in.in_channels // (self.encoder.patch_size_t * self.encoder.patch_size**2),
|
| 100 |
+
out_channels=self.decoder.conv_out.out_channels // (self.decoder.patch_size_t * self.decoder.patch_size**2),
|
| 101 |
+
latent_channels=self.decoder.conv_in.in_channels,
|
| 102 |
+
block_out_channels=[
|
| 103 |
+
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
| 104 |
+
for i in range(len(self.encoder.down_blocks))
|
| 105 |
+
],
|
| 106 |
+
scaling_factor=1.0,
|
| 107 |
+
norm_layer=self.encoder.norm_layer,
|
| 108 |
+
patch_size=self.encoder.patch_size,
|
| 109 |
+
latent_log_var=self.encoder.latent_log_var,
|
| 110 |
+
use_quant_conv=self.use_quant_conv,
|
| 111 |
+
patch_size_t=self.encoder.patch_size_t,
|
| 112 |
+
add_channel_padding=self.encoder.add_channel_padding,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def is_video_supported(self):
|
| 117 |
+
"""
|
| 118 |
+
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
|
| 119 |
+
"""
|
| 120 |
+
return self.dims != 2
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def downscale_factor(self):
|
| 124 |
+
return self.encoder.downsample_factor
|
| 125 |
+
|
| 126 |
+
def to_json_string(self) -> str:
|
| 127 |
+
import json
|
| 128 |
+
|
| 129 |
+
return json.dumps(self.config.__dict__)
|
| 130 |
+
|
| 131 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 132 |
+
model_keys = set(name for name, _ in self.named_parameters())
|
| 133 |
+
|
| 134 |
+
key_mapping = {
|
| 135 |
+
".resnets.": ".res_blocks.",
|
| 136 |
+
"downsamplers.0": "downsample",
|
| 137 |
+
"upsamplers.0": "upsample",
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
converted_state_dict = {}
|
| 141 |
+
for key, value in state_dict.items():
|
| 142 |
+
for k, v in key_mapping.items():
|
| 143 |
+
key = key.replace(k, v)
|
| 144 |
+
|
| 145 |
+
if "norm" in key and key not in model_keys:
|
| 146 |
+
logger.info(f"Removing key {key} from state_dict as it is not present in the model")
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
converted_state_dict[key] = value
|
| 150 |
+
|
| 151 |
+
super().load_state_dict(converted_state_dict, strict=strict)
|
| 152 |
+
|
| 153 |
+
def last_layer(self):
|
| 154 |
+
if hasattr(self.decoder, "conv_out"):
|
| 155 |
+
if isinstance(self.decoder.conv_out, nn.Sequential):
|
| 156 |
+
last_layer = self.decoder.conv_out[-1]
|
| 157 |
+
else:
|
| 158 |
+
last_layer = self.decoder.conv_out
|
| 159 |
+
else:
|
| 160 |
+
last_layer = self.decoder.layers[-1]
|
| 161 |
+
return last_layer
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class Encoder(nn.Module):
|
| 165 |
+
r"""
|
| 166 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 170 |
+
The number of input channels.
|
| 171 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 172 |
+
The number of output channels.
|
| 173 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 174 |
+
The number of output channels for each block.
|
| 175 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 176 |
+
The number of layers per block.
|
| 177 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 178 |
+
The number of groups for normalization.
|
| 179 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 180 |
+
The patch size to use. Should be a power of 2.
|
| 181 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 182 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 183 |
+
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
| 184 |
+
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
dims: Union[int, Tuple[int, int]] = 3,
|
| 190 |
+
in_channels: int = 3,
|
| 191 |
+
out_channels: int = 3,
|
| 192 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 193 |
+
layers_per_block: int = 2,
|
| 194 |
+
norm_num_groups: int = 32,
|
| 195 |
+
patch_size: Union[int, Tuple[int]] = 1,
|
| 196 |
+
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
| 197 |
+
latent_log_var: str = "per_channel",
|
| 198 |
+
patch_size_t: Optional[int] = None,
|
| 199 |
+
add_channel_padding: Optional[bool] = False,
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.patch_size = patch_size
|
| 203 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
| 204 |
+
self.add_channel_padding = add_channel_padding
|
| 205 |
+
self.layers_per_block = layers_per_block
|
| 206 |
+
self.norm_layer = norm_layer
|
| 207 |
+
self.latent_channels = out_channels
|
| 208 |
+
self.latent_log_var = latent_log_var
|
| 209 |
+
if add_channel_padding:
|
| 210 |
+
in_channels = in_channels * self.patch_size**3
|
| 211 |
+
else:
|
| 212 |
+
in_channels = in_channels * self.patch_size_t * self.patch_size**2
|
| 213 |
+
self.in_channels = in_channels
|
| 214 |
+
output_channel = block_out_channels[0]
|
| 215 |
+
|
| 216 |
+
self.conv_in = make_conv_nd(
|
| 217 |
+
dims=dims,
|
| 218 |
+
in_channels=in_channels,
|
| 219 |
+
out_channels=output_channel,
|
| 220 |
+
kernel_size=3,
|
| 221 |
+
stride=1,
|
| 222 |
+
padding=1,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self.down_blocks = nn.ModuleList([])
|
| 226 |
+
|
| 227 |
+
for i in range(len(block_out_channels)):
|
| 228 |
+
input_channel = output_channel
|
| 229 |
+
output_channel = block_out_channels[i]
|
| 230 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 231 |
+
|
| 232 |
+
down_block = DownEncoderBlock3D(
|
| 233 |
+
dims=dims,
|
| 234 |
+
in_channels=input_channel,
|
| 235 |
+
out_channels=output_channel,
|
| 236 |
+
num_layers=self.layers_per_block,
|
| 237 |
+
add_downsample=not is_final_block and 2**i >= patch_size,
|
| 238 |
+
resnet_eps=1e-6,
|
| 239 |
+
downsample_padding=0,
|
| 240 |
+
resnet_groups=norm_num_groups,
|
| 241 |
+
norm_layer=norm_layer,
|
| 242 |
+
)
|
| 243 |
+
self.down_blocks.append(down_block)
|
| 244 |
+
|
| 245 |
+
self.mid_block = UNetMidBlock3D(
|
| 246 |
+
dims=dims,
|
| 247 |
+
in_channels=block_out_channels[-1],
|
| 248 |
+
num_layers=self.layers_per_block,
|
| 249 |
+
resnet_eps=1e-6,
|
| 250 |
+
resnet_groups=norm_num_groups,
|
| 251 |
+
norm_layer=norm_layer,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# out
|
| 255 |
+
if norm_layer == "group_norm":
|
| 256 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 257 |
+
elif norm_layer == "pixel_norm":
|
| 258 |
+
self.conv_norm_out = PixelNorm()
|
| 259 |
+
self.conv_act = nn.SiLU()
|
| 260 |
+
|
| 261 |
+
conv_out_channels = out_channels
|
| 262 |
+
if latent_log_var == "per_channel":
|
| 263 |
+
conv_out_channels *= 2
|
| 264 |
+
elif latent_log_var == "uniform":
|
| 265 |
+
conv_out_channels += 1
|
| 266 |
+
elif latent_log_var != "none":
|
| 267 |
+
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 268 |
+
self.conv_out = make_conv_nd(dims, block_out_channels[-1], conv_out_channels, 3, padding=1)
|
| 269 |
+
|
| 270 |
+
self.gradient_checkpointing = False
|
| 271 |
+
|
| 272 |
+
@property
|
| 273 |
+
def downscale_factor(self):
|
| 274 |
+
return (
|
| 275 |
+
2 ** len([block for block in self.down_blocks if isinstance(block.downsample, Downsample3D)])
|
| 276 |
+
* self.patch_size
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
| 280 |
+
r"""The forward method of the `Encoder` class."""
|
| 281 |
+
|
| 282 |
+
downsample_in_time = sample.shape[2] != 1
|
| 283 |
+
|
| 284 |
+
# patchify
|
| 285 |
+
patch_size_t = self.patch_size_t if downsample_in_time else 1
|
| 286 |
+
sample = patchify(
|
| 287 |
+
sample,
|
| 288 |
+
patch_size_hw=self.patch_size,
|
| 289 |
+
patch_size_t=patch_size_t,
|
| 290 |
+
add_channel_padding=self.add_channel_padding,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
sample = self.conv_in(sample)
|
| 294 |
+
|
| 295 |
+
checkpoint_fn = (
|
| 296 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 297 |
+
if self.gradient_checkpointing and self.training
|
| 298 |
+
else lambda x: x
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
for down_block in self.down_blocks:
|
| 302 |
+
sample = checkpoint_fn(down_block)(sample, downsample_in_time=downsample_in_time)
|
| 303 |
+
|
| 304 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
| 305 |
+
|
| 306 |
+
# post-process
|
| 307 |
+
sample = self.conv_norm_out(sample)
|
| 308 |
+
sample = self.conv_act(sample)
|
| 309 |
+
sample = self.conv_out(sample)
|
| 310 |
+
|
| 311 |
+
if self.latent_log_var == "uniform":
|
| 312 |
+
last_channel = sample[:, -1:, ...]
|
| 313 |
+
num_dims = sample.dim()
|
| 314 |
+
|
| 315 |
+
if num_dims == 4:
|
| 316 |
+
# For shape (B, C, H, W)
|
| 317 |
+
repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1)
|
| 318 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 319 |
+
elif num_dims == 5:
|
| 320 |
+
# For shape (B, C, F, H, W)
|
| 321 |
+
repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1)
|
| 322 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError(f"Invalid input shape: {sample.shape}")
|
| 325 |
+
|
| 326 |
+
return sample
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class Decoder(nn.Module):
|
| 330 |
+
r"""
|
| 331 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 335 |
+
The number of input channels.
|
| 336 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 337 |
+
The number of output channels.
|
| 338 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 339 |
+
The number of output channels for each block.
|
| 340 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 341 |
+
The number of layers per block.
|
| 342 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 343 |
+
The number of groups for normalization.
|
| 344 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 345 |
+
The patch size to use. Should be a power of 2.
|
| 346 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 347 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
def __init__(
|
| 351 |
+
self,
|
| 352 |
+
dims,
|
| 353 |
+
in_channels: int = 3,
|
| 354 |
+
out_channels: int = 3,
|
| 355 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 356 |
+
layers_per_block: int = 2,
|
| 357 |
+
norm_num_groups: int = 32,
|
| 358 |
+
patch_size: int = 1,
|
| 359 |
+
norm_layer: str = "group_norm",
|
| 360 |
+
patch_size_t: Optional[int] = None,
|
| 361 |
+
add_channel_padding: Optional[bool] = False,
|
| 362 |
+
):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.patch_size = patch_size
|
| 365 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
| 366 |
+
self.add_channel_padding = add_channel_padding
|
| 367 |
+
self.layers_per_block = layers_per_block
|
| 368 |
+
if add_channel_padding:
|
| 369 |
+
out_channels = out_channels * self.patch_size**3
|
| 370 |
+
else:
|
| 371 |
+
out_channels = out_channels * self.patch_size_t * self.patch_size**2
|
| 372 |
+
self.out_channels = out_channels
|
| 373 |
+
|
| 374 |
+
self.conv_in = make_conv_nd(
|
| 375 |
+
dims,
|
| 376 |
+
in_channels,
|
| 377 |
+
block_out_channels[-1],
|
| 378 |
+
kernel_size=3,
|
| 379 |
+
stride=1,
|
| 380 |
+
padding=1,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self.mid_block = None
|
| 384 |
+
self.up_blocks = nn.ModuleList([])
|
| 385 |
+
|
| 386 |
+
self.mid_block = UNetMidBlock3D(
|
| 387 |
+
dims=dims,
|
| 388 |
+
in_channels=block_out_channels[-1],
|
| 389 |
+
num_layers=self.layers_per_block,
|
| 390 |
+
resnet_eps=1e-6,
|
| 391 |
+
resnet_groups=norm_num_groups,
|
| 392 |
+
norm_layer=norm_layer,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 396 |
+
output_channel = reversed_block_out_channels[0]
|
| 397 |
+
for i in range(len(reversed_block_out_channels)):
|
| 398 |
+
prev_output_channel = output_channel
|
| 399 |
+
output_channel = reversed_block_out_channels[i]
|
| 400 |
+
|
| 401 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 402 |
+
|
| 403 |
+
up_block = UpDecoderBlock3D(
|
| 404 |
+
dims=dims,
|
| 405 |
+
num_layers=self.layers_per_block + 1,
|
| 406 |
+
in_channels=prev_output_channel,
|
| 407 |
+
out_channels=output_channel,
|
| 408 |
+
add_upsample=not is_final_block and 2 ** (len(block_out_channels) - i - 1) > patch_size,
|
| 409 |
+
resnet_eps=1e-6,
|
| 410 |
+
resnet_groups=norm_num_groups,
|
| 411 |
+
norm_layer=norm_layer,
|
| 412 |
+
)
|
| 413 |
+
self.up_blocks.append(up_block)
|
| 414 |
+
|
| 415 |
+
if norm_layer == "group_norm":
|
| 416 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
| 417 |
+
elif norm_layer == "pixel_norm":
|
| 418 |
+
self.conv_norm_out = PixelNorm()
|
| 419 |
+
|
| 420 |
+
self.conv_act = nn.SiLU()
|
| 421 |
+
self.conv_out = make_conv_nd(dims, block_out_channels[0], out_channels, 3, padding=1)
|
| 422 |
+
|
| 423 |
+
self.gradient_checkpointing = False
|
| 424 |
+
|
| 425 |
+
def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
| 426 |
+
r"""The forward method of the `Decoder` class."""
|
| 427 |
+
assert target_shape is not None, "target_shape must be provided"
|
| 428 |
+
upsample_in_time = sample.shape[2] < target_shape[2]
|
| 429 |
+
|
| 430 |
+
sample = self.conv_in(sample)
|
| 431 |
+
|
| 432 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 433 |
+
|
| 434 |
+
checkpoint_fn = (
|
| 435 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 436 |
+
if self.gradient_checkpointing and self.training
|
| 437 |
+
else lambda x: x
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
| 441 |
+
sample = sample.to(upscale_dtype)
|
| 442 |
+
|
| 443 |
+
for up_block in self.up_blocks:
|
| 444 |
+
sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
|
| 445 |
+
|
| 446 |
+
# post-process
|
| 447 |
+
sample = self.conv_norm_out(sample)
|
| 448 |
+
sample = self.conv_act(sample)
|
| 449 |
+
sample = self.conv_out(sample)
|
| 450 |
+
|
| 451 |
+
# un-patchify
|
| 452 |
+
patch_size_t = self.patch_size_t if upsample_in_time else 1
|
| 453 |
+
sample = unpatchify(
|
| 454 |
+
sample,
|
| 455 |
+
patch_size_hw=self.patch_size,
|
| 456 |
+
patch_size_t=patch_size_t,
|
| 457 |
+
add_channel_padding=self.add_channel_padding,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
return sample
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class DownEncoderBlock3D(nn.Module):
|
| 464 |
+
def __init__(
|
| 465 |
+
self,
|
| 466 |
+
dims: Union[int, Tuple[int, int]],
|
| 467 |
+
in_channels: int,
|
| 468 |
+
out_channels: int,
|
| 469 |
+
dropout: float = 0.0,
|
| 470 |
+
num_layers: int = 1,
|
| 471 |
+
resnet_eps: float = 1e-6,
|
| 472 |
+
resnet_groups: int = 32,
|
| 473 |
+
add_downsample: bool = True,
|
| 474 |
+
downsample_padding: int = 1,
|
| 475 |
+
norm_layer: str = "group_norm",
|
| 476 |
+
):
|
| 477 |
+
super().__init__()
|
| 478 |
+
res_blocks = []
|
| 479 |
+
|
| 480 |
+
for i in range(num_layers):
|
| 481 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 482 |
+
res_blocks.append(
|
| 483 |
+
ResnetBlock3D(
|
| 484 |
+
dims=dims,
|
| 485 |
+
in_channels=in_channels,
|
| 486 |
+
out_channels=out_channels,
|
| 487 |
+
eps=resnet_eps,
|
| 488 |
+
groups=resnet_groups,
|
| 489 |
+
dropout=dropout,
|
| 490 |
+
norm_layer=norm_layer,
|
| 491 |
+
)
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 495 |
+
|
| 496 |
+
if add_downsample:
|
| 497 |
+
self.downsample = Downsample3D(dims, out_channels, out_channels=out_channels, padding=downsample_padding)
|
| 498 |
+
else:
|
| 499 |
+
self.downsample = Identity()
|
| 500 |
+
|
| 501 |
+
def forward(self, hidden_states: torch.FloatTensor, downsample_in_time) -> torch.FloatTensor:
|
| 502 |
+
for resnet in self.res_blocks:
|
| 503 |
+
hidden_states = resnet(hidden_states)
|
| 504 |
+
|
| 505 |
+
hidden_states = self.downsample(hidden_states, downsample_in_time=downsample_in_time)
|
| 506 |
+
|
| 507 |
+
return hidden_states
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class UNetMidBlock3D(nn.Module):
|
| 511 |
+
"""
|
| 512 |
+
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
in_channels (`int`): The number of input channels.
|
| 516 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 517 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 518 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 519 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
| 520 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 524 |
+
in_channels, height, width)`.
|
| 525 |
+
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
+
def __init__(
|
| 529 |
+
self,
|
| 530 |
+
dims: Union[int, Tuple[int, int]],
|
| 531 |
+
in_channels: int,
|
| 532 |
+
dropout: float = 0.0,
|
| 533 |
+
num_layers: int = 1,
|
| 534 |
+
resnet_eps: float = 1e-6,
|
| 535 |
+
resnet_groups: int = 32,
|
| 536 |
+
norm_layer: str = "group_norm",
|
| 537 |
+
):
|
| 538 |
+
super().__init__()
|
| 539 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 540 |
+
|
| 541 |
+
self.res_blocks = nn.ModuleList(
|
| 542 |
+
[
|
| 543 |
+
ResnetBlock3D(
|
| 544 |
+
dims=dims,
|
| 545 |
+
in_channels=in_channels,
|
| 546 |
+
out_channels=in_channels,
|
| 547 |
+
eps=resnet_eps,
|
| 548 |
+
groups=resnet_groups,
|
| 549 |
+
dropout=dropout,
|
| 550 |
+
norm_layer=norm_layer,
|
| 551 |
+
)
|
| 552 |
+
for _ in range(num_layers)
|
| 553 |
+
]
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 557 |
+
for resnet in self.res_blocks:
|
| 558 |
+
hidden_states = resnet(hidden_states)
|
| 559 |
+
|
| 560 |
+
return hidden_states
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class UpDecoderBlock3D(nn.Module):
|
| 564 |
+
def __init__(
|
| 565 |
+
self,
|
| 566 |
+
dims: Union[int, Tuple[int, int]],
|
| 567 |
+
in_channels: int,
|
| 568 |
+
out_channels: int,
|
| 569 |
+
resolution_idx: Optional[int] = None,
|
| 570 |
+
dropout: float = 0.0,
|
| 571 |
+
num_layers: int = 1,
|
| 572 |
+
resnet_eps: float = 1e-6,
|
| 573 |
+
resnet_groups: int = 32,
|
| 574 |
+
add_upsample: bool = True,
|
| 575 |
+
norm_layer: str = "group_norm",
|
| 576 |
+
):
|
| 577 |
+
super().__init__()
|
| 578 |
+
res_blocks = []
|
| 579 |
+
|
| 580 |
+
for i in range(num_layers):
|
| 581 |
+
input_channels = in_channels if i == 0 else out_channels
|
| 582 |
+
|
| 583 |
+
res_blocks.append(
|
| 584 |
+
ResnetBlock3D(
|
| 585 |
+
dims=dims,
|
| 586 |
+
in_channels=input_channels,
|
| 587 |
+
out_channels=out_channels,
|
| 588 |
+
eps=resnet_eps,
|
| 589 |
+
groups=resnet_groups,
|
| 590 |
+
dropout=dropout,
|
| 591 |
+
norm_layer=norm_layer,
|
| 592 |
+
)
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 596 |
+
|
| 597 |
+
if add_upsample:
|
| 598 |
+
self.upsample = Upsample3D(dims=dims, channels=out_channels, out_channels=out_channels)
|
| 599 |
+
else:
|
| 600 |
+
self.upsample = Identity()
|
| 601 |
+
|
| 602 |
+
self.resolution_idx = resolution_idx
|
| 603 |
+
|
| 604 |
+
def forward(self, hidden_states: torch.FloatTensor, upsample_in_time=True) -> torch.FloatTensor:
|
| 605 |
+
for resnet in self.res_blocks:
|
| 606 |
+
hidden_states = resnet(hidden_states)
|
| 607 |
+
|
| 608 |
+
hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
|
| 609 |
+
|
| 610 |
+
return hidden_states
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class ResnetBlock3D(nn.Module):
|
| 614 |
+
r"""
|
| 615 |
+
A Resnet block.
|
| 616 |
+
|
| 617 |
+
Parameters:
|
| 618 |
+
in_channels (`int`): The number of channels in the input.
|
| 619 |
+
out_channels (`int`, *optional*, default to be `None`):
|
| 620 |
+
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
| 621 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
| 622 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
| 623 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
def __init__(
|
| 627 |
+
self,
|
| 628 |
+
dims: Union[int, Tuple[int, int]],
|
| 629 |
+
in_channels: int,
|
| 630 |
+
out_channels: Optional[int] = None,
|
| 631 |
+
conv_shortcut: bool = False,
|
| 632 |
+
dropout: float = 0.0,
|
| 633 |
+
groups: int = 32,
|
| 634 |
+
eps: float = 1e-6,
|
| 635 |
+
norm_layer: str = "group_norm",
|
| 636 |
+
):
|
| 637 |
+
super().__init__()
|
| 638 |
+
self.in_channels = in_channels
|
| 639 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 640 |
+
self.out_channels = out_channels
|
| 641 |
+
self.use_conv_shortcut = conv_shortcut
|
| 642 |
+
|
| 643 |
+
if norm_layer == "group_norm":
|
| 644 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 645 |
+
elif norm_layer == "pixel_norm":
|
| 646 |
+
self.norm1 = PixelNorm()
|
| 647 |
+
|
| 648 |
+
self.non_linearity = nn.SiLU()
|
| 649 |
+
|
| 650 |
+
self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 651 |
+
|
| 652 |
+
if norm_layer == "group_norm":
|
| 653 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| 654 |
+
elif norm_layer == "pixel_norm":
|
| 655 |
+
self.norm2 = PixelNorm()
|
| 656 |
+
|
| 657 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 658 |
+
|
| 659 |
+
self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 660 |
+
|
| 661 |
+
self.conv_shortcut = (
|
| 662 |
+
make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
|
| 663 |
+
if in_channels != out_channels
|
| 664 |
+
else nn.Identity()
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
def forward(
|
| 668 |
+
self,
|
| 669 |
+
input_tensor: torch.FloatTensor,
|
| 670 |
+
) -> torch.FloatTensor:
|
| 671 |
+
hidden_states = input_tensor
|
| 672 |
+
|
| 673 |
+
hidden_states = self.norm1(hidden_states)
|
| 674 |
+
|
| 675 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 676 |
+
|
| 677 |
+
hidden_states = self.conv1(hidden_states)
|
| 678 |
+
|
| 679 |
+
hidden_states = self.norm2(hidden_states)
|
| 680 |
+
|
| 681 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 682 |
+
|
| 683 |
+
hidden_states = self.dropout(hidden_states)
|
| 684 |
+
|
| 685 |
+
hidden_states = self.conv2(hidden_states)
|
| 686 |
+
|
| 687 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 688 |
+
|
| 689 |
+
output_tensor = input_tensor + hidden_states
|
| 690 |
+
|
| 691 |
+
return output_tensor
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
class Downsample3D(nn.Module):
|
| 695 |
+
def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
|
| 696 |
+
super().__init__()
|
| 697 |
+
stride: int = 2
|
| 698 |
+
self.padding = padding
|
| 699 |
+
self.in_channels = in_channels
|
| 700 |
+
self.dims = dims
|
| 701 |
+
self.conv = make_conv_nd(
|
| 702 |
+
dims=dims,
|
| 703 |
+
in_channels=in_channels,
|
| 704 |
+
out_channels=out_channels,
|
| 705 |
+
kernel_size=kernel_size,
|
| 706 |
+
stride=stride,
|
| 707 |
+
padding=padding,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
def forward(self, x, downsample_in_time=True):
|
| 711 |
+
conv = self.conv
|
| 712 |
+
if self.padding == 0:
|
| 713 |
+
if self.dims == 2:
|
| 714 |
+
padding = (0, 1, 0, 1)
|
| 715 |
+
else:
|
| 716 |
+
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
|
| 717 |
+
|
| 718 |
+
x = functional.pad(x, padding, mode="constant", value=0)
|
| 719 |
+
|
| 720 |
+
if self.dims == (2, 1) and not downsample_in_time:
|
| 721 |
+
return conv(x, skip_time_conv=True)
|
| 722 |
+
|
| 723 |
+
return conv(x)
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
class Upsample3D(nn.Module):
|
| 727 |
+
"""
|
| 728 |
+
An upsampling layer for 3D tensors of shape (B, C, D, H, W).
|
| 729 |
+
|
| 730 |
+
:param channels: channels in the inputs and outputs.
|
| 731 |
+
"""
|
| 732 |
+
|
| 733 |
+
def __init__(self, dims, channels, out_channels=None):
|
| 734 |
+
super().__init__()
|
| 735 |
+
self.dims = dims
|
| 736 |
+
self.channels = channels
|
| 737 |
+
self.out_channels = out_channels or channels
|
| 738 |
+
self.conv = make_conv_nd(dims, channels, out_channels, kernel_size=3, padding=1, bias=True)
|
| 739 |
+
|
| 740 |
+
def forward(self, x, upsample_in_time):
|
| 741 |
+
if self.dims == 2:
|
| 742 |
+
x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
|
| 743 |
+
else:
|
| 744 |
+
time_scale_factor = 2 if upsample_in_time else 1
|
| 745 |
+
# print("before:", x.shape)
|
| 746 |
+
b, c, d, h, w = x.shape
|
| 747 |
+
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 748 |
+
# height and width interpolate
|
| 749 |
+
x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
|
| 750 |
+
_, _, h, w = x.shape
|
| 751 |
+
|
| 752 |
+
if not upsample_in_time and self.dims == (2, 1):
|
| 753 |
+
x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
|
| 754 |
+
return self.conv(x, skip_time_conv=True)
|
| 755 |
+
|
| 756 |
+
# Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
|
| 757 |
+
x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
|
| 758 |
+
|
| 759 |
+
# (b h w) c 1 d
|
| 760 |
+
new_d = x.shape[-1] * time_scale_factor
|
| 761 |
+
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
| 762 |
+
# (b h w) c 1 new_d
|
| 763 |
+
x = rearrange(x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d)
|
| 764 |
+
# b c d h w
|
| 765 |
+
|
| 766 |
+
# x = functional.interpolate(
|
| 767 |
+
# x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 768 |
+
# )
|
| 769 |
+
# print("after:", x.shape)
|
| 770 |
+
|
| 771 |
+
return self.conv(x)
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
| 775 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 776 |
+
return x
|
| 777 |
+
if x.dim() == 4:
|
| 778 |
+
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw)
|
| 779 |
+
elif x.dim() == 5:
|
| 780 |
+
x = rearrange(x, "b c (f p) (h q) (w r) -> b (c p r q) f h w", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
|
| 781 |
+
else:
|
| 782 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 783 |
+
|
| 784 |
+
if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
|
| 785 |
+
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
| 786 |
+
padding_zeros = torch.zeros(
|
| 787 |
+
x.shape[0],
|
| 788 |
+
channels_to_pad,
|
| 789 |
+
x.shape[2],
|
| 790 |
+
x.shape[3],
|
| 791 |
+
x.shape[4],
|
| 792 |
+
device=x.device,
|
| 793 |
+
dtype=x.dtype,
|
| 794 |
+
)
|
| 795 |
+
x = torch.cat([padding_zeros, x], dim=1)
|
| 796 |
+
|
| 797 |
+
return x
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
| 801 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 802 |
+
return x
|
| 803 |
+
|
| 804 |
+
if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
|
| 805 |
+
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
| 806 |
+
x = x[:, :channels_to_keep, :, :, :]
|
| 807 |
+
|
| 808 |
+
if x.dim() == 4:
|
| 809 |
+
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw)
|
| 810 |
+
elif x.dim() == 5:
|
| 811 |
+
x = rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
|
| 812 |
+
|
| 813 |
+
return x
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def create_video_autoencoder_config(
|
| 817 |
+
latent_channels: int = 4,
|
| 818 |
+
):
|
| 819 |
+
config = {
|
| 820 |
+
"_class_name": "VideoAutoencoder",
|
| 821 |
+
"dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 822 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 823 |
+
"out_channels": 3, # Number of output color channels
|
| 824 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 825 |
+
"block_out_channels": [128, 256, 512, 512], # Number of output channels of each encoder / decoder inner block
|
| 826 |
+
"patch_size": 1,
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
return config
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
def create_video_autoencoder_pathify4x4x4_config(
|
| 833 |
+
latent_channels: int = 4,
|
| 834 |
+
):
|
| 835 |
+
config = {
|
| 836 |
+
"_class_name": "VideoAutoencoder",
|
| 837 |
+
"dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 838 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 839 |
+
"out_channels": 3, # Number of output color channels
|
| 840 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 841 |
+
"block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
|
| 842 |
+
"patch_size": 4,
|
| 843 |
+
"latent_log_var": "uniform",
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
return config
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
def create_video_autoencoder_pathify4x4_config(
|
| 850 |
+
latent_channels: int = 4,
|
| 851 |
+
):
|
| 852 |
+
config = {
|
| 853 |
+
"_class_name": "VideoAutoencoder",
|
| 854 |
+
"dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 855 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 856 |
+
"out_channels": 3, # Number of output color channels
|
| 857 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 858 |
+
"block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
|
| 859 |
+
"patch_size": 4,
|
| 860 |
+
"norm_layer": "pixel_norm",
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
return config
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
def test_vae_patchify_unpatchify():
|
| 867 |
+
import torch
|
| 868 |
+
|
| 869 |
+
x = torch.randn(2, 3, 8, 64, 64)
|
| 870 |
+
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
|
| 871 |
+
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
|
| 872 |
+
assert torch.allclose(x, x_unpatched)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def demo_video_autoencoder_forward_backward():
|
| 876 |
+
# Configuration for the VideoAutoencoder
|
| 877 |
+
config = create_video_autoencoder_pathify4x4x4_config()
|
| 878 |
+
|
| 879 |
+
# Instantiate the VideoAutoencoder with the specified configuration
|
| 880 |
+
video_autoencoder = VideoAutoencoder.from_config(config)
|
| 881 |
+
|
| 882 |
+
print(video_autoencoder)
|
| 883 |
+
|
| 884 |
+
# Print the total number of parameters in the video autoencoder
|
| 885 |
+
total_params = sum(p.numel() for p in video_autoencoder.parameters())
|
| 886 |
+
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
|
| 887 |
+
|
| 888 |
+
# Create a mock input tensor simulating a batch of videos
|
| 889 |
+
# Shape: (batch_size, channels, depth, height, width)
|
| 890 |
+
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
|
| 891 |
+
input_videos = torch.randn(2, 3, 8, 64, 64)
|
| 892 |
+
|
| 893 |
+
# Forward pass: encode and decode the input videos
|
| 894 |
+
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
| 895 |
+
print(f"input shape={input_videos.shape}")
|
| 896 |
+
print(f"latent shape={latent.shape}")
|
| 897 |
+
reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample
|
| 898 |
+
|
| 899 |
+
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 900 |
+
|
| 901 |
+
# Calculate the loss (e.g., mean squared error)
|
| 902 |
+
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
|
| 903 |
+
|
| 904 |
+
# Perform backward pass
|
| 905 |
+
loss.backward()
|
| 906 |
+
|
| 907 |
+
print(f"Demo completed with loss: {loss.item()}")
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
# Ensure to call the demo function to execute the forward and backward pass
|
| 911 |
+
if __name__ == "__main__":
|
| 912 |
+
demo_video_autoencoder_forward_backward()
|
xora/models/transformers/embeddings.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_timestep_embedding(
|
| 11 |
+
timesteps: torch.Tensor,
|
| 12 |
+
embedding_dim: int,
|
| 13 |
+
flip_sin_to_cos: bool = False,
|
| 14 |
+
downscale_freq_shift: float = 1,
|
| 15 |
+
scale: float = 1,
|
| 16 |
+
max_period: int = 10000,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 20 |
+
|
| 21 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 22 |
+
These may be fractional.
|
| 23 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
| 24 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
| 25 |
+
"""
|
| 26 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 27 |
+
|
| 28 |
+
half_dim = embedding_dim // 2
|
| 29 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 30 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 31 |
+
|
| 32 |
+
emb = torch.exp(exponent)
|
| 33 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 34 |
+
|
| 35 |
+
# scale embeddings
|
| 36 |
+
emb = scale * emb
|
| 37 |
+
|
| 38 |
+
# concat sine and cosine embeddings
|
| 39 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 40 |
+
|
| 41 |
+
# flip sine and cosine embeddings
|
| 42 |
+
if flip_sin_to_cos:
|
| 43 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 44 |
+
|
| 45 |
+
# zero pad
|
| 46 |
+
if embedding_dim % 2 == 1:
|
| 47 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
|
| 52 |
+
"""
|
| 53 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
| 54 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 55 |
+
"""
|
| 56 |
+
grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
|
| 57 |
+
grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
|
| 58 |
+
grid = grid.reshape([3, 1, w, h, f])
|
| 59 |
+
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 60 |
+
pos_embed = pos_embed.transpose(1, 0, 2, 3)
|
| 61 |
+
return rearrange(pos_embed, "h w f c -> (f h w) c")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 65 |
+
if embed_dim % 3 != 0:
|
| 66 |
+
raise ValueError("embed_dim must be divisible by 3")
|
| 67 |
+
|
| 68 |
+
# use half of dimensions to encode grid_h
|
| 69 |
+
emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
|
| 70 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
|
| 71 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
|
| 72 |
+
|
| 73 |
+
emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
|
| 74 |
+
return emb
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 78 |
+
"""
|
| 79 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 80 |
+
"""
|
| 81 |
+
if embed_dim % 2 != 0:
|
| 82 |
+
raise ValueError("embed_dim must be divisible by 2")
|
| 83 |
+
|
| 84 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 85 |
+
omega /= embed_dim / 2.0
|
| 86 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 87 |
+
|
| 88 |
+
pos_shape = pos.shape
|
| 89 |
+
|
| 90 |
+
pos = pos.reshape(-1)
|
| 91 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 92 |
+
out = out.reshape([*pos_shape, -1])[0]
|
| 93 |
+
|
| 94 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 95 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 96 |
+
|
| 97 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
|
| 98 |
+
return emb
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
| 102 |
+
"""Apply positional information to a sequence of embeddings.
|
| 103 |
+
|
| 104 |
+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
| 105 |
+
them
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
embed_dim: (int): Dimension of the positional embedding.
|
| 109 |
+
max_seq_length: Maximum sequence length to apply positional embeddings
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
| 114 |
+
super().__init__()
|
| 115 |
+
position = torch.arange(max_seq_length).unsqueeze(1)
|
| 116 |
+
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
|
| 117 |
+
pe = torch.zeros(1, max_seq_length, embed_dim)
|
| 118 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 119 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 120 |
+
self.register_buffer("pe", pe)
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
_, seq_length, _ = x.shape
|
| 124 |
+
x = x + self.pe[:, :seq_length]
|
| 125 |
+
return x
|
xora/models/transformers/transformer3d.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
|
| 2 |
import math
|
| 3 |
from dataclasses import dataclass
|
| 4 |
-
from typing import Any, Dict, List, Optional
|
| 5 |
|
| 6 |
import torch
|
| 7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
@@ -9,10 +9,13 @@ from diffusers.models.embeddings import PixArtAlphaTextProjection
|
|
| 9 |
from diffusers.models.modeling_utils import ModelMixin
|
| 10 |
from diffusers.models.normalization import AdaLayerNormSingle
|
| 11 |
from diffusers.utils import BaseOutput, is_torch_version
|
|
|
|
| 12 |
from torch import nn
|
| 13 |
|
| 14 |
from xora.models.transformers.attention import BasicTransformerBlock
|
|
|
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
@dataclass
|
| 18 |
class Transformer3DModelOutput(BaseOutput):
|
|
@@ -143,6 +146,61 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 143 |
|
| 144 |
self.gradient_checkpointing = False
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 147 |
if hasattr(module, "gradient_checkpointing"):
|
| 148 |
module.gradient_checkpointing = value
|
|
@@ -287,10 +345,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 287 |
if self.timestep_scale_multiplier:
|
| 288 |
timestep = self.timestep_scale_multiplier * timestep
|
| 289 |
|
| 290 |
-
if self.positional_embedding_type == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
| 292 |
-
else:
|
| 293 |
-
raise NotImplementedError("Only rope pos embed supported.")
|
| 294 |
|
| 295 |
batch_size = hidden_states.shape[0]
|
| 296 |
timestep, embedded_timestep = self.adaln_single(
|
|
@@ -358,3 +420,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 358 |
|
| 359 |
return Transformer3DModelOutput(sample=hidden_states)
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
|
| 2 |
import math
|
| 3 |
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Literal
|
| 5 |
|
| 6 |
import torch
|
| 7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
|
|
| 9 |
from diffusers.models.modeling_utils import ModelMixin
|
| 10 |
from diffusers.models.normalization import AdaLayerNormSingle
|
| 11 |
from diffusers.utils import BaseOutput, is_torch_version
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
from torch import nn
|
| 14 |
|
| 15 |
from xora.models.transformers.attention import BasicTransformerBlock
|
| 16 |
+
from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
|
| 17 |
|
| 18 |
+
logger = logging.get_logger(__name__)
|
| 19 |
|
| 20 |
@dataclass
|
| 21 |
class Transformer3DModelOutput(BaseOutput):
|
|
|
|
| 146 |
|
| 147 |
self.gradient_checkpointing = False
|
| 148 |
|
| 149 |
+
def set_use_tpu_flash_attention(self):
|
| 150 |
+
r"""
|
| 151 |
+
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 152 |
+
attention kernel.
|
| 153 |
+
"""
|
| 154 |
+
logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
|
| 155 |
+
# if using TPU -> configure components to use TPU flash attention
|
| 156 |
+
if dist_util.acceleration_type() == dist_util.AccelerationType.TPU:
|
| 157 |
+
self.use_tpu_flash_attention = True
|
| 158 |
+
# push config down to the attention modules
|
| 159 |
+
for block in self.transformer_blocks:
|
| 160 |
+
block.set_use_tpu_flash_attention()
|
| 161 |
+
|
| 162 |
+
def initialize(self, embedding_std: float, mode: Literal["xora", "pixart"]):
|
| 163 |
+
def _basic_init(module):
|
| 164 |
+
if isinstance(module, nn.Linear):
|
| 165 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 166 |
+
if module.bias is not None:
|
| 167 |
+
nn.init.constant_(module.bias, 0)
|
| 168 |
+
|
| 169 |
+
self.apply(_basic_init)
|
| 170 |
+
|
| 171 |
+
# Initialize timestep embedding MLP:
|
| 172 |
+
nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std)
|
| 173 |
+
nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std)
|
| 174 |
+
nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
|
| 175 |
+
|
| 176 |
+
if hasattr(self.adaln_single.emb, "resolution_embedder"):
|
| 177 |
+
nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_1.weight, std=embedding_std)
|
| 178 |
+
nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_2.weight, std=embedding_std)
|
| 179 |
+
if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
|
| 180 |
+
nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight, std=embedding_std)
|
| 181 |
+
nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight, std=embedding_std)
|
| 182 |
+
|
| 183 |
+
# Initialize caption embedding MLP:
|
| 184 |
+
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
| 185 |
+
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
| 186 |
+
|
| 187 |
+
# Zero-out adaLN modulation layers in PixArt blocks:
|
| 188 |
+
for block in self.transformer_blocks:
|
| 189 |
+
if mode == "xora":
|
| 190 |
+
nn.init.constant_(block.attn1.to_out[0].weight, 0)
|
| 191 |
+
nn.init.constant_(block.attn1.to_out[0].bias, 0)
|
| 192 |
+
|
| 193 |
+
nn.init.constant_(block.attn2.to_out[0].weight, 0)
|
| 194 |
+
nn.init.constant_(block.attn2.to_out[0].bias, 0)
|
| 195 |
+
|
| 196 |
+
if mode == "xora":
|
| 197 |
+
nn.init.constant_(block.ff.net[2].weight, 0)
|
| 198 |
+
nn.init.constant_(block.ff.net[2].bias, 0)
|
| 199 |
+
|
| 200 |
+
# Zero-out output layers:
|
| 201 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
| 202 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
| 203 |
+
|
| 204 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 205 |
if hasattr(module, "gradient_checkpointing"):
|
| 206 |
module.gradient_checkpointing = value
|
|
|
|
| 345 |
if self.timestep_scale_multiplier:
|
| 346 |
timestep = self.timestep_scale_multiplier * timestep
|
| 347 |
|
| 348 |
+
if self.positional_embedding_type == "absolute":
|
| 349 |
+
pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(hidden_states.device)
|
| 350 |
+
if self.project_to_2d_pos:
|
| 351 |
+
pos_embed = self.to_2d_proj(pos_embed_3d)
|
| 352 |
+
hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
|
| 353 |
+
freqs_cis = None
|
| 354 |
+
elif self.positional_embedding_type == "rope":
|
| 355 |
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
|
|
|
|
|
|
| 356 |
|
| 357 |
batch_size = hidden_states.shape[0]
|
| 358 |
timestep, embedded_timestep = self.adaln_single(
|
|
|
|
| 420 |
|
| 421 |
return Transformer3DModelOutput(sample=hidden_states)
|
| 422 |
|
| 423 |
+
def get_absolute_pos_embed(self, grid):
|
| 424 |
+
grid_np = grid[0].cpu().numpy()
|
| 425 |
+
embed_dim_3d = math.ceil((self.inner_dim / 2) * 3) if self.project_to_2d_pos else self.inner_dim
|
| 426 |
+
pos_embed = get_3d_sincos_pos_embed( # (f h w)
|
| 427 |
+
embed_dim_3d,
|
| 428 |
+
grid_np,
|
| 429 |
+
h=int(max(grid_np[1]) + 1),
|
| 430 |
+
w=int(max(grid_np[2]) + 1),
|
| 431 |
+
f=int(max(grid_np[0] + 1)),
|
| 432 |
+
)
|
| 433 |
+
return torch.from_numpy(pos_embed).float().unsqueeze(0)
|
xora/pipelines/pipeline_video_pixart_alpha.py
CHANGED
|
@@ -32,16 +32,106 @@ from xora.models.transformers.symmetric_patchifier import Patchifier
|
|
| 32 |
from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
|
| 33 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 34 |
from xora.schedulers.rf import TimestepShifter
|
|
|
|
| 35 |
|
| 36 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 37 |
|
| 38 |
-
|
| 39 |
if is_bs4_available():
|
| 40 |
from bs4 import BeautifulSoup
|
| 41 |
|
| 42 |
if is_ftfy_available():
|
| 43 |
import ftfy
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def retrieve_timesteps(
|
| 46 |
scheduler,
|
| 47 |
num_inference_steps: Optional[int] = None,
|
|
@@ -520,14 +610,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 520 |
|
| 521 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 522 |
def prepare_latents(
|
| 523 |
-
self,
|
| 524 |
-
batch_size,
|
| 525 |
-
num_latent_channels,
|
| 526 |
-
num_patches,
|
| 527 |
-
dtype,
|
| 528 |
-
device,
|
| 529 |
-
generator,
|
| 530 |
-
latents=None,
|
| 531 |
):
|
| 532 |
shape = (
|
| 533 |
batch_size,
|
|
@@ -543,6 +626,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 543 |
|
| 544 |
if latents is None:
|
| 545 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
| 546 |
else:
|
| 547 |
latents = latents.to(device)
|
| 548 |
|
|
@@ -582,8 +668,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 582 |
|
| 583 |
return samples
|
| 584 |
|
| 585 |
-
|
| 586 |
@torch.no_grad()
|
|
|
|
| 587 |
def __call__(
|
| 588 |
self,
|
| 589 |
height: int,
|
|
@@ -607,6 +693,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 607 |
return_dict: bool = True,
|
| 608 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 609 |
clean_caption: bool = True,
|
|
|
|
| 610 |
**kwargs,
|
| 611 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 612 |
"""
|
|
@@ -736,8 +823,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 736 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 737 |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
| 738 |
|
| 739 |
-
#
|
| 740 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
latent_height = height // self.vae_scale_factor
|
| 742 |
latent_width = width // self.vae_scale_factor
|
| 743 |
latent_num_frames = num_frames // self.video_scale_factor
|
|
@@ -752,7 +846,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 752 |
dtype=prompt_embeds.dtype,
|
| 753 |
device=device,
|
| 754 |
generator=generator,
|
|
|
|
|
|
|
| 755 |
)
|
|
|
|
|
|
|
|
|
|
| 756 |
|
| 757 |
# 5. Prepare timesteps
|
| 758 |
retrieve_timesteps_kwargs = {}
|
|
@@ -790,7 +889,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 790 |
elif len(current_timestep.shape) == 0:
|
| 791 |
current_timestep = current_timestep[None].to(latent_model_input.device)
|
| 792 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 793 |
-
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
| 794 |
scale_grid = (
|
| 795 |
(1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
|
| 796 |
if self.transformer.use_rope
|
|
@@ -805,6 +904,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 805 |
device=latents.device,
|
| 806 |
)
|
| 807 |
|
|
|
|
|
|
|
|
|
|
| 808 |
# predict noise model_output
|
| 809 |
noise_pred = self.transformer(
|
| 810 |
latent_model_input.to(self.transformer.dtype),
|
|
@@ -819,13 +921,20 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 819 |
if do_classifier_free_guidance:
|
| 820 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 821 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
| 822 |
|
| 823 |
# learned sigma
|
| 824 |
if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
|
| 825 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
| 826 |
|
| 827 |
# compute previous image: x_t -> x_t-1
|
| 828 |
-
latents = self.scheduler.step(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
# call the callback, if provided
|
| 831 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
@@ -857,3 +966,62 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 857 |
return (image,)
|
| 858 |
|
| 859 |
return ImagePipelineOutput(images=image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
|
| 33 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 34 |
from xora.schedulers.rf import TimestepShifter
|
| 35 |
+
from xora.utils.conditioning_method import ConditioningMethod
|
| 36 |
|
| 37 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
|
|
|
|
| 39 |
if is_bs4_available():
|
| 40 |
from bs4 import BeautifulSoup
|
| 41 |
|
| 42 |
if is_ftfy_available():
|
| 43 |
import ftfy
|
| 44 |
|
| 45 |
+
EXAMPLE_DOC_STRING = """
|
| 46 |
+
Examples:
|
| 47 |
+
```py
|
| 48 |
+
>>> import torch
|
| 49 |
+
>>> from diffusers import PixArtAlphaPipeline
|
| 50 |
+
|
| 51 |
+
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
| 52 |
+
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
| 53 |
+
>>> # Enable memory optimizations.
|
| 54 |
+
>>> pipe.enable_model_cpu_offload()
|
| 55 |
+
|
| 56 |
+
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
| 57 |
+
>>> image = pipe(prompt).images[0]
|
| 58 |
+
```
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
ASPECT_RATIO_1024_BIN = {
|
| 62 |
+
"0.25": [512.0, 2048.0],
|
| 63 |
+
"0.28": [512.0, 1856.0],
|
| 64 |
+
"0.32": [576.0, 1792.0],
|
| 65 |
+
"0.33": [576.0, 1728.0],
|
| 66 |
+
"0.35": [576.0, 1664.0],
|
| 67 |
+
"0.4": [640.0, 1600.0],
|
| 68 |
+
"0.42": [640.0, 1536.0],
|
| 69 |
+
"0.48": [704.0, 1472.0],
|
| 70 |
+
"0.5": [704.0, 1408.0],
|
| 71 |
+
"0.52": [704.0, 1344.0],
|
| 72 |
+
"0.57": [768.0, 1344.0],
|
| 73 |
+
"0.6": [768.0, 1280.0],
|
| 74 |
+
"0.68": [832.0, 1216.0],
|
| 75 |
+
"0.72": [832.0, 1152.0],
|
| 76 |
+
"0.78": [896.0, 1152.0],
|
| 77 |
+
"0.82": [896.0, 1088.0],
|
| 78 |
+
"0.88": [960.0, 1088.0],
|
| 79 |
+
"0.94": [960.0, 1024.0],
|
| 80 |
+
"1.0": [1024.0, 1024.0],
|
| 81 |
+
"1.07": [1024.0, 960.0],
|
| 82 |
+
"1.13": [1088.0, 960.0],
|
| 83 |
+
"1.21": [1088.0, 896.0],
|
| 84 |
+
"1.29": [1152.0, 896.0],
|
| 85 |
+
"1.38": [1152.0, 832.0],
|
| 86 |
+
"1.46": [1216.0, 832.0],
|
| 87 |
+
"1.67": [1280.0, 768.0],
|
| 88 |
+
"1.75": [1344.0, 768.0],
|
| 89 |
+
"2.0": [1408.0, 704.0],
|
| 90 |
+
"2.09": [1472.0, 704.0],
|
| 91 |
+
"2.4": [1536.0, 640.0],
|
| 92 |
+
"2.5": [1600.0, 640.0],
|
| 93 |
+
"3.0": [1728.0, 576.0],
|
| 94 |
+
"4.0": [2048.0, 512.0],
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
ASPECT_RATIO_512_BIN = {
|
| 98 |
+
"0.25": [256.0, 1024.0],
|
| 99 |
+
"0.28": [256.0, 928.0],
|
| 100 |
+
"0.32": [288.0, 896.0],
|
| 101 |
+
"0.33": [288.0, 864.0],
|
| 102 |
+
"0.35": [288.0, 832.0],
|
| 103 |
+
"0.4": [320.0, 800.0],
|
| 104 |
+
"0.42": [320.0, 768.0],
|
| 105 |
+
"0.48": [352.0, 736.0],
|
| 106 |
+
"0.5": [352.0, 704.0],
|
| 107 |
+
"0.52": [352.0, 672.0],
|
| 108 |
+
"0.57": [384.0, 672.0],
|
| 109 |
+
"0.6": [384.0, 640.0],
|
| 110 |
+
"0.68": [416.0, 608.0],
|
| 111 |
+
"0.72": [416.0, 576.0],
|
| 112 |
+
"0.78": [448.0, 576.0],
|
| 113 |
+
"0.82": [448.0, 544.0],
|
| 114 |
+
"0.88": [480.0, 544.0],
|
| 115 |
+
"0.94": [480.0, 512.0],
|
| 116 |
+
"1.0": [512.0, 512.0],
|
| 117 |
+
"1.07": [512.0, 480.0],
|
| 118 |
+
"1.13": [544.0, 480.0],
|
| 119 |
+
"1.21": [544.0, 448.0],
|
| 120 |
+
"1.29": [576.0, 448.0],
|
| 121 |
+
"1.38": [576.0, 416.0],
|
| 122 |
+
"1.46": [608.0, 416.0],
|
| 123 |
+
"1.67": [640.0, 384.0],
|
| 124 |
+
"1.75": [672.0, 384.0],
|
| 125 |
+
"2.0": [704.0, 352.0],
|
| 126 |
+
"2.09": [736.0, 352.0],
|
| 127 |
+
"2.4": [768.0, 320.0],
|
| 128 |
+
"2.5": [800.0, 320.0],
|
| 129 |
+
"3.0": [864.0, 288.0],
|
| 130 |
+
"4.0": [1024.0, 256.0],
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 135 |
def retrieve_timesteps(
|
| 136 |
scheduler,
|
| 137 |
num_inference_steps: Optional[int] = None,
|
|
|
|
| 610 |
|
| 611 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 612 |
def prepare_latents(
|
| 613 |
+
self, batch_size, num_latent_channels, num_patches, dtype, device, generator, latents=None, latents_mask=None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
):
|
| 615 |
shape = (
|
| 616 |
batch_size,
|
|
|
|
| 626 |
|
| 627 |
if latents is None:
|
| 628 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 629 |
+
elif latents_mask is not None:
|
| 630 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 631 |
+
latents = latents * latents_mask[..., None] + noise * (1 - latents_mask[..., None])
|
| 632 |
else:
|
| 633 |
latents = latents.to(device)
|
| 634 |
|
|
|
|
| 668 |
|
| 669 |
return samples
|
| 670 |
|
|
|
|
| 671 |
@torch.no_grad()
|
| 672 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 673 |
def __call__(
|
| 674 |
self,
|
| 675 |
height: int,
|
|
|
|
| 693 |
return_dict: bool = True,
|
| 694 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 695 |
clean_caption: bool = True,
|
| 696 |
+
media_items: Optional[torch.FloatTensor] = None,
|
| 697 |
**kwargs,
|
| 698 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 699 |
"""
|
|
|
|
| 823 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 824 |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
| 825 |
|
| 826 |
+
# 3b. Encode and prepare conditioning data
|
| 827 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
| 828 |
+
conditioning_method = kwargs.get("conditioning_method", None)
|
| 829 |
+
vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
|
| 830 |
+
init_latents, conditioning_mask = self.prepare_conditioning(
|
| 831 |
+
media_items, num_frames, height, width, conditioning_method, vae_per_channel_normalize
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# 4. Prepare latents.
|
| 835 |
latent_height = height // self.vae_scale_factor
|
| 836 |
latent_width = width // self.vae_scale_factor
|
| 837 |
latent_num_frames = num_frames // self.video_scale_factor
|
|
|
|
| 846 |
dtype=prompt_embeds.dtype,
|
| 847 |
device=device,
|
| 848 |
generator=generator,
|
| 849 |
+
latents=init_latents,
|
| 850 |
+
latents_mask=conditioning_mask,
|
| 851 |
)
|
| 852 |
+
if conditioning_mask is not None and is_video:
|
| 853 |
+
assert num_images_per_prompt == 1
|
| 854 |
+
conditioning_mask = torch.cat([conditioning_mask] * 2) if do_classifier_free_guidance else conditioning_mask
|
| 855 |
|
| 856 |
# 5. Prepare timesteps
|
| 857 |
retrieve_timesteps_kwargs = {}
|
|
|
|
| 889 |
elif len(current_timestep.shape) == 0:
|
| 890 |
current_timestep = current_timestep[None].to(latent_model_input.device)
|
| 891 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 892 |
+
current_timestep = current_timestep.expand(latent_model_input.shape[0]).unsqueeze(-1)
|
| 893 |
scale_grid = (
|
| 894 |
(1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
|
| 895 |
if self.transformer.use_rope
|
|
|
|
| 904 |
device=latents.device,
|
| 905 |
)
|
| 906 |
|
| 907 |
+
if conditioning_mask is not None:
|
| 908 |
+
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 909 |
+
|
| 910 |
# predict noise model_output
|
| 911 |
noise_pred = self.transformer(
|
| 912 |
latent_model_input.to(self.transformer.dtype),
|
|
|
|
| 921 |
if do_classifier_free_guidance:
|
| 922 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 923 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 924 |
+
current_timestep, _ = current_timestep.chunk(2)
|
| 925 |
|
| 926 |
# learned sigma
|
| 927 |
if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
|
| 928 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
| 929 |
|
| 930 |
# compute previous image: x_t -> x_t-1
|
| 931 |
+
latents = self.scheduler.step(
|
| 932 |
+
noise_pred,
|
| 933 |
+
t if current_timestep is None else current_timestep,
|
| 934 |
+
latents,
|
| 935 |
+
**extra_step_kwargs,
|
| 936 |
+
return_dict=False,
|
| 937 |
+
)[0]
|
| 938 |
|
| 939 |
# call the callback, if provided
|
| 940 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
|
| 966 |
return (image,)
|
| 967 |
|
| 968 |
return ImagePipelineOutput(images=image)
|
| 969 |
+
|
| 970 |
+
def prepare_conditioning(
|
| 971 |
+
self,
|
| 972 |
+
media_items: torch.Tensor,
|
| 973 |
+
num_frames: int,
|
| 974 |
+
height: int,
|
| 975 |
+
width: int,
|
| 976 |
+
method: ConditioningMethod = ConditioningMethod.UNCONDITIONAL,
|
| 977 |
+
vae_per_channel_normalize: bool = False,
|
| 978 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 979 |
+
"""
|
| 980 |
+
Prepare the conditioning data for the video generation. If an input media item is provided, encode it
|
| 981 |
+
and set the conditioning_mask to indicate which tokens to condition on. Input media item should have
|
| 982 |
+
the same height and width as the generated video.
|
| 983 |
+
|
| 984 |
+
Args:
|
| 985 |
+
media_items (torch.Tensor): media items to condition on (images or videos)
|
| 986 |
+
num_frames (int): number of frames to generate
|
| 987 |
+
height (int): height of the generated video
|
| 988 |
+
width (int): width of the generated video
|
| 989 |
+
method (ConditioningMethod, optional): conditioning method to use. Defaults to ConditioningMethod.UNCONDITIONAL.
|
| 990 |
+
vae_per_channel_normalize (bool, optional): whether to normalize the input to the VAE per channel. Defaults to False.
|
| 991 |
+
|
| 992 |
+
Returns:
|
| 993 |
+
Tuple[torch.Tensor, torch.Tensor]: the conditioning latents and the conditioning mask
|
| 994 |
+
"""
|
| 995 |
+
if media_items is None or method == ConditioningMethod.UNCONDITIONAL:
|
| 996 |
+
return None, None
|
| 997 |
+
|
| 998 |
+
assert media_items.ndim == 5
|
| 999 |
+
assert height == media_items.shape[-2] and width == media_items.shape[-1]
|
| 1000 |
+
|
| 1001 |
+
# Encode the input video and repeat to the required number of frame-tokens
|
| 1002 |
+
init_latents = vae_encode(
|
| 1003 |
+
media_items.to(dtype=self.vae.dtype, device=self.vae.device),
|
| 1004 |
+
self.vae,
|
| 1005 |
+
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1006 |
+
).float()
|
| 1007 |
+
|
| 1008 |
+
init_len, target_len = init_latents.shape[2], num_frames // self.video_scale_factor
|
| 1009 |
+
if isinstance(self.vae, CausalVideoAutoencoder):
|
| 1010 |
+
target_len += 1
|
| 1011 |
+
init_latents = init_latents[:, :, :target_len]
|
| 1012 |
+
if target_len > init_len:
|
| 1013 |
+
repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
|
| 1014 |
+
init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[:, :, :target_len]
|
| 1015 |
+
|
| 1016 |
+
# Prepare the conditioning mask (1.0 = condition on this token)
|
| 1017 |
+
b, n, f, h, w = init_latents.shape
|
| 1018 |
+
conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
|
| 1019 |
+
if method in [ConditioningMethod.FIRST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
|
| 1020 |
+
conditioning_mask[:, :, 0] = 1.0
|
| 1021 |
+
if method in [ConditioningMethod.LAST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
|
| 1022 |
+
conditioning_mask[:, :, -1] = 1.0
|
| 1023 |
+
|
| 1024 |
+
# Patchify the init latents and the mask
|
| 1025 |
+
conditioning_mask = self.patchifier.patchify(conditioning_mask).squeeze(-1)
|
| 1026 |
+
init_latents = self.patchifier.patchify(latents=init_latents)
|
| 1027 |
+
return init_latents, conditioning_mask
|
xora/schedulers/rf.py
CHANGED
|
@@ -9,7 +9,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|
| 9 |
from diffusers.utils import BaseOutput
|
| 10 |
from torch import Tensor
|
| 11 |
|
| 12 |
-
from
|
| 13 |
|
| 14 |
|
| 15 |
def simple_diffusion_resolution_dependent_timestep_shift(
|
|
@@ -199,8 +199,17 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
| 199 |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 200 |
)
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
prev_sample = sample - dt * model_output
|
| 206 |
|
|
@@ -219,4 +228,4 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
| 219 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
| 220 |
alphas = 1 - sigmas
|
| 221 |
noisy_samples = alphas * original_samples + sigmas * noise
|
| 222 |
-
return noisy_samples
|
|
|
|
| 9 |
from diffusers.utils import BaseOutput
|
| 10 |
from torch import Tensor
|
| 11 |
|
| 12 |
+
from txt2img.common.torch_utils import append_dims
|
| 13 |
|
| 14 |
|
| 15 |
def simple_diffusion_resolution_dependent_timestep_shift(
|
|
|
|
| 199 |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 200 |
)
|
| 201 |
|
| 202 |
+
if timestep.ndim == 0:
|
| 203 |
+
# Global timestep
|
| 204 |
+
current_index = (self.timesteps - timestep).abs().argmin()
|
| 205 |
+
dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0))
|
| 206 |
+
else:
|
| 207 |
+
# Timestep per token
|
| 208 |
+
assert timestep.ndim == 2
|
| 209 |
+
current_index = (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
|
| 210 |
+
dt = self.delta_timesteps[current_index]
|
| 211 |
+
# Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
|
| 212 |
+
dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
|
| 213 |
|
| 214 |
prev_sample = sample - dt * model_output
|
| 215 |
|
|
|
|
| 228 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
| 229 |
alphas = 1 - sigmas
|
| 230 |
noisy_samples = alphas * original_samples + sigmas * noise
|
| 231 |
+
return noisy_samples
|
xora/utils/conditioning_method.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
class ConditioningMethod(Enum):
|
| 4 |
+
UNCONDITIONAL = "unconditional"
|
| 5 |
+
FIRST_FRAME = "first_frame"
|
| 6 |
+
LAST_FRAME = "last_frame"
|
| 7 |
+
FIRST_AND_LAST_FRAME = "first_and_last_frame"
|
xora/utils/dist_util.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
class AccelerationType(Enum):
|
| 4 |
+
CPU = "cpu"
|
| 5 |
+
GPU = "gpu"
|
| 6 |
+
TPU = "tpu"
|
| 7 |
+
MPS = "mps"
|
| 8 |
+
|
| 9 |
+
def execute_graph() -> None:
|
| 10 |
+
if _acceleration_type == AccelerationType.TPU:
|
| 11 |
+
xm.mark_step()
|