mehmetkeremturkcan commited on
Commit
04dc2c4
·
verified ·
1 Parent(s): bdd1d35

Upload stg_ltx_i2v_pipeline.py

Browse files
Files changed (1) hide show
  1. stg_ltx_i2v_pipeline.py +595 -0
stg_ltx_i2v_pipeline.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import types
16
+ import inspect
17
+ from typing import Callable, Dict, List, Optional, Union, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ from transformers import T5EncoderModel, T5TokenizerFast
22
+
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.image_processor import PipelineImageInput
25
+ from diffusers.loaders import FromSingleFileMixin
26
+ from diffusers.pipelines.ltx.pipeline_ltx_image2video import LTXImageToVideoPipeline
27
+ from diffusers.models.autoencoders import AutoencoderKLLTXVideo
28
+ from diffusers.models.transformers import LTXVideoTransformer3DModel
29
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
30
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.video_processor import VideoProcessor
33
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
34
+ from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
35
+ from diffusers.models.attention_processor import Attention
36
+ from diffusers.models.transformers.transformer_ltx import apply_rotary_emb
37
+
38
+ import torch.nn.functional as F
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+ def forward_with_stg(
50
+ self,
51
+ hidden_states: torch.Tensor,
52
+ encoder_hidden_states: torch.Tensor,
53
+ temb: torch.Tensor,
54
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
55
+ encoder_attention_mask: Optional[torch.Tensor] = None,
56
+ ) -> torch.Tensor:
57
+
58
+ hidden_states_ptb = hidden_states[2:]
59
+ encoder_hidden_states_ptb = encoder_hidden_states[2:]
60
+
61
+ batch_size = hidden_states.size(0)
62
+ norm_hidden_states = self.norm1(hidden_states)
63
+
64
+ num_ada_params = self.scale_shift_table.shape[0]
65
+ ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
66
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
67
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
68
+
69
+ attn_hidden_states = self.attn1(
70
+ hidden_states=norm_hidden_states,
71
+ encoder_hidden_states=None,
72
+ image_rotary_emb=image_rotary_emb,
73
+ )
74
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
75
+
76
+ attn_hidden_states = self.attn2(
77
+ hidden_states,
78
+ encoder_hidden_states=encoder_hidden_states,
79
+ image_rotary_emb=None,
80
+ attention_mask=encoder_attention_mask,
81
+ )
82
+ hidden_states = hidden_states + attn_hidden_states
83
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
84
+
85
+ ff_output = self.ff(norm_hidden_states)
86
+ hidden_states = hidden_states + ff_output * gate_mlp
87
+
88
+ hidden_states[2:] = hidden_states_ptb
89
+ encoder_hidden_states[2:] = encoder_hidden_states_ptb
90
+
91
+ return hidden_states
92
+
93
+ class STGLTXVideoAttentionProcessor2_0:
94
+ r"""
95
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
96
+ used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
97
+ """
98
+
99
+ def __init__(self):
100
+ if not hasattr(F, "scaled_dot_product_attention"):
101
+ raise ImportError(
102
+ "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
103
+ )
104
+
105
+ def __call__(
106
+ self,
107
+ attn: Attention,
108
+ hidden_states: torch.Tensor,
109
+ encoder_hidden_states: Optional[torch.Tensor] = None,
110
+ attention_mask: Optional[torch.Tensor] = None,
111
+ image_rotary_emb: Optional[torch.Tensor] = None,
112
+ ) -> torch.Tensor:
113
+
114
+ hidden_states_uncond, hidden_states_text, hidden_states_perturb = hidden_states.chunk(3)
115
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_text])
116
+
117
+ emb_sin, emb_cos = image_rotary_emb
118
+ emb_sin_uncond, emb_sin_text, emb_sin_perturb = emb_sin.chunk(3)
119
+ emb_cos_uncond, emb_cos_text, emb_cos_perturb = emb_cos.chunk(3)
120
+ emb_sin_org = torch.cat([emb_sin_uncond, emb_sin_text])
121
+ emb_cos_org = torch.cat([emb_cos_uncond, emb_cos_text])
122
+
123
+ image_rotary_emb_org = (emb_sin_org, emb_cos_org)
124
+ image_rotary_emb_perturb = (emb_sin_perturb, emb_cos_perturb)
125
+
126
+ #----------------Original Path----------------#
127
+ assert encoder_hidden_states is None
128
+ batch_size, sequence_length, _ = (
129
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
130
+ )
131
+
132
+ if attention_mask is not None:
133
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
134
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
135
+
136
+ if encoder_hidden_states is None:
137
+ encoder_hidden_states_org = hidden_states_org
138
+
139
+ query_org = attn.to_q(hidden_states_org)
140
+ key_org = attn.to_k(encoder_hidden_states_org)
141
+ value_org = attn.to_v(encoder_hidden_states_org)
142
+
143
+ query_org = attn.norm_q(query_org)
144
+ key_org = attn.norm_k(key_org)
145
+
146
+ if image_rotary_emb is not None:
147
+ query_org = apply_rotary_emb(query_org, image_rotary_emb_org)
148
+ key_org = apply_rotary_emb(key_org, image_rotary_emb_org)
149
+
150
+ query_org = query_org.unflatten(2, (attn.heads, -1)).transpose(1, 2)
151
+ key_org = key_org.unflatten(2, (attn.heads, -1)).transpose(1, 2)
152
+ value_org = value_org.unflatten(2, (attn.heads, -1)).transpose(1, 2)
153
+
154
+ hidden_states_org = F.scaled_dot_product_attention(
155
+ query_org, key_org, value_org, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
156
+ )
157
+ hidden_states_org = hidden_states_org.transpose(1, 2).flatten(2, 3)
158
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
159
+
160
+ hidden_states_org = attn.to_out[0](hidden_states_org)
161
+ hidden_states_org = attn.to_out[1](hidden_states_org)
162
+ #----------------------------------------------#
163
+ #--------------Perturbation Path---------------#
164
+ batch_size, sequence_length, _ = hidden_states_perturb.shape
165
+
166
+ if attention_mask is not None:
167
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
168
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
169
+
170
+ if encoder_hidden_states is None:
171
+ encoder_hidden_states_perturb = hidden_states_perturb
172
+
173
+ query_perturb = attn.to_q(hidden_states_perturb)
174
+ key_perturb = attn.to_k(encoder_hidden_states_perturb)
175
+ value_perturb = attn.to_v(encoder_hidden_states_perturb)
176
+
177
+ query_perturb = attn.norm_q(query_perturb)
178
+ key_perturb = attn.norm_k(key_perturb)
179
+
180
+ if image_rotary_emb is not None:
181
+ query_perturb = apply_rotary_emb(query_perturb, image_rotary_emb_perturb)
182
+ key_perturb = apply_rotary_emb(key_perturb, image_rotary_emb_perturb)
183
+
184
+ query_perturb = query_perturb.unflatten(2, (attn.heads, -1)).transpose(1, 2)
185
+ key_perturb = key_perturb.unflatten(2, (attn.heads, -1)).transpose(1, 2)
186
+ value_perturb = value_perturb.unflatten(2, (attn.heads, -1)).transpose(1, 2)
187
+
188
+ hidden_states_perturb = value_perturb
189
+
190
+ hidden_states_perturb = hidden_states_perturb.transpose(1, 2).flatten(2, 3)
191
+ hidden_states_perturb = hidden_states_perturb.to(query_perturb.dtype)
192
+
193
+ hidden_states_perturb = attn.to_out[0](hidden_states_perturb)
194
+ hidden_states_perturb = attn.to_out[1](hidden_states_perturb)
195
+ #----------------------------------------------#
196
+
197
+ hidden_states = torch.cat([hidden_states_org, hidden_states_perturb], dim=0)
198
+
199
+ return hidden_states
200
+
201
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
202
+ def calculate_shift(
203
+ image_seq_len,
204
+ base_seq_len: int = 256,
205
+ max_seq_len: int = 4096,
206
+ base_shift: float = 0.5,
207
+ max_shift: float = 1.16,
208
+ ):
209
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
210
+ b = base_shift - m * base_seq_len
211
+ mu = image_seq_len * m + b
212
+ return mu
213
+
214
+
215
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
216
+ def retrieve_timesteps(
217
+ scheduler,
218
+ num_inference_steps: Optional[int] = None,
219
+ device: Optional[Union[str, torch.device]] = None,
220
+ timesteps: Optional[List[int]] = None,
221
+ sigmas: Optional[List[float]] = None,
222
+ **kwargs,
223
+ ):
224
+ r"""
225
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
226
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
227
+
228
+ Args:
229
+ scheduler (`SchedulerMixin`):
230
+ The scheduler to get timesteps from.
231
+ num_inference_steps (`int`):
232
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
233
+ must be `None`.
234
+ device (`str` or `torch.device`, *optional*):
235
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
236
+ timesteps (`List[int]`, *optional*):
237
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
238
+ `num_inference_steps` and `sigmas` must be `None`.
239
+ sigmas (`List[float]`, *optional*):
240
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
241
+ `num_inference_steps` and `timesteps` must be `None`.
242
+
243
+ Returns:
244
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
245
+ second element is the number of inference steps.
246
+ """
247
+ if timesteps is not None and sigmas is not None:
248
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
249
+ if timesteps is not None:
250
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
251
+ if not accepts_timesteps:
252
+ raise ValueError(
253
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
254
+ f" timestep schedules. Please check whether you are using the correct scheduler."
255
+ )
256
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
257
+ timesteps = scheduler.timesteps
258
+ num_inference_steps = len(timesteps)
259
+ elif sigmas is not None:
260
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
261
+ if not accept_sigmas:
262
+ raise ValueError(
263
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
264
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
265
+ )
266
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
267
+ timesteps = scheduler.timesteps
268
+ num_inference_steps = len(timesteps)
269
+ else:
270
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
271
+ timesteps = scheduler.timesteps
272
+ return timesteps, num_inference_steps
273
+
274
+
275
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
276
+ def retrieve_latents(
277
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
278
+ ):
279
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
280
+ return encoder_output.latent_dist.sample(generator)
281
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
282
+ return encoder_output.latent_dist.mode()
283
+ elif hasattr(encoder_output, "latents"):
284
+ return encoder_output.latents
285
+ else:
286
+ raise AttributeError("Could not access latents of provided encoder_output")
287
+
288
+
289
+ class LTXImageToVideoSTGPipeline(LTXImageToVideoPipeline):
290
+ def extract_layers(self, file_path="./unet_info.txt"):
291
+ layers = []
292
+ with open(file_path, "w") as f:
293
+ for name, module in self.transformer.named_modules():
294
+ if "attn1" in name and "to" not in name and "add" not in name and "norm" not in name:
295
+ f.write(f"{name}\n")
296
+ layer_type = name.split(".")[0].split("_")[0]
297
+ layers.append((name, module))
298
+
299
+ return layers
300
+
301
+ def replace_layer_processor(self, layers, replace_processor, target_layers_idx=[]):
302
+ for layer_idx in target_layers_idx:
303
+ layers[layer_idx][1].processor = replace_processor
304
+
305
+ return
306
+
307
+ @property
308
+ def do_spatio_temporal_guidance(self):
309
+ return self._stg_scale > 0.0
310
+
311
+ @torch.no_grad()
312
+ def __call__(
313
+ self,
314
+ image: PipelineImageInput = None,
315
+ prompt: Union[str, List[str]] = None,
316
+ negative_prompt: Optional[Union[str, List[str]]] = None,
317
+ height: int = 512,
318
+ width: int = 704,
319
+ num_frames: int = 161,
320
+ frame_rate: int = 25,
321
+ num_inference_steps: int = 50,
322
+ timesteps: List[int] = None,
323
+ guidance_scale: float = 3,
324
+ num_videos_per_prompt: Optional[int] = 1,
325
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
326
+ latents: Optional[torch.Tensor] = None,
327
+ prompt_embeds: Optional[torch.Tensor] = None,
328
+ prompt_attention_mask: Optional[torch.Tensor] = None,
329
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
330
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
331
+ output_type: Optional[str] = "pil",
332
+ return_dict: bool = True,
333
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
334
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
335
+ max_sequence_length: int = 128,
336
+ stg_mode: Optional[str] = "STG-R",
337
+ stg_applied_layers_idx: Optional[List[int]] = [35],
338
+ stg_scale: Optional[float] = 1.0,
339
+ do_rescaling: Optional[bool] = False,
340
+ decode_timestep: Union[float, List[float]] = 0.0,
341
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
342
+ ):
343
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
344
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
345
+
346
+ layers = self.extract_layers()
347
+
348
+ # 1. Check inputs. Raise error if not correct
349
+ self.check_inputs(
350
+ prompt=prompt,
351
+ height=height,
352
+ width=width,
353
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
354
+ prompt_embeds=prompt_embeds,
355
+ negative_prompt_embeds=negative_prompt_embeds,
356
+ prompt_attention_mask=prompt_attention_mask,
357
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
358
+ )
359
+
360
+ self._stg_scale = stg_scale
361
+ self._guidance_scale = guidance_scale
362
+ self._interrupt = False
363
+
364
+ if self.do_spatio_temporal_guidance:
365
+ if stg_mode == "STG-A":
366
+ layers = self.extract_layers()
367
+ replace_processor = STGLTXVideoAttentionProcessor2_0()
368
+ self.replace_layer_processor(layers, replace_processor, stg_applied_layers_idx)
369
+ elif stg_mode == "STG-R":
370
+ for i in stg_applied_layers_idx:
371
+ self.transformer.transformer_blocks[i].forward = types.MethodType(forward_with_stg, self.transformer.transformer_blocks[i])
372
+
373
+ # 2. Define call parameters
374
+ if prompt is not None and isinstance(prompt, str):
375
+ batch_size = 1
376
+ elif prompt is not None and isinstance(prompt, list):
377
+ batch_size = len(prompt)
378
+ else:
379
+ batch_size = prompt_embeds.shape[0]
380
+
381
+ device = self._execution_device
382
+
383
+ # 3. Prepare text embeddings
384
+ (
385
+ prompt_embeds,
386
+ prompt_attention_mask,
387
+ negative_prompt_embeds,
388
+ negative_prompt_attention_mask,
389
+ ) = self.encode_prompt(
390
+ prompt=prompt,
391
+ negative_prompt=negative_prompt,
392
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
393
+ num_videos_per_prompt=num_videos_per_prompt,
394
+ prompt_embeds=prompt_embeds,
395
+ negative_prompt_embeds=negative_prompt_embeds,
396
+ prompt_attention_mask=prompt_attention_mask,
397
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
398
+ max_sequence_length=max_sequence_length,
399
+ device=device,
400
+ )
401
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
402
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
403
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
404
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
405
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
406
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0)
407
+
408
+ # 4. Prepare latent variables
409
+ if latents is None:
410
+ image = self.video_processor.preprocess(image, height=height, width=width)
411
+ image = image.to(device=device, dtype=prompt_embeds.dtype)
412
+
413
+ num_channels_latents = self.transformer.config.in_channels
414
+ latents, conditioning_mask = self.prepare_latents(
415
+ image,
416
+ batch_size * num_videos_per_prompt,
417
+ num_channels_latents,
418
+ height,
419
+ width,
420
+ num_frames,
421
+ torch.float32,
422
+ device,
423
+ generator,
424
+ latents,
425
+ )
426
+
427
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
428
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
429
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
430
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])
431
+
432
+ # 5. Prepare timesteps
433
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
434
+ latent_height = height // self.vae_spatial_compression_ratio
435
+ latent_width = width // self.vae_spatial_compression_ratio
436
+ video_sequence_length = latent_num_frames * latent_height * latent_width
437
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
438
+ mu = calculate_shift(
439
+ video_sequence_length,
440
+ self.scheduler.config.base_image_seq_len,
441
+ self.scheduler.config.max_image_seq_len,
442
+ self.scheduler.config.base_shift,
443
+ self.scheduler.config.max_shift,
444
+ )
445
+ timesteps, num_inference_steps = retrieve_timesteps(
446
+ self.scheduler,
447
+ num_inference_steps,
448
+ device,
449
+ timesteps,
450
+ sigmas=sigmas,
451
+ mu=mu,
452
+ )
453
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
454
+ self._num_timesteps = len(timesteps)
455
+
456
+ # 6. Prepare micro-conditions
457
+ latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
458
+ rope_interpolation_scale = (
459
+ 1 / latent_frame_rate,
460
+ self.vae_spatial_compression_ratio,
461
+ self.vae_spatial_compression_ratio,
462
+ )
463
+
464
+ # 7. Denoising loop
465
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
466
+ for i, t in enumerate(timesteps):
467
+ if self.interrupt:
468
+ continue
469
+
470
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
471
+ latent_model_input = torch.cat([latents] * 2)
472
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
473
+ latent_model_input = torch.cat([latents] * 3)
474
+
475
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
476
+
477
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
478
+ timestep = t.expand(latent_model_input.shape[0])
479
+ timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
480
+
481
+ noise_pred = self.transformer(
482
+ hidden_states=latent_model_input,
483
+ encoder_hidden_states=prompt_embeds,
484
+ timestep=timestep,
485
+ encoder_attention_mask=prompt_attention_mask,
486
+ num_frames=latent_num_frames,
487
+ height=latent_height,
488
+ width=latent_width,
489
+ rope_interpolation_scale=rope_interpolation_scale,
490
+ return_dict=False,
491
+ )[0]
492
+ noise_pred = noise_pred.float()
493
+
494
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
495
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
496
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
497
+ timestep, _ = timestep.chunk(2)
498
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
499
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
500
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) \
501
+ + self._stg_scale * (noise_pred_text - noise_pred_perturb)
502
+ timestep, _, _ = timestep.chunk(3)
503
+
504
+ if do_rescaling:
505
+ rescaling_scale = 0.7
506
+ factor = noise_pred_text.std() / noise_pred.std()
507
+ factor = rescaling_scale * factor + (1 - rescaling_scale)
508
+ noise_pred = noise_pred * factor
509
+
510
+ # compute the previous noisy sample x_t -> x_t-1
511
+ noise_pred = self._unpack_latents(
512
+ noise_pred,
513
+ latent_num_frames,
514
+ latent_height,
515
+ latent_width,
516
+ self.transformer_spatial_patch_size,
517
+ self.transformer_temporal_patch_size,
518
+ )
519
+ latents = self._unpack_latents(
520
+ latents,
521
+ latent_num_frames,
522
+ latent_height,
523
+ latent_width,
524
+ self.transformer_spatial_patch_size,
525
+ self.transformer_temporal_patch_size,
526
+ )
527
+
528
+ noise_pred = noise_pred[:, :, 1:]
529
+ noise_latents = latents[:, :, 1:]
530
+ pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
531
+
532
+ latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
533
+ latents = self._pack_latents(
534
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
535
+ )
536
+
537
+ if callback_on_step_end is not None:
538
+ callback_kwargs = {}
539
+ for k in callback_on_step_end_tensor_inputs:
540
+ callback_kwargs[k] = locals()[k]
541
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
542
+
543
+ latents = callback_outputs.pop("latents", latents)
544
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
545
+
546
+ # call the callback, if provided
547
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
548
+ progress_bar.update()
549
+
550
+ if XLA_AVAILABLE:
551
+ xm.mark_step()
552
+
553
+ if output_type == "latent":
554
+ video = latents
555
+ else:
556
+ latents = self._unpack_latents(
557
+ latents,
558
+ latent_num_frames,
559
+ latent_height,
560
+ latent_width,
561
+ self.transformer_spatial_patch_size,
562
+ self.transformer_temporal_patch_size,
563
+ )
564
+ latents = self._denormalize_latents(
565
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
566
+ )
567
+ latents = latents.to(prompt_embeds.dtype)
568
+
569
+ if not self.vae.config.timestep_conditioning:
570
+ timestep = None
571
+ else:
572
+ noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
573
+ if not isinstance(decode_timestep, list):
574
+ decode_timestep = [decode_timestep] * batch_size
575
+ if decode_noise_scale is None:
576
+ decode_noise_scale = decode_timestep
577
+ elif not isinstance(decode_noise_scale, list):
578
+ decode_noise_scale = [decode_noise_scale] * batch_size
579
+
580
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
581
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
582
+ :, None, None, None, None
583
+ ]
584
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
585
+
586
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
587
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
588
+
589
+ # Offload all models
590
+ self.maybe_free_model_hooks()
591
+
592
+ if not return_dict:
593
+ return (video,)
594
+
595
+ return LTXPipelineOutput(frames=video)