Spaces:
Paused
Paused
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
To account for certain changes pertaining to Q8Linear. | |
""" | |
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Union | |
import numpy as np | |
import torch | |
from transformers import T5EncoderModel, T5TokenizerFast | |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback | |
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin | |
from diffusers.models.autoencoders import AutoencoderKLLTXVideo | |
from diffusers.models.transformers import LTXVideoTransformer3DModel | |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring | |
from diffusers.utils.torch_utils import randn_tensor | |
from diffusers.video_processor import VideoProcessor | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput | |
try: | |
import q8_kernels # noqa | |
from q8_kernels.modules.linear import Q8Linear | |
except: | |
Q8Linear = None | |
if is_torch_xla_available(): | |
import torch_xla.core.xla_model as xm | |
XLA_AVAILABLE = True | |
else: | |
XLA_AVAILABLE = False | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> import torch | |
>>> from diffusers import LTXPipeline | |
>>> from diffusers.utils import export_to_video | |
>>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) | |
>>> pipe.to("cuda") | |
>>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" | |
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" | |
>>> video = pipe( | |
... prompt=prompt, | |
... negative_prompt=negative_prompt, | |
... width=704, | |
... height=480, | |
... num_frames=161, | |
... num_inference_steps=50, | |
... ).frames[0] | |
>>> export_to_video(video, "output.mp4", fps=24) | |
``` | |
""" | |
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift | |
def calculate_shift( | |
image_seq_len, | |
base_seq_len: int = 256, | |
max_seq_len: int = 4096, | |
base_shift: float = 0.5, | |
max_shift: float = 1.16, | |
): | |
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
b = base_shift - m * base_seq_len | |
mu = image_seq_len * m + b | |
return mu | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional[int] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
timesteps: Optional[List[int]] = None, | |
sigmas: Optional[List[float]] = None, | |
**kwargs, | |
): | |
r""" | |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
Args: | |
scheduler (`SchedulerMixin`): | |
The scheduler to get timesteps from. | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
must be `None`. | |
device (`str` or `torch.device`, *optional*): | |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
timesteps (`List[int]`, *optional*): | |
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
`num_inference_steps` and `sigmas` must be `None`. | |
sigmas (`List[float]`, *optional*): | |
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
`num_inference_steps` and `timesteps` must be `None`. | |
Returns: | |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
second element is the number of inference steps. | |
""" | |
if timesteps is not None and sigmas is not None: | |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
if timesteps is not None: | |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accepts_timesteps: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" timestep schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accept_sigmas: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" sigmas schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): | |
r""" | |
Pipeline for text-to-video generation. | |
Reference: https://github.com/Lightricks/LTX-Video | |
Args: | |
transformer ([`LTXVideoTransformer3DModel`]): | |
Conditional Transformer architecture to denoise the encoded video latents. | |
scheduler ([`FlowMatchEulerDiscreteScheduler`]): | |
A scheduler to be used in combination with `transformer` to denoise the encoded image latents. | |
vae ([`AutoencoderKLLTXVideo`]): | |
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
text_encoder ([`T5EncoderModel`]): | |
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically | |
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. | |
tokenizer (`CLIPTokenizer`): | |
Tokenizer of class | |
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). | |
tokenizer (`T5TokenizerFast`): | |
Second Tokenizer of class | |
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). | |
""" | |
model_cpu_offload_seq = "text_encoder->transformer->vae" | |
_optional_components = [] | |
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] | |
def __init__( | |
self, | |
scheduler: FlowMatchEulerDiscreteScheduler, | |
vae: AutoencoderKLLTXVideo, | |
text_encoder: T5EncoderModel, | |
tokenizer: T5TokenizerFast, | |
transformer: LTXVideoTransformer3DModel, | |
): | |
super().__init__() | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
transformer=transformer, | |
scheduler=scheduler, | |
) | |
self.vae_spatial_compression_ratio = ( | |
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 | |
) | |
self.vae_temporal_compression_ratio = ( | |
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 | |
) | |
self.transformer_spatial_patch_size = ( | |
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 | |
) | |
self.transformer_temporal_patch_size = ( | |
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 | |
) | |
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) | |
self.tokenizer_max_length = ( | |
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 | |
) | |
def _get_t5_prompt_embeds( | |
self, | |
prompt: Union[str, List[str]] = None, | |
num_videos_per_prompt: int = 1, | |
max_sequence_length: int = 128, | |
device: Optional[torch.device] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
device = device or self._execution_device | |
dtype = dtype or self.text_encoder.dtype | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
batch_size = len(prompt) | |
text_inputs = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_sequence_length, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_attention_mask = text_inputs.attention_mask | |
prompt_attention_mask = prompt_attention_mask.bool().to(device) | |
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): | |
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) | |
logger.warning( | |
"The following part of your input was truncated because `max_sequence_length` is set to " | |
f" {max_sequence_length} tokens: {removed_text}" | |
) | |
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] | |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
_, seq_len, _ = prompt_embeds.shape | |
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) | |
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) | |
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) | |
return prompt_embeds, prompt_attention_mask | |
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128 | |
def encode_prompt( | |
self, | |
prompt: Union[str, List[str]], | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
do_classifier_free_guidance: bool = True, | |
num_videos_per_prompt: int = 1, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
prompt_attention_mask: Optional[torch.Tensor] = None, | |
negative_prompt_attention_mask: Optional[torch.Tensor] = None, | |
max_sequence_length: int = 128, | |
device: Optional[torch.device] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
r""" | |
Encodes the prompt into text encoder hidden states. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
prompt to be encoded | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
less than `1`). | |
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): | |
Whether to use classifier free guidance or not. | |
num_videos_per_prompt (`int`, *optional*, defaults to 1): | |
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on | |
prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
negative_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
argument. | |
device: (`torch.device`, *optional*): | |
torch device | |
dtype: (`torch.dtype`, *optional*): | |
torch dtype | |
""" | |
print(f"{max_sequence_length=}") | |
device = device or self._execution_device | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
if prompt is not None: | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
if prompt_embeds is None: | |
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( | |
prompt=prompt, | |
num_videos_per_prompt=num_videos_per_prompt, | |
max_sequence_length=max_sequence_length, | |
device=device, | |
dtype=dtype, | |
) | |
if do_classifier_free_guidance and negative_prompt_embeds is None: | |
negative_prompt = negative_prompt or "" | |
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | |
if prompt is not None and type(prompt) is not type(negative_prompt): | |
raise TypeError( | |
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
f" {type(prompt)}." | |
) | |
elif batch_size != len(negative_prompt): | |
raise ValueError( | |
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
" the batch size of `prompt`." | |
) | |
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( | |
prompt=negative_prompt, | |
num_videos_per_prompt=num_videos_per_prompt, | |
max_sequence_length=max_sequence_length, | |
device=device, | |
dtype=dtype, | |
) | |
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask | |
def check_inputs( | |
self, | |
prompt, | |
height, | |
width, | |
callback_on_step_end_tensor_inputs=None, | |
prompt_embeds=None, | |
negative_prompt_embeds=None, | |
prompt_attention_mask=None, | |
negative_prompt_attention_mask=None, | |
): | |
if height % 32 != 0 or width % 32 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") | |
if callback_on_step_end_tensor_inputs is not None and not all( | |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
): | |
raise ValueError( | |
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
) | |
if prompt is not None and prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
" only forward one of the two." | |
) | |
elif prompt is None and prompt_embeds is None: | |
raise ValueError( | |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
) | |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
if prompt_embeds is not None and prompt_attention_mask is None: | |
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") | |
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: | |
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") | |
if prompt_embeds is not None and negative_prompt_embeds is not None: | |
if prompt_embeds.shape != negative_prompt_embeds.shape: | |
raise ValueError( | |
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" | |
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" | |
f" {negative_prompt_embeds.shape}." | |
) | |
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: | |
raise ValueError( | |
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" | |
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" | |
f" {negative_prompt_attention_mask.shape}." | |
) | |
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: | |
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. | |
# The patch dimensions are then permuted and collapsed into the channel dimension of shape: | |
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). | |
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features | |
batch_size, num_channels, num_frames, height, width = latents.shape | |
post_patch_num_frames = num_frames // patch_size_t | |
post_patch_height = height // patch_size | |
post_patch_width = width // patch_size | |
latents = latents.reshape( | |
batch_size, | |
-1, | |
post_patch_num_frames, | |
patch_size_t, | |
post_patch_height, | |
patch_size, | |
post_patch_width, | |
patch_size, | |
) | |
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) | |
return latents | |
def _unpack_latents( | |
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 | |
) -> torch.Tensor: | |
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) | |
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of | |
# what happens in the `_pack_latents` method. | |
batch_size = latents.size(0) | |
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) | |
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) | |
return latents | |
def _normalize_latents( | |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 | |
) -> torch.Tensor: | |
# Normalize latents across the channel dimension [B, C, F, H, W] | |
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
latents = (latents - latents_mean) * scaling_factor / latents_std | |
return latents | |
def _denormalize_latents( | |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 | |
) -> torch.Tensor: | |
# Denormalize latents across the channel dimension [B, C, F, H, W] | |
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) | |
latents = latents * latents_std / scaling_factor + latents_mean | |
return latents | |
def prepare_latents( | |
self, | |
batch_size: int = 1, | |
num_channels_latents: int = 128, | |
height: int = 512, | |
width: int = 704, | |
num_frames: int = 161, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device] = None, | |
generator: Optional[torch.Generator] = None, | |
latents: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
if latents is not None: | |
return latents.to(device=device, dtype=dtype) | |
height = height // self.vae_spatial_compression_ratio | |
width = width // self.vae_spatial_compression_ratio | |
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 | |
shape = (batch_size, num_channels_latents, num_frames, height, width) | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
latents = self._pack_latents( | |
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | |
) | |
return latents | |
def guidance_scale(self): | |
return self._guidance_scale | |
def do_classifier_free_guidance(self): | |
return self._guidance_scale > 1.0 | |
def num_timesteps(self): | |
return self._num_timesteps | |
def attention_kwargs(self): | |
return self._attention_kwargs | |
def interrupt(self): | |
return self._interrupt | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
height: int = 512, | |
width: int = 704, | |
num_frames: int = 161, | |
frame_rate: int = 25, | |
num_inference_steps: int = 50, | |
timesteps: List[int] = None, | |
guidance_scale: float = 3, | |
num_videos_per_prompt: Optional[int] = 1, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.Tensor] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
prompt_attention_mask: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_attention_mask: Optional[torch.Tensor] = None, | |
decode_timestep: Union[float, List[float]] = 0.0, | |
decode_noise_scale: Optional[Union[float, List[float]]] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
attention_kwargs: Optional[Dict[str, Any]] = None, | |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
max_sequence_length: int = 128, | |
): | |
r""" | |
Function invoked when calling the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | |
instead. | |
height (`int`, defaults to `512`): | |
The height in pixels of the generated image. This is set to 480 by default for the best results. | |
width (`int`, defaults to `704`): | |
The width in pixels of the generated image. This is set to 848 by default for the best results. | |
num_frames (`int`, defaults to `161`): | |
The number of video frames to generate | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
timesteps (`List[int]`, *optional*): | |
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument | |
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is | |
passed will be used. Must be in descending order. | |
guidance_scale (`float`, defaults to `3 `): | |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | |
`guidance_scale` is defined as `w` of equation 2. of [Imagen | |
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | |
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | |
usually at the expense of lower image quality. | |
num_videos_per_prompt (`int`, *optional*, defaults to 1): | |
The number of videos to generate per prompt. | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | |
to make generation deterministic. | |
latents (`torch.Tensor`, *optional*): | |
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
tensor will ge generated by sampling using the supplied random `generator`. | |
prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
prompt_attention_mask (`torch.Tensor`, *optional*): | |
Pre-generated attention mask for text embeddings. | |
negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not | |
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. | |
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): | |
Pre-generated attention mask for negative text embeddings. | |
decode_timestep (`float`, defaults to `0.0`): | |
The timestep at which generated video is decoded. | |
decode_noise_scale (`float`, defaults to `None`): | |
The interpolation factor between random noise and denoised latents at the decode timestep. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generate image. Choose between | |
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. | |
attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
callback_on_step_end (`Callable`, *optional*): | |
A function that calls at the end of each denoising steps during the inference. The function is called | |
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, | |
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by | |
`callback_on_step_end_tensor_inputs`. | |
callback_on_step_end_tensor_inputs (`List`, *optional*): | |
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | |
`._callback_tensor_inputs` attribute of your pipeline class. | |
max_sequence_length (`int` defaults to `128 `): | |
Maximum sequence length to use with the `prompt`. | |
Examples: | |
Returns: | |
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: | |
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is | |
returned where the first element is a list with the generated images. | |
""" | |
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt=prompt, | |
height=height, | |
width=width, | |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
prompt_attention_mask=prompt_attention_mask, | |
negative_prompt_attention_mask=negative_prompt_attention_mask, | |
) | |
self._guidance_scale = guidance_scale | |
self._attention_kwargs = attention_kwargs | |
self._interrupt = False | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
# 3. Prepare text embeddings | |
( | |
prompt_embeds, | |
prompt_attention_mask, | |
negative_prompt_embeds, | |
negative_prompt_attention_mask, | |
) = self.encode_prompt( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
do_classifier_free_guidance=self.do_classifier_free_guidance, | |
num_videos_per_prompt=num_videos_per_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
prompt_attention_mask=prompt_attention_mask, | |
negative_prompt_attention_mask=negative_prompt_attention_mask, | |
max_sequence_length=max_sequence_length, | |
device=device, | |
) | |
if self.do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | |
if Q8Linear is not None and isinstance(self.transformer.transformer_blocks[0].attn1.to_q, Q8Linear): | |
prompt_attention_mask = prompt_attention_mask.to(torch.int64) | |
prompt_attention_mask = prompt_attention_mask.argmin(-1).int().squeeze() | |
prompt_attention_mask[prompt_attention_mask == 0] = max_sequence_length | |
# 4. Prepare latent variables | |
num_channels_latents = self.transformer.config.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_videos_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
num_frames, | |
torch.float32, | |
device, | |
generator, | |
latents, | |
) | |
# 5. Prepare timesteps | |
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 | |
latent_height = height // self.vae_spatial_compression_ratio | |
latent_width = width // self.vae_spatial_compression_ratio | |
video_sequence_length = latent_num_frames * latent_height * latent_width | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
mu = calculate_shift( | |
video_sequence_length, | |
self.scheduler.config.base_image_seq_len, | |
self.scheduler.config.max_image_seq_len, | |
self.scheduler.config.base_shift, | |
self.scheduler.config.max_shift, | |
) | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, | |
num_inference_steps, | |
device, | |
timesteps, | |
sigmas=sigmas, | |
mu=mu, | |
) | |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
self._num_timesteps = len(timesteps) | |
# 6. Prepare micro-conditions | |
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio | |
rope_interpolation_scale = ( | |
1 / latent_frame_rate, | |
self.vae_spatial_compression_ratio, | |
self.vae_spatial_compression_ratio, | |
) | |
# 7. Denoising loop | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
if self.interrupt: | |
continue | |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
latent_model_input = latent_model_input.to(prompt_embeds.dtype) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timestep = t.expand(latent_model_input.shape[0]) | |
noise_pred = self.transformer( | |
hidden_states=latent_model_input, | |
encoder_hidden_states=prompt_embeds, | |
timestep=timestep, | |
encoder_attention_mask=prompt_attention_mask, | |
num_frames=latent_num_frames, | |
height=latent_height, | |
width=latent_width, | |
rope_interpolation_scale=rope_interpolation_scale, | |
attention_kwargs=attention_kwargs, | |
return_dict=False, | |
)[0] | |
noise_pred = noise_pred.float() | |
if self.do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
latents = callback_outputs.pop("latents", latents) | |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
if XLA_AVAILABLE: | |
xm.mark_step() | |
if output_type == "latent": | |
video = latents | |
else: | |
latents = self._unpack_latents( | |
latents, | |
latent_num_frames, | |
latent_height, | |
latent_width, | |
self.transformer_spatial_patch_size, | |
self.transformer_temporal_patch_size, | |
) | |
latents = self._denormalize_latents( | |
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor | |
) | |
latents = latents.to(prompt_embeds.dtype) | |
if not self.vae.config.timestep_conditioning: | |
timestep = None | |
else: | |
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) | |
if not isinstance(decode_timestep, list): | |
decode_timestep = [decode_timestep] * batch_size | |
if decode_noise_scale is None: | |
decode_noise_scale = decode_timestep | |
elif not isinstance(decode_noise_scale, list): | |
decode_noise_scale = [decode_noise_scale] * batch_size | |
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) | |
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ | |
:, None, None, None, None | |
] | |
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise | |
video = self.vae.decode(latents, timestep, return_dict=False)[0] | |
video = self.video_processor.postprocess_video(video, output_type=output_type) | |
# Offload all models | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (video,) | |
return LTXPipelineOutput(frames=video) | |