vyloup commited on
Commit
1ab61b7
·
verified ·
1 Parent(s): 1be84b2

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +10 -79
  2. app.py +490 -766
  3. inference.py +645 -240
  4. requirements.txt +13 -12
README.md CHANGED
@@ -1,82 +1,13 @@
1
  ---
2
- title: LTX-Video-Playground # Replace with your app's title
3
- emoji: 🚀 # Choose an emoji to represent your app
4
- colorFrom: blue # Choose a color to start the gradient (e.g., blue, red, green)
5
- colorTo: purple # Choose a color to end the gradient
6
- sdk: gradio # Specify the SDK, e.g., gradio or streamlit
7
- sdk_version: "4.44.1" # Specify the SDK version if needed
8
- app_file: app.py # Name of your main app file
9
- pinned: false # Set to true if you want to pin this Space
 
10
  ---
11
 
12
- <div align="center">
13
-
14
- # Xora️
15
-
16
- </div>
17
-
18
- This is the official repository for Xora.
19
-
20
- ## Table of Contents
21
-
22
- - [Introduction](#introduction)
23
- - [Installation](#installation)
24
- - [Inference](#inference)
25
- - [Inference Code](#inference-code)
26
- - [Acknowledgement](#acknowledgement)
27
-
28
- ## Introduction
29
-
30
- The performance of Diffusion Transformers is heavily influenced by the number of generated latent pixels (or tokens). In video generation, the token count becomes substantial as the number of frames increases. To address this, we designed a carefully optimized VAE that compresses videos into a smaller number of tokens while utilizing a deeper latent space. This approach enables our model to generate high-quality 768x512 videos at 24 FPS, achieving near real-time speeds.
31
-
32
- ## Installation
33
-
34
- # Setup
35
-
36
- The codebase currently uses Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
37
-
38
- ```bash
39
- git clone https://github.com/LightricksResearch/xora-core.git
40
- cd xora-core
41
-
42
- # create env
43
- python -m venv env
44
- source env/bin/activate
45
- python -m pip install -e .\[inference-script\]
46
- ```
47
-
48
- Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/Xora)
49
-
50
- ```python
51
- from huggingface_hub import snapshot_download
52
-
53
- model_path = 'PATH' # The local directory to save downloaded checkpoint
54
- snapshot_download("Lightricks/Xora", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
55
- ```
56
-
57
- ## Inference
58
-
59
- ### Inference Code
60
-
61
- To use our model, please follow the inference code in `inference.py` at [https://github.com/LightricksResearch/xora-core/blob/main/inference.py]():
62
-
63
- For text-to-video generation:
64
-
65
- ```bash
66
- python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH
67
- ```
68
-
69
- For image-to-video generation:
70
-
71
- ```python
72
- python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH
73
-
74
- ```
75
-
76
- ## Acknowledgement
77
-
78
- We are grateful for the following awesome projects when implementing Xora:
79
-
80
- - [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.
81
-
82
- [//]: # "## Citation"
 
1
  ---
2
+ title: LTX Video Fast
3
+ emoji: 🎥
4
+ colorFrom: yellow
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: ultra-fast video model, LTX 0.9.8 13B distilled
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,797 +1,521 @@
1
- import spaces
2
- from functools import lru_cache
3
  import gradio as gr
4
- from gradio_toggle import Toggle
5
  import torch
6
- from huggingface_hub import hf_hub_download
7
- from transformers import CLIPProcessor, CLIPModel
8
-
9
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
10
- from ltx_video.models.transformers.transformer3d import Transformer3DModel
11
- from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
12
- from ltx_video.schedulers.rf import RectifiedFlowScheduler
13
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline as XoraVideoPipeline
14
- from transformers import T5EncoderModel, T5Tokenizer
15
- from ltx_video.utils.conditioning_method import ConditioningMethod
16
- from pathlib import Path
17
- import safetensors.torch
18
- import json
19
  import numpy as np
20
- import cv2
21
- from PIL import Image
22
- import tempfile
23
  import os
24
- import gc
25
- from openai import OpenAI
26
- import csv
27
- from datetime import datetime
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Load Hugging Face token if needed
33
- hf_token = os.getenv("HF_TOKEN")
34
- openai_api_key = os.getenv("OPENAI_API_KEY")
35
- client = OpenAI(api_key=openai_api_key)
36
- system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
37
- system_prompt_i2v_path = "assets/system_prompt_i2v.txt"
38
- with open(system_prompt_t2v_path, "r") as f:
39
- system_prompt_t2v = f.read()
40
-
41
- with open(system_prompt_i2v_path, "r") as f:
42
- system_prompt_i2v = f.read()
43
-
44
- # Set model download directory within Hugging Face Spaces
45
- model_path = Path("/home/elevin/xora-core/assets/")
46
- cpkt_path = Path("/home/elevin/xora-core/assets/ltx-video-2b-v0.9.1.safetensors")
47
- if not os.path.exists(cpkt_path):
48
- hf_hub_download(repo_id="Lightricks/LTX-Video", filename="ltx-video-2b-v0.9.1.safetensors", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
49
- # Global variables to load components
50
- vae_dir = Path(model_path) / "vae"
51
- unet_dir = Path(model_path) / "unet"
52
- scheduler_dir = Path(model_path) / "scheduler"
53
-
54
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
-
56
- DATA_DIR = "/data"
57
- os.makedirs(DATA_DIR, exist_ok=True)
58
- LOG_FILE_PATH = os.path.join("/data", "user_requests.csv")
59
-
60
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
61
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
62
-
63
-
64
- if not os.path.exists(LOG_FILE_PATH):
65
- with open(LOG_FILE_PATH, "w", newline="") as f:
66
- writer = csv.writer(f)
67
- writer.writerow(
68
- [
69
- "timestamp",
70
- "request_type",
71
- "prompt",
72
- "negative_prompt",
73
- "height",
74
- "width",
75
- "num_frames",
76
- "frame_rate",
77
- "seed",
78
- "num_inference_steps",
79
- "guidance_scale",
80
- "is_enhanced",
81
- "clip_embedding",
82
- "original_resolution",
83
- ]
84
- )
85
-
86
-
87
- @lru_cache(maxsize=128)
88
- def log_request(
89
- request_type,
90
- prompt,
91
- negative_prompt,
92
- height,
93
- width,
94
- num_frames,
95
- frame_rate,
96
- seed,
97
- num_inference_steps,
98
- guidance_scale,
99
- is_enhanced,
100
- clip_embedding=None,
101
- original_resolution=None,
102
- ):
103
- """Log the user's request to a CSV file."""
104
- timestamp = datetime.now().isoformat()
105
- with open(LOG_FILE_PATH, "a", newline="") as f:
106
- try:
107
- writer = csv.writer(f)
108
- writer.writerow(
109
- [
110
- timestamp,
111
- request_type,
112
- prompt,
113
- negative_prompt,
114
- height,
115
- width,
116
- num_frames,
117
- frame_rate,
118
- seed,
119
- num_inference_steps,
120
- guidance_scale,
121
- is_enhanced,
122
- clip_embedding,
123
- original_resolution,
124
- ]
125
- )
126
- except Exception as e:
127
- print(f"Error logging request: {e}")
128
 
129
 
130
- def compute_clip_embedding(text=None, image=None):
 
 
 
 
131
  """
132
- Compute CLIP embedding for a given text or image.
133
- Args:
134
- text (str): Input text prompt.
135
- image (PIL.Image): Input image.
136
- Returns:
137
- list: CLIP embedding as a list of floats.
138
  """
139
- inputs = clip_processor(text=text, images=image, return_tensors="pt", padding=True)
140
- outputs = clip_model.get_text_features(**inputs) if text else clip_model.get_image_features(**inputs)
141
- embedding = outputs.detach().cpu().numpy().flatten().tolist()
142
- return embedding
143
-
144
-
145
- def load_vae(vae_dir):
146
- return CausalVideoAutoencoder.from_pretrained(cpkt_path).to(device=device, dtype=torch.bfloat16)
147
-
148
-
149
- def load_unet(unet_dir):
150
- return Transformer3DModel.from_pretrained(cpkt_path).to(device=device, dtype=torch.bfloat16)
151
-
152
-
153
- def load_scheduler(scheduler_dir):
154
- return RectifiedFlowScheduler.from_pretrained(cpkt_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- # Helper function for image processing
158
- def center_crop_and_resize(frame, target_height, target_width):
159
- h, w, _ = frame.shape
160
- aspect_ratio_target = target_width / target_height
161
- aspect_ratio_frame = w / h
162
- if aspect_ratio_frame > aspect_ratio_target:
163
- new_width = int(h * aspect_ratio_target)
164
- x_start = (w - new_width) // 2
165
- frame_cropped = frame[:, x_start : x_start + new_width]
166
- else:
167
- new_height = int(w / aspect_ratio_target)
168
- y_start = (h - new_height) // 2
169
- frame_cropped = frame[y_start : y_start + new_height, :]
170
- frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
171
- return frame_resized
172
-
173
-
174
- def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
175
- image = Image.open(image_path).convert("RGB")
176
- image_np = np.array(image)
177
- frame_resized = center_crop_and_resize(image_np, target_height, target_width)
178
- frame_tensor = torch.tensor(frame_resized).permute(2, 0, 1).float()
179
- frame_tensor = (frame_tensor / 127.5) - 1.0
180
- return frame_tensor.unsqueeze(0).unsqueeze(2)
181
-
182
-
183
- def enhance_prompt_if_enabled(prompt, enhance_toggle, type="t2v"):
184
- if not enhance_toggle:
185
- print("Enhance toggle is off, Prompt: ", prompt)
186
- return prompt
187
-
188
- system_prompt = system_prompt_t2v if type == "t2v" else system_prompt_i2v
189
- messages = [
190
- {"role": "system", "content": system_prompt},
191
- {"role": "user", "content": prompt},
192
- ]
193
 
194
- try:
195
- response = client.chat.completions.create(
196
- model="gpt-4o-mini",
197
- messages=messages,
198
- max_tokens=200,
199
- )
200
- print("Enhanced Prompt: ", response.choices[0].message.content.strip())
201
- return response.choices[0].message.content.strip()
202
- except Exception as e:
203
- print(f"Error: {e}")
204
- return prompt
205
-
206
-
207
- # Preset options for resolution and frame configuration
208
- preset_options = [
209
- {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
210
- {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
211
- {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
212
- {"label": "992x608, 65 frames", "width": 992, "height": 608, "num_frames": 65},
213
- {"label": "896x608, 73 frames", "width": 896, "height": 608, "num_frames": 73},
214
- {"label": "896x544, 81 frames", "width": 896, "height": 544, "num_frames": 81},
215
- {"label": "832x544, 89 frames", "width": 832, "height": 544, "num_frames": 89},
216
- {"label": "800x512, 97 frames", "width": 800, "height": 512, "num_frames": 97},
217
- {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
218
- {"label": "800x480, 105 frames", "width": 800, "height": 480, "num_frames": 105},
219
- {"label": "736x480, 113 frames", "width": 736, "height": 480, "num_frames": 113},
220
- {"label": "704x480, 121 frames", "width": 704, "height": 480, "num_frames": 121},
221
- {"label": "704x448, 129 frames", "width": 704, "height": 448, "num_frames": 129},
222
- {"label": "672x448, 137 frames", "width": 672, "height": 448, "num_frames": 137},
223
- {"label": "640x416, 153 frames", "width": 640, "height": 416, "num_frames": 153},
224
- {"label": "672x384, 161 frames", "width": 672, "height": 384, "num_frames": 161},
225
- {"label": "640x384, 169 frames", "width": 640, "height": 384, "num_frames": 169},
226
- {"label": "608x384, 177 frames", "width": 608, "height": 384, "num_frames": 177},
227
- {"label": "576x384, 185 frames", "width": 576, "height": 384, "num_frames": 185},
228
- {"label": "608x352, 193 frames", "width": 608, "height": 352, "num_frames": 193},
229
- {"label": "576x352, 201 frames", "width": 576, "height": 352, "num_frames": 201},
230
- {"label": "544x352, 209 frames", "width": 544, "height": 352, "num_frames": 209},
231
- {"label": "512x352, 225 frames", "width": 512, "height": 352, "num_frames": 225},
232
- {"label": "512x352, 233 frames", "width": 512, "height": 352, "num_frames": 233},
233
- {"label": "544x320, 241 frames", "width": 544, "height": 320, "num_frames": 241},
234
- {"label": "512x320, 249 frames", "width": 512, "height": 320, "num_frames": 249},
235
- {"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
236
- ]
237
-
238
-
239
- # Function to toggle visibility of sliders based on preset selection
240
- def preset_changed(preset):
241
- if preset != "Custom":
242
- selected = next(item for item in preset_options if item["label"] == preset)
243
- return (
244
- selected["height"],
245
- selected["width"],
246
- selected["num_frames"],
247
- gr.update(visible=False),
248
- gr.update(visible=False),
249
- gr.update(visible=False),
250
- )
251
  else:
252
- return (
253
- None,
254
- None,
255
- None,
256
- gr.update(visible=True),
257
- gr.update(visible=True),
258
- gr.update(visible=True),
259
- )
260
-
261
-
262
- # Load models
263
- vae = load_vae(vae_dir)
264
- unet = load_unet(unet_dir)
265
- scheduler = load_scheduler(scheduler_dir)
266
- patchifier = SymmetricPatchifier(patch_size=1)
267
- text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(device)
268
- tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
269
-
270
- pipeline = XoraVideoPipeline(
271
- transformer=unet,
272
- patchifier=patchifier,
273
- text_encoder=text_encoder,
274
- tokenizer=tokenizer,
275
- scheduler=scheduler,
276
- vae=vae,
277
- ).to(device)
278
-
279
- @spaces.GPU(duration=120)
280
- def generate_video_from_text(
281
- prompt="",
282
- enhance_prompt_toggle=False,
283
- txt2vid_analytics_toggle=True,
284
- negative_prompt="",
285
- frame_rate=25,
286
- seed=646373,
287
- num_inference_steps=30,
288
- guidance_scale=3,
289
- height=512,
290
- width=768,
291
- num_frames=121,
292
- progress=gr.Progress(),
293
- stg_scale=1.0,
294
- stg_rescale=0.7,
295
- stg_mode="stg_a",
296
- stg_skip_layers="19",
297
- ):
298
- if len(prompt.strip()) < 50:
299
- raise gr.Error(
300
- "Prompt must be at least 50 characters long. Please provide more details for the best results.",
301
- duration=5,
302
- )
303
-
304
- if txt2vid_analytics_toggle:
305
- log_request(
306
- "txt2vid",
307
- prompt,
308
- negative_prompt,
309
- height,
310
- width,
311
- num_frames,
312
- frame_rate,
313
- seed,
314
- num_inference_steps,
315
- guidance_scale,
316
- enhance_prompt_toggle,
317
- )
318
-
319
- prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle, type="t2v")
320
-
321
- sample = {
322
  "prompt": prompt,
323
- "prompt_attention_mask": None,
324
  "negative_prompt": negative_prompt,
325
- "negative_prompt_attention_mask": None,
 
 
 
 
 
 
326
  "media_items": None,
 
 
 
 
 
 
 
 
 
327
  }
328
 
329
- generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
330
 
331
- def gradio_progress_callback(self, step, timestep, kwargs):
332
- progress((step + 1) / num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
- try:
335
- with torch.no_grad():
336
- images = pipeline(
337
- num_inference_steps=num_inference_steps,
338
- num_images_per_prompt=1,
339
- guidance_scale=guidance_scale,
340
- generator=generator,
341
- output_type="pt",
342
- height=height,
343
- width=width,
344
- num_frames=num_frames,
345
- frame_rate=frame_rate,
346
- **sample,
347
- is_video=True,
348
- vae_per_channel_normalize=True,
349
- conditioning_method=ConditioningMethod.UNCONDITIONAL,
350
- mixed_precision=True,
351
- callback_on_step_end=gradio_progress_callback,
352
- stg_scale=stg_scale,
353
- do_rescaling=stg_rescale != 1,
354
- rescaling_scale=stg_rescale,
355
- skip_layer_strategy=SkipLayerStrategy.Attention if stg_mode == "stg_a" else SkipLayerStrategy.Residual,
356
- skip_block_list=[int(x.strip()) for x in stg_skip_layers.split(",")]
357
- ).images
358
- except Exception as e:
359
- raise gr.Error(
360
- f"An error occurred while generating the video. Please try again. Error: {e}",
361
- duration=5,
362
- )
363
- finally:
364
- torch.cuda.empty_cache()
365
- gc.collect()
366
-
367
- output_path = tempfile.mktemp(suffix=".mp4")
368
- print(images.shape)
369
- video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
370
  video_np = (video_np * 255).astype(np.uint8)
371
- height, width = video_np.shape[1:3]
372
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
373
- for frame in video_np[..., ::-1]:
374
- out.write(frame)
375
- out.release()
376
- # Explicitly delete tensors and clear cache
377
- del images
378
- del video_np
379
- torch.cuda.empty_cache()
380
- return output_path
381
-
382
- @spaces.GPU(duration=120)
383
- def generate_video_from_image(
384
- image_path,
385
- prompt="",
386
- enhance_prompt_toggle=False,
387
- img2vid_analytics_toggle=True,
388
- negative_prompt="",
389
- frame_rate=25,
390
- seed=646373,
391
- num_inference_steps=30,
392
- guidance_scale=3,
393
- height=512,
394
- width=768,
395
- num_frames=121,
396
- progress=gr.Progress(),
397
- stg_scale=1.0,
398
- stg_rescale=0.7,
399
- stg_mode="stg_a",
400
- stg_skip_layers="19",
401
- ):
402
-
403
- print("Height: ", height)
404
- print("Width: ", width)
405
- print("Num Frames: ", num_frames)
406
-
407
- if len(prompt.strip()) < 50:
408
- raise gr.Error(
409
- "Prompt must be at least 50 characters long. Please provide more details for the best results.",
410
- duration=5,
411
- )
412
-
413
- if not image_path:
414
- raise gr.Error("Please provide an input image.", duration=5)
415
-
416
- if img2vid_analytics_toggle:
417
- with Image.open(image_path) as img:
418
- original_resolution = f"{img.width}x{img.height}" # Format as "widthxheight"
419
- clip_embedding = compute_clip_embedding(image=img)
420
-
421
- log_request(
422
- "img2vid",
423
- prompt,
424
- negative_prompt,
425
- height,
426
- width,
427
- num_frames,
428
- frame_rate,
429
- seed,
430
- num_inference_steps,
431
- guidance_scale,
432
- enhance_prompt_toggle,
433
- json.dumps(clip_embedding),
434
- original_resolution,
435
- )
436
-
437
- media_items = load_image_to_tensor_with_resize(image_path, height, width).to(device).detach()
438
-
439
- prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle, type="i2v")
440
-
441
- sample = {
442
- "prompt": prompt,
443
- "prompt_attention_mask": None,
444
- "negative_prompt": negative_prompt,
445
- "negative_prompt_attention_mask": None,
446
- "media_items": media_items,
447
- }
448
-
449
- generator = torch.Generator(device=device).manual_seed(seed)
450
-
451
- def gradio_progress_callback(self, step, timestep, kwargs):
452
- progress((step + 1) / num_inference_steps)
453
 
 
 
 
 
454
  try:
455
- with torch.no_grad():
456
- images = pipeline(
457
- num_inference_steps=num_inference_steps,
458
- num_images_per_prompt=1,
459
- guidance_scale=guidance_scale,
460
- generator=generator,
461
- output_type="pt",
462
- height=height,
463
- width=width,
464
- num_frames=num_frames,
465
- frame_rate=frame_rate,
466
- **sample,
467
- is_video=True,
468
- vae_per_channel_normalize=True,
469
- conditioning_method=ConditioningMethod.FIRST_FRAME,
470
- mixed_precision=True,
471
- callback_on_step_end=gradio_progress_callback,
472
- stg_scale=stg_scale,
473
- do_rescaling=stg_rescale != 1,
474
- rescaling_scale=stg_rescale,
475
- skip_layer_strategy=SkipLayerStrategy.Attention if stg_mode == "stg_a" else SkipLayerStrategy.Residual,
476
- skip_block_list=[int(x.strip()) for x in stg_skip_layers.split(",")]
477
- ).images
478
-
479
- output_path = tempfile.mktemp(suffix=".mp4")
480
- video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
481
- video_np = (video_np * 255).astype(np.uint8)
482
- height, width = video_np.shape[1:3]
483
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
484
- for frame in video_np[..., ::-1]:
485
- out.write(frame)
486
- out.release()
487
  except Exception as e:
488
- raise gr.Error(
489
- f"An error occurred while generating the video. Please try again. Error: {e}",
490
- duration=5,
491
- )
492
-
493
- finally:
494
- torch.cuda.empty_cache()
495
- gc.collect()
496
-
497
- return output_path
498
-
499
-
500
- def create_advanced_options():
501
- with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
502
- seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
503
- inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=1, maximum=50, step=1, value=30)
504
- guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=3.0)
505
-
506
- height_slider = gr.Slider(
507
- label="4.4 Height",
508
- minimum=256,
509
- maximum=1024,
510
- step=64,
511
- value=512,
512
- visible=False,
513
- )
514
- width_slider = gr.Slider(
515
- label="4.5 Width",
516
- minimum=256,
517
- maximum=1024,
518
- step=64,
519
- value=768,
520
- visible=False,
521
- )
522
- num_frames_slider = gr.Slider(
523
- label="4.5 Number of Frames",
524
- minimum=1,
525
- maximum=200,
526
- step=1,
527
- value=121,
528
- visible=False,
529
- )
530
-
531
- return [
532
- seed,
533
- inference_steps,
534
- guidance_scale,
535
- height_slider,
536
- width_slider,
537
- num_frames_slider,
538
- ]
539
-
540
-
541
- # Define the Gradio interface with tabs
542
- with gr.Blocks(theme=gr.themes.Soft()) as iface:
543
- with gr.Row(elem_id="title-row"):
544
- gr.Markdown(
545
- """
546
- <div style="text-align: center; margin-bottom: 1em">
547
- <h1 style="font-size: 2.5em; font-weight: 600; margin: 0.5em 0;">Video Generation with LTX Video</h1>
548
- </div>
549
- """
550
- )
551
- with gr.Row(elem_id="title-row"):
552
- gr.HTML( # add technical report link
553
- """
554
- <div style="display:flex;column-gap:4px;">
555
- <a href="https://github.com/Lightricks/LTX-Video">
556
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
557
- </a>
558
- <a href="https://github.com/Lightricks/ComfyUI-LTXVideo">
559
- <img src='https://img.shields.io/badge/GitHub-ComfyUI-blue'>
560
- </a>
561
- <a href="http://www.lightricks.com/ltxv">
562
- <img src="https://img.shields.io/badge/Project-Page-green" alt="Follow me on HF">
563
- </a>
564
- <a href="https://huggingface.co/spaces/Lightricks/LTX-Video-Playground?duplicate=true">
565
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
566
- </a>
567
- <a href="https://huggingface.co/Lightricks">
568
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
569
- </a>
570
- </div>
571
- """
572
- )
573
- with gr.Accordion(" 📖 Tips for Best Results", open=False, elem_id="instructions-accordion"):
574
- gr.Markdown(
575
- """
576
- 📝 Prompt Engineering
577
-
578
- When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words.
579
- For best results, build your prompts using this structure:
580
-
581
- - Start with main action in a single sentence
582
- - Add specific details about movements and gestures
583
- - Describe character/object appearances precisely
584
- - Include background and environment details
585
- - Specify camera angles and movements
586
- - Describe lighting and colors
587
- - Note any changes or sudden events
588
-
589
- See examples for more inspiration.
590
-
591
- 🎮 Parameter Guide
592
-
593
- - Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes
594
- - Seed: Save seed values to recreate specific styles or compositions you like
595
- - Guidance Scale: 3-3.5 are the recommended values
596
- - Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed
597
- """
598
- )
599
-
600
- with gr.Tabs():
601
- # Text to Video Tab
602
- with gr.TabItem("Text to Video"):
603
- with gr.Row():
604
- with gr.Column():
605
- txt2vid_prompt = gr.Textbox(
606
- label="Step 1: Enter Your Prompt",
607
- placeholder="Describe the video you want to generate (minimum 50 characters)...",
608
- value="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.",
609
- lines=5,
610
- )
611
- txt2vid_analytics_toggle = Toggle(
612
- label="I agree to share my usage data anonymously to help improve the model features.",
613
- value=True,
614
- interactive=True,
615
- )
616
-
617
- txt2vid_enhance_toggle = Toggle(
618
- label="Enhance Prompt",
619
- value=False,
620
- interactive=True,
621
- )
622
-
623
- txt2vid_negative_prompt = gr.Textbox(
624
- label="Step 2: Enter Negative Prompt",
625
- placeholder="Describe what you don't want in the video...",
626
- value="low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
627
- lines=2,
628
- )
629
-
630
- txt2vid_preset = gr.Dropdown(
631
- choices=[p["label"] for p in preset_options],
632
- value="768x512, 97 frames",
633
- label="Step 3.1: Choose Resolution Preset",
634
- )
635
-
636
- txt2vid_frame_rate = gr.Slider(
637
- label="Step 3.2: Frame Rate",
638
- minimum=21,
639
- maximum=30,
640
- step=1,
641
- value=25,
642
- )
643
-
644
- txt2vid_advanced = create_advanced_options()
645
- txt2vid_generate = gr.Button(
646
- "Step 5: Generate Video",
647
- variant="primary",
648
- size="lg",
649
- )
650
-
651
- with gr.Column():
652
- txt2vid_output = gr.Video(label="Generated Output")
653
-
654
- with gr.Row():
655
- gr.Examples(
656
- examples=[
657
- [
658
- "A young woman in a traditional Mongolian dress is peeking through a sheer white curtain, her face showing a mix of curiosity and apprehension. The woman has long black hair styled in two braids, adorned with white beads, and her eyes are wide with a hint of surprise. Her dress is a vibrant blue with intricate gold embroidery, and she wears a matching headband with a similar design. The background is a simple white curtain, which creates a sense of mystery and intrigue.ith long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair’s face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
659
- "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
660
- "assets/t2v_2.mp4",
661
- ],
662
- [
663
- "A young man with blond hair wearing a yellow jacket stands in a forest and looks around. He has light skin and his hair is styled with a middle part. He looks to the left and then to the right, his gaze lingering in each direction. The camera angle is low, looking up at the man, and remains stationary throughout the video. The background is slightly out of focus, with green trees and the sun shining brightly behind the man. The lighting is natural and warm, with the sun creating a lens flare that moves across the man’s face. The scene is captured in real-life footage.",
664
- "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
665
- "assets/t2v_1.mp4",
666
- ],
667
- [
668
- "A cyclist races along a winding mountain road. Clad in aerodynamic gear, he pedals intensely, sweat glistening on his brow. The camera alternates between close-ups of his determined expression and wide shots of the breathtaking landscape. Pine trees blur past, and the sky is a crisp blue. The scene is invigorating and competitive.",
669
- "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
670
- "assets/t2v_0.mp4",
671
- ],
672
- ],
673
- inputs=[txt2vid_prompt, txt2vid_negative_prompt, txt2vid_output],
674
- label="Example Text-to-Video Generations",
675
- )
676
-
677
- # Image to Video Tab
678
- with gr.TabItem("Image to Video"):
679
- with gr.Row():
680
- with gr.Column():
681
- img2vid_image = gr.Image(
682
- type="filepath",
683
- label="Step 1: Upload Input Image",
684
- elem_id="image_upload",
685
- )
686
- img2vid_prompt = gr.Textbox(
687
- label="Step 2: Enter Your Prompt",
688
- placeholder="Describe how you want to animate the image (minimum 50 characters)...",
689
- value="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.",
690
- lines=5,
691
- )
692
- img2vid_analytics_toggle = Toggle(
693
- label="I agree to share my usage data anonymously to help improve the model features.",
694
- value=True,
695
- interactive=True,
696
- )
697
- img2vid_enhance_toggle = Toggle(
698
- label="Enhance Prompt",
699
- value=False,
700
- interactive=True,
701
- )
702
- img2vid_negative_prompt = gr.Textbox(
703
- label="Step 3: Enter Negative Prompt",
704
- placeholder="Describe what you don't want in the video...",
705
- value="low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
706
- lines=2,
707
- )
708
-
709
- img2vid_preset = gr.Dropdown(
710
- choices=[p["label"] for p in preset_options],
711
- value="768x512, 97 frames",
712
- label="Step 3.1: Choose Resolution Preset",
713
- )
714
-
715
- img2vid_frame_rate = gr.Slider(
716
- label="Step 3.2: Frame Rate",
717
- minimum=21,
718
- maximum=30,
719
- step=1,
720
- value=25,
721
- )
722
-
723
- img2vid_advanced = create_advanced_options()
724
- img2vid_generate = gr.Button("Step 6: Generate Video", variant="primary", size="lg")
725
-
726
- with gr.Column():
727
- img2vid_output = gr.Video(label="Generated Output")
728
-
729
- with gr.Row():
730
- gr.Examples(
731
- examples=[
732
- [
733
- "assets/i2v_i2.png",
734
- "A woman stirs a pot of boiling water on a white electric burner. Her hands, with purple nail polish, hold a wooden spoon and move it in a circular motion within a white pot filled with bubbling water. The pot sits on a white electric burner with black buttons and a digital display. The burner is positioned on a white countertop with a red and white checkered cloth partially visible in the bottom right corner. The camera angle is a direct overhead shot, remaining stationary throughout the scene. The lighting is bright and even, illuminating the scene with a neutral white light. The scene is real-life footage.",
735
- "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
736
- "assets/i2v_2.mp4",
737
- ],
738
- [
739
- "assets/i2v_i0.png",
740
- "A woman in a long, flowing dress stands in a field, her back to the camera, gazing towards the horizon; her hair is long and light, cascading down her back; she stands beneath the sprawling branches of a large oak tree; to her left, a classic American car is parked on the dry grass; in the distance, a wrecked car lies on its side; the sky above is a dramatic canvas of bright white clouds against a darker sky; the entire image is in black and white, emphasizing the contrast of light and shadow. The woman is walking slowly towards the car.",
741
- "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
742
- "assets/i2v_0.mp4",
743
- ],
744
- [
745
- "assets/i2v_i1.png",
746
- "A pair of hands shapes a piece of clay on a pottery wheel, gradually forming a cone shape. The hands, belonging to a person out of frame, are covered in clay and gently press a ball of clay onto the center of a spinning pottery wheel. The hands move in a circular motion, gradually forming a cone shape at the top of the clay. The camera is positioned directly above the pottery wheel, providing a bird’s-eye view of the clay being shaped. The lighting is bright and even, illuminating the clay and the hands working on it. The scene is captured in real-life footage.",
747
- "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
748
- "assets/i2v_1.mp4",
749
- ],
750
- ],
751
- inputs=[
752
- img2vid_image,
753
- img2vid_prompt,
754
- img2vid_negative_prompt,
755
- img2vid_output,
756
- ],
757
- label="Example Image-to-Video Generations",
758
- )
759
-
760
- # [Previous event handlers remain the same]
761
- txt2vid_preset.change(fn=preset_changed, inputs=[txt2vid_preset], outputs=txt2vid_advanced[3:])
762
-
763
- txt2vid_generate.click(
764
- fn=generate_video_from_text,
765
- inputs=[
766
- txt2vid_prompt,
767
- txt2vid_enhance_toggle,
768
- txt2vid_analytics_toggle,
769
- txt2vid_negative_prompt,
770
- txt2vid_frame_rate,
771
- *txt2vid_advanced,
772
- ],
773
- outputs=txt2vid_output,
774
- concurrency_limit=1,
775
- concurrency_id="generate_video",
776
  )
777
 
778
- img2vid_preset.change(fn=preset_changed, inputs=[img2vid_preset], outputs=img2vid_advanced[3:])
779
-
780
- img2vid_generate.click(
781
- fn=generate_video_from_image,
782
- inputs=[
783
- img2vid_image,
784
- img2vid_prompt,
785
- img2vid_enhance_toggle,
786
- img2vid_analytics_toggle,
787
- img2vid_negative_prompt,
788
- img2vid_frame_rate,
789
- *img2vid_advanced,
790
- ],
791
- outputs=img2vid_output,
792
- concurrency_limit=1,
793
- concurrency_id="generate_video",
794
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
 
796
  if __name__ == "__main__":
797
- iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False)
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
+ import random
 
 
6
  import os
7
+ import yaml
8
+ from pathlib import Path
9
+ import imageio
10
+ import tempfile
11
+ from PIL import Image
12
+ from huggingface_hub import hf_hub_download
13
+ import shutil
14
+
15
+ from inference import (
16
+ create_ltx_video_pipeline,
17
+ create_latent_upsampler,
18
+ load_image_to_tensor_with_resize_and_crop,
19
+ seed_everething,
20
+ get_device,
21
+ calculate_padding,
22
+ load_media_file
23
+ )
24
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
25
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
26
 
27
+ config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml"
28
+ with open(config_file_path, "r") as file:
29
+ PIPELINE_CONFIG_YAML = yaml.safe_load(file)
30
+
31
+ LTX_REPO = "Lightricks/LTX-Video"
32
+ MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280)
33
+ MAX_NUM_FRAMES = 257
34
+
35
+ FPS = 30.0
36
+
37
+ # --- Global variables for loaded models ---
38
+ pipeline_instance = None
39
+ latent_upsampler_instance = None
40
+ models_dir = "downloaded_models_gradio_cpu_init"
41
+ Path(models_dir).mkdir(parents=True, exist_ok=True)
42
+
43
+ print("Downloading models (if not present)...")
44
+ distilled_model_actual_path = hf_hub_download(
45
+ repo_id=LTX_REPO,
46
+ filename=PIPELINE_CONFIG_YAML["checkpoint_path"],
47
+ local_dir=models_dir,
48
+ local_dir_use_symlinks=False
49
+ )
50
+ PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
51
+ print(f"Distilled model path: {distilled_model_actual_path}")
52
+
53
+ SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
54
+ spatial_upscaler_actual_path = hf_hub_download(
55
+ repo_id=LTX_REPO,
56
+ filename=SPATIAL_UPSCALER_FILENAME,
57
+ local_dir=models_dir,
58
+ local_dir_use_symlinks=False
59
+ )
60
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
61
+ print(f"Spatial upscaler model path: {spatial_upscaler_actual_path}")
62
+
63
+ print("Creating LTX Video pipeline on CPU...")
64
+ pipeline_instance = create_ltx_video_pipeline(
65
+ ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
66
+ precision=PIPELINE_CONFIG_YAML["precision"],
67
+ text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
68
+ sampler=PIPELINE_CONFIG_YAML["sampler"],
69
+ device="cpu",
70
+ enhance_prompt=False,
71
+ prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
72
+ prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
73
+ )
74
+ print("LTX Video pipeline created on CPU.")
75
+
76
+ if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
77
+ print("Creating latent upsampler on CPU...")
78
+ latent_upsampler_instance = create_latent_upsampler(
79
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"],
80
+ device="cpu"
81
+ )
82
+ print("Latent upsampler created on CPU.")
83
 
84
+ target_inference_device = "cuda"
85
+ print(f"Target inference device: {target_inference_device}")
86
+ pipeline_instance.to(target_inference_device)
87
+ if latent_upsampler_instance:
88
+ latent_upsampler_instance.to(target_inference_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
+ # --- Helper function for dimension calculation ---
92
+ MIN_DIM_SLIDER = 256 # As defined in the sliders minimum attribute
93
+ TARGET_FIXED_SIDE = 768 # Desired fixed side length as per requirement
94
+
95
+ def calculate_new_dimensions(orig_w, orig_h):
96
  """
97
+ Calculates new dimensions for height and width sliders based on original media dimensions.
98
+ Ensures one side is TARGET_FIXED_SIDE, the other is scaled proportionally,
99
+ both are multiples of 32, and within [MIN_DIM_SLIDER, MAX_IMAGE_SIZE].
 
 
 
100
  """
101
+ if orig_w == 0 or orig_h == 0:
102
+ # Default to TARGET_FIXED_SIDE square if original dimensions are invalid
103
+ return int(TARGET_FIXED_SIDE), int(TARGET_FIXED_SIDE)
104
+
105
+ if orig_w >= orig_h: # Landscape or square
106
+ new_h = TARGET_FIXED_SIDE
107
+ aspect_ratio = orig_w / orig_h
108
+ new_w_ideal = new_h * aspect_ratio
109
+
110
+ # Round to nearest multiple of 32
111
+ new_w = round(new_w_ideal / 32) * 32
112
+
113
+ # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
114
+ new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
115
+ # Ensure new_h is also clamped (TARGET_FIXED_SIDE should be within these bounds if configured correctly)
116
+ new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
117
+ else: # Portrait
118
+ new_w = TARGET_FIXED_SIDE
119
+ aspect_ratio = orig_h / orig_w # Use H/W ratio for portrait scaling
120
+ new_h_ideal = new_w * aspect_ratio
121
+
122
+ # Round to nearest multiple of 32
123
+ new_h = round(new_h_ideal / 32) * 32
124
+
125
+ # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
126
+ new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
127
+ # Ensure new_w is also clamped
128
+ new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
129
+
130
+ return int(new_h), int(new_w)
131
+
132
+ def get_duration(prompt, negative_prompt, input_image_filepath, input_video_filepath,
133
+ height_ui, width_ui, mode,
134
+ duration_ui, # Removed ui_steps
135
+ ui_frames_to_use,
136
+ seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
137
+ progress):
138
+ if duration_ui > 7:
139
+ return 75
140
+ else:
141
+ return 60
142
+
143
+ @spaces.GPU(duration=get_duration)
144
+ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_filepath=None,
145
+ height_ui=512, width_ui=704, mode="text-to-video",
146
+ duration_ui=2.0,
147
+ ui_frames_to_use=9,
148
+ seed_ui=42, randomize_seed=True, ui_guidance_scale=3.0, improve_texture_flag=True,
149
+ progress=gr.Progress(track_tqdm=True)):
150
+ """
151
+ Generate high-quality videos using LTX Video model with support for text-to-video, image-to-video, and video-to-video modes.
152
 
153
+ Args:
154
+ prompt (str): Text description of the desired video content. Required for all modes.
155
+ negative_prompt (str): Text describing what to avoid in the generated video. Optional, can be empty string.
156
+ input_image_filepath (str or None): Path to input image file. Required for image-to-video mode, None for other modes.
157
+ input_video_filepath (str or None): Path to input video file. Required for video-to-video mode, None for other modes.
158
+ height_ui (int): Height of the output video in pixels, must be divisible by 32. Default: 512.
159
+ width_ui (int): Width of the output video in pixels, must be divisible by 32. Default: 704.
160
+ mode (str): Generation mode. Required. One of "text-to-video", "image-to-video", or "video-to-video". Default: "text-to-video".
161
+ duration_ui (float): Duration of the output video in seconds. Range: 0.3 to 8.5. Default: 2.0.
162
+ ui_frames_to_use (int): Number of frames to use from input video. Only used in video-to-video mode. Must be N*8+1. Default: 9.
163
+ seed_ui (int): Random seed for reproducible generation. Range: 0 to 2^32-1. Default: 42.
164
+ randomize_seed (bool): Whether to use a random seed instead of seed_ui. Default: True.
165
+ ui_guidance_scale (float): CFG scale controlling prompt influence. Range: 1.0 to 10.0. Higher values = stronger prompt influence. Default: 3.0.
166
+ improve_texture_flag (bool): Whether to use multi-scale generation for better texture quality. Slower but higher quality. Default: True.
167
+ progress (gr.Progress): Progress tracker for the generation process. Optional, used for UI updates.
168
 
169
+ Returns:
170
+ tuple: A tuple containing (output_video_path, used_seed) where output_video_path is the path to the generated video file and used_seed is the actual seed used for generation.
171
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # Validate mode-specific required parameters
174
+ if mode == "image-to-video":
175
+ if not input_image_filepath:
176
+ raise gr.Error("input_image_filepath is required for image-to-video mode")
177
+ elif mode == "video-to-video":
178
+ if not input_video_filepath:
179
+ raise gr.Error("input_video_filepath is required for video-to-video mode")
180
+ elif mode == "text-to-video":
181
+ # No additional file inputs required for text-to-video
182
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  else:
184
+ raise gr.Error(f"Invalid mode: {mode}. Must be one of: text-to-video, image-to-video, video-to-video")
185
+
186
+ if randomize_seed:
187
+ seed_ui = random.randint(0, 2**32 - 1)
188
+ seed_everething(int(seed_ui))
189
+
190
+ target_frames_ideal = duration_ui * FPS
191
+ target_frames_rounded = round(target_frames_ideal)
192
+ if target_frames_rounded < 1:
193
+ target_frames_rounded = 1
194
+
195
+ n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
196
+ actual_num_frames = int(n_val * 8 + 1)
197
+
198
+ actual_num_frames = max(9, actual_num_frames)
199
+ actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
200
+
201
+ actual_height = int(height_ui)
202
+ actual_width = int(width_ui)
203
+
204
+ height_padded = ((actual_height - 1) // 32 + 1) * 32
205
+ width_padded = ((actual_width - 1) // 32 + 1) * 32
206
+ num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
207
+ if num_frames_padded != actual_num_frames:
208
+ print(f"Warning: actual_num_frames ({actual_num_frames}) and num_frames_padded ({num_frames_padded}) differ. Using num_frames_padded for pipeline.")
209
+
210
+ padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
211
+
212
+ call_kwargs = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  "prompt": prompt,
 
214
  "negative_prompt": negative_prompt,
215
+ "height": height_padded,
216
+ "width": width_padded,
217
+ "num_frames": num_frames_padded,
218
+ "frame_rate": int(FPS),
219
+ "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
220
+ "output_type": "pt",
221
+ "conditioning_items": None,
222
  "media_items": None,
223
+ "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
224
+ "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"],
225
+ "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"],
226
+ "image_cond_noise_scale": 0.15,
227
+ "is_video": True,
228
+ "vae_per_channel_normalize": True,
229
+ "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"),
230
+ "offload_to_cpu": False,
231
+ "enhance_prompt": False,
232
  }
233
 
234
+ stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values")
235
+ if stg_mode_str.lower() in ["stg_av", "attention_values"]:
236
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionValues
237
+ elif stg_mode_str.lower() in ["stg_as", "attention_skip"]:
238
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionSkip
239
+ elif stg_mode_str.lower() in ["stg_r", "residual"]:
240
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.Residual
241
+ elif stg_mode_str.lower() in ["stg_t", "transformer_block"]:
242
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.TransformerBlock
243
+ else:
244
+ raise ValueError(f"Invalid stg_mode: {stg_mode_str}")
245
 
246
+ if mode == "image-to-video" and input_image_filepath:
247
+ try:
248
+ media_tensor = load_image_to_tensor_with_resize_and_crop(
249
+ input_image_filepath, actual_height, actual_width
250
+ )
251
+ media_tensor = torch.nn.functional.pad(media_tensor, padding_values)
252
+ call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_inference_device), 0, 1.0)]
253
+ except Exception as e:
254
+ print(f"Error loading image {input_image_filepath}: {e}")
255
+ raise gr.Error(f"Could not load image: {e}")
256
+ elif mode == "video-to-video" and input_video_filepath:
257
+ try:
258
+ call_kwargs["media_items"] = load_media_file(
259
+ media_path=input_video_filepath,
260
+ height=actual_height,
261
+ width=actual_width,
262
+ max_frames=int(ui_frames_to_use),
263
+ padding=padding_values
264
+ ).to(target_inference_device)
265
+ except Exception as e:
266
+ print(f"Error loading video {input_video_filepath}: {e}")
267
+ raise gr.Error(f"Could not load video: {e}")
268
+
269
+ print(f"Moving models to {target_inference_device} for inference (if not already there)...")
270
+
271
+ active_latent_upsampler = None
272
+ if improve_texture_flag and latent_upsampler_instance:
273
+ active_latent_upsampler = latent_upsampler_instance
274
+
275
+ result_images_tensor = None
276
+ if improve_texture_flag:
277
+ if not active_latent_upsampler:
278
+ raise gr.Error("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
279
+
280
+ multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
281
+
282
+ first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
283
+ first_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
284
+ # num_inference_steps will be derived from len(timesteps) in the pipeline
285
+ first_pass_args.pop("num_inference_steps", None)
286
+
287
+
288
+ second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
289
+ second_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
290
+ # num_inference_steps will be derived from len(timesteps) in the pipeline
291
+ second_pass_args.pop("num_inference_steps", None)
292
+
293
+ multi_scale_call_kwargs = call_kwargs.copy()
294
+ multi_scale_call_kwargs.update({
295
+ "downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
296
+ "first_pass": first_pass_args,
297
+ "second_pass": second_pass_args,
298
+ })
299
+
300
+ print(f"Calling multi-scale pipeline (eff. HxW: {actual_height}x{actual_width}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
301
+ result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
302
+ else:
303
+ single_pass_call_kwargs = call_kwargs.copy()
304
+ first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
305
+
306
+ single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
307
+ single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
308
+ single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
309
+ single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
310
+ single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
311
+
312
+ # Remove keys that might conflict or are not used in single pass / handled by above
313
+ single_pass_call_kwargs.pop("num_inference_steps", None)
314
+ single_pass_call_kwargs.pop("first_pass", None)
315
+ single_pass_call_kwargs.pop("second_pass", None)
316
+ single_pass_call_kwargs.pop("downscale_factor", None)
317
+
318
+ print(f"Calling base pipeline (padded HxW: {height_padded}x{width_padded}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
319
+ result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
320
+
321
+ if result_images_tensor is None:
322
+ raise gr.Error("Generation failed.")
323
+
324
+ pad_left, pad_right, pad_top, pad_bottom = padding_values
325
+ slice_h_end = -pad_bottom if pad_bottom > 0 else None
326
+ slice_w_end = -pad_right if pad_right > 0 else None
327
+
328
+ result_images_tensor = result_images_tensor[
329
+ :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
330
+ ]
331
 
332
+ video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
333
+
334
+ video_np = np.clip(video_np, 0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  video_np = (video_np * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ temp_dir = tempfile.mkdtemp()
338
+ timestamp = random.randint(10000,99999)
339
+ output_video_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
340
+
341
  try:
342
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
343
+ for frame_idx in range(video_np.shape[0]):
344
+ progress(frame_idx / video_np.shape[0], desc="Saving video")
345
+ video_writer.append_data(video_np[frame_idx])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  except Exception as e:
347
+ print(f"Error saving video with macro_block_size=1: {e}")
348
+ try:
349
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264', quality=8) as video_writer:
350
+ for frame_idx in range(video_np.shape[0]):
351
+ progress(frame_idx / video_np.shape[0], desc="Saving video (fallback ffmpeg)")
352
+ video_writer.append_data(video_np[frame_idx])
353
+ except Exception as e2:
354
+ print(f"Fallback video saving error: {e2}")
355
+ raise gr.Error(f"Failed to save video: {e2}")
356
+
357
+ return output_video_path, seed_ui
358
+
359
+ def update_task_image():
360
+ return "image-to-video"
361
+
362
+ def update_task_text():
363
+ return "text-to-video"
364
+
365
+ def update_task_video():
366
+ return "video-to-video"
367
+
368
+ # --- Gradio UI Definition ---
369
+ css="""
370
+ #col-container {
371
+ margin: 0 auto;
372
+ max-width: 900px;
373
+ }
374
+ """
375
+
376
+ with gr.Blocks(css=css) as demo:
377
+ gr.Markdown("# LTX Video 0.9.8 13B Distilled")
378
+ gr.Markdown("Fast high quality video generation.**Update (17/07):** now with the new v0.9.8 for improved prompt understanding and detail generation" )
379
+ gr.Markdown("[Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.8-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](https://huggingface.co/Lightricks/LTX-Video-0.9.8-13B-distilled#diffusers-🧨)")
380
+ with gr.Row():
381
+ with gr.Column():
382
+ with gr.Tab("image-to-video") as image_tab:
383
+ video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None)
384
+ image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "webcam", "clipboard"])
385
+ i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3)
386
+ i2v_button = gr.Button("Generate Image-to-Video", variant="primary")
387
+ with gr.Tab("text-to-video") as text_tab:
388
+ image_n_hidden = gr.Textbox(label="image_n", visible=False, value=None)
389
+ video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None)
390
+ t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3)
391
+ t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
392
+ with gr.Tab("video-to-video", visible=False) as video_tab:
393
+ image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
394
+ video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"]) # type defaults to filepath
395
+ frames_to_use = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames to use for conditioning/transformation. Must be N*8+1.")
396
+ v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3)
397
+ v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
398
+
399
+ duration_input = gr.Slider(
400
+ label="Video Duration (seconds)",
401
+ minimum=0.3,
402
+ maximum=8.5,
403
+ value=2,
404
+ step=0.1,
405
+ info=f"Target video duration (0.3s to 8.5s)"
406
+ )
407
+ improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True,visible=False, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
408
+
409
+ with gr.Column():
410
+ output_video = gr.Video(label="Generated Video", interactive=False)
411
+ # gr.DeepLinkButton()
412
+
413
+ with gr.Accordion("Advanced settings", open=False):
414
+ mode = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="task", value="image-to-video", visible=False)
415
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
416
+ with gr.Row():
417
+ seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
418
+ randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
419
+ with gr.Row(visible=False):
420
+ guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
421
+ with gr.Row():
422
+ height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
423
+ width_input = gr.Slider(label="Width", value=704, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
424
+
425
+
426
+ # --- Event handlers for updating dimensions on upload ---
427
+ def handle_image_upload_for_dims(image_filepath, current_h, current_w):
428
+ if not image_filepath: # Image cleared or no image initially
429
+ # Keep current slider values if image is cleared or no input
430
+ return gr.update(value=current_h), gr.update(value=current_w)
431
+ try:
432
+ img = Image.open(image_filepath)
433
+ orig_w, orig_h = img.size
434
+ new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
435
+ return gr.update(value=new_h), gr.update(value=new_w)
436
+ except Exception as e:
437
+ print(f"Error processing image for dimension update: {e}")
438
+ # Keep current slider values on error
439
+ return gr.update(value=current_h), gr.update(value=current_w)
440
+
441
+ def handle_video_upload_for_dims(video_filepath, current_h, current_w):
442
+ if not video_filepath: # Video cleared or no video initially
443
+ return gr.update(value=current_h), gr.update(value=current_w)
444
+ try:
445
+ # Ensure video_filepath is a string for os.path.exists and imageio
446
+ video_filepath_str = str(video_filepath)
447
+ if not os.path.exists(video_filepath_str):
448
+ print(f"Video file path does not exist for dimension update: {video_filepath_str}")
449
+ return gr.update(value=current_h), gr.update(value=current_w)
450
+
451
+ orig_w, orig_h = -1, -1
452
+ with imageio.get_reader(video_filepath_str) as reader:
453
+ meta = reader.get_meta_data()
454
+ if 'size' in meta:
455
+ orig_w, orig_h = meta['size']
456
+ else:
457
+ # Fallback: read first frame if 'size' not in metadata
458
+ try:
459
+ first_frame = reader.get_data(0)
460
+ # Shape is (h, w, c) for frames
461
+ orig_h, orig_w = first_frame.shape[0], first_frame.shape[1]
462
+ except Exception as e_frame:
463
+ print(f"Could not get video size from metadata or first frame: {e_frame}")
464
+ return gr.update(value=current_h), gr.update(value=current_w)
465
+
466
+ if orig_w == -1 or orig_h == -1: # If dimensions couldn't be determined
467
+ print(f"Could not determine dimensions for video: {video_filepath_str}")
468
+ return gr.update(value=current_h), gr.update(value=current_w)
469
+
470
+ new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
471
+ return gr.update(value=new_h), gr.update(value=new_w)
472
+ except Exception as e:
473
+ # Log type of video_filepath for debugging if it's not a path-like string
474
+ print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
475
+ return gr.update(value=current_h), gr.update(value=current_w)
476
+
477
+
478
+ image_i2v.upload(
479
+ fn=handle_image_upload_for_dims,
480
+ inputs=[image_i2v, height_input, width_input],
481
+ outputs=[height_input, width_input]
482
+ )
483
+ video_v2v.upload(
484
+ fn=handle_video_upload_for_dims,
485
+ inputs=[video_v2v, height_input, width_input],
486
+ outputs=[height_input, width_input]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  )
488
 
489
+ image_tab.select(
490
+ fn=update_task_image,
491
+ outputs=[mode]
492
+ )
493
+ text_tab.select(
494
+ fn=update_task_text,
495
+ outputs=[mode]
 
 
 
 
 
 
 
 
 
496
  )
497
+
498
+ t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
499
+ height_input, width_input, mode,
500
+ duration_input, frames_to_use,
501
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
502
+
503
+ i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
504
+ height_input, width_input, mode,
505
+ duration_input, frames_to_use,
506
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
507
+
508
+ v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
509
+ height_input, width_input, mode,
510
+ duration_input, frames_to_use,
511
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
512
+
513
+ t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video, seed_input], api_name="text_to_video")
514
+ i2v_button.click(fn=generate, inputs=i2v_inputs, outputs=[output_video, seed_input], api_name="image_to_video")
515
+ v2v_button.click(fn=generate, inputs=v2v_inputs, outputs=[output_video, seed_input], api_name="video_to_video")
516
 
517
  if __name__ == "__main__":
518
+ if os.path.exists(models_dir) and os.path.isdir(models_dir):
519
+ print(f"Model directory: {Path(models_dir).resolve()}")
520
+
521
+ demo.queue().launch(debug=True, share=False, mcp_server=True)
inference.py CHANGED
@@ -1,162 +1,206 @@
1
- import torch
2
- from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
- from xora.models.transformers.transformer3d import Transformer3DModel
4
- from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
- from xora.schedulers.rf import RectifiedFlowScheduler
6
- from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
7
- from pathlib import Path
8
- from transformers import T5EncoderModel, T5Tokenizer
9
- import safetensors.torch
10
- import json
11
  import argparse
12
- from xora.utils.conditioning_method import ConditioningMethod
13
  import os
 
 
 
 
 
 
 
 
 
14
  import numpy as np
 
15
  import cv2
 
16
  from PIL import Image
17
- import random
18
-
19
- RECOMMENDED_RESOLUTIONS = [
20
- (704, 1216, 41),
21
- (704, 1088, 49),
22
- (640, 1056, 57),
23
- (608, 992, 65),
24
- (608, 896, 73),
25
- (544, 896, 81),
26
- (544, 832, 89),
27
- (512, 800, 97),
28
- (512, 768, 97),
29
- (480, 800, 105),
30
- (480, 736, 113),
31
- (480, 704, 121),
32
- (448, 704, 129),
33
- (448, 672, 137),
34
- (416, 640, 153),
35
- (384, 672, 161),
36
- (384, 640, 169),
37
- (384, 608, 177),
38
- (384, 576, 185),
39
- (352, 608, 193),
40
- (352, 576, 201),
41
- (352, 544, 209),
42
- (352, 512, 225),
43
- (352, 512, 233),
44
- (320, 544, 241),
45
- (320, 512, 249),
46
- (320, 512, 257),
47
- ]
48
-
49
-
50
- def load_vae(vae_dir):
51
- vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
52
- vae_config_path = vae_dir / "config.json"
53
- with open(vae_config_path, "r") as f:
54
- vae_config = json.load(f)
55
- vae = CausalVideoAutoencoder.from_config(vae_config)
56
- vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
57
- vae.load_state_dict(vae_state_dict)
58
  if torch.cuda.is_available():
59
- vae = vae.cuda()
60
- return vae.to(torch.bfloat16)
 
61
 
62
 
63
- def load_unet(unet_dir):
64
- unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
65
- unet_config_path = unet_dir / "config.json"
66
- transformer_config = Transformer3DModel.load_config(unet_config_path)
67
- transformer = Transformer3DModel.from_config(transformer_config)
68
- unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
69
- transformer.load_state_dict(unet_state_dict, strict=True)
70
  if torch.cuda.is_available():
71
- transformer = transformer.cuda()
72
- return transformer
73
-
74
-
75
- def load_scheduler(scheduler_dir):
76
- scheduler_config_path = scheduler_dir / "scheduler_config.json"
77
- scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
78
- return RectifiedFlowScheduler.from_config(scheduler_config)
79
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- def center_crop_and_resize(frame, target_height, target_width):
82
- h, w, _ = frame.shape
83
  aspect_ratio_target = target_width / target_height
84
- aspect_ratio_frame = w / h
85
  if aspect_ratio_frame > aspect_ratio_target:
86
- new_width = int(h * aspect_ratio_target)
87
- x_start = (w - new_width) // 2
88
- frame_cropped = frame[:, x_start : x_start + new_width]
 
89
  else:
90
- new_height = int(w / aspect_ratio_target)
91
- y_start = (h - new_height) // 2
92
- frame_cropped = frame[y_start : y_start + new_height, :]
93
- frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
94
- return frame_resized
95
-
96
-
97
- def load_video_to_tensor_with_resize(video_path, target_height, target_width):
98
- cap = cv2.VideoCapture(video_path)
99
- frames = []
100
- while True:
101
- ret, frame = cap.read()
102
- if not ret:
103
- break
104
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
- if target_height is not None:
106
- frame_resized = center_crop_and_resize(
107
- frame_rgb, target_height, target_width
108
- )
109
- else:
110
- frame_resized = frame_rgb
111
- frames.append(frame_resized)
112
- cap.release()
113
- video_np = (np.array(frames) / 127.5) - 1.0
114
- video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
115
- return video_tensor
116
-
117
-
118
- def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
119
- image = Image.open(image_path).convert("RGB")
120
- image_np = np.array(image)
121
- frame_resized = center_crop_and_resize(image_np, target_height, target_width)
122
- frame_tensor = torch.tensor(frame_resized).permute(2, 0, 1).float()
123
  frame_tensor = (frame_tensor / 127.5) - 1.0
124
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
125
  return frame_tensor.unsqueeze(0).unsqueeze(2)
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def main():
129
  parser = argparse.ArgumentParser(
130
  description="Load models from separate directories and run the pipeline."
131
  )
132
 
133
  # Directories
134
- parser.add_argument(
135
- "--ckpt_dir",
136
- type=str,
137
- required=True,
138
- help="Path to the directory containing unet, vae, and scheduler subdirectories",
139
- )
140
- parser.add_argument(
141
- "--input_video_path",
142
- type=str,
143
- help="Path to the input video file (first frame used)",
144
- )
145
- parser.add_argument(
146
- "--input_image_path", type=str, help="Path to the input image file"
147
- )
148
  parser.add_argument(
149
  "--output_path",
150
  type=str,
151
  default=None,
152
- help="Path to save output video, if None will save in working directory.",
153
  )
154
  parser.add_argument("--seed", type=int, default="171198")
155
 
156
  # Pipeline parameters
157
- parser.add_argument(
158
- "--num_inference_steps", type=int, default=40, help="Number of inference steps"
159
- )
160
  parser.add_argument(
161
  "--num_images_per_prompt",
162
  type=int,
@@ -164,21 +208,21 @@ def main():
164
  help="Number of images per prompt",
165
  )
166
  parser.add_argument(
167
- "--guidance_scale",
168
  type=float,
169
- default=3,
170
- help="Guidance scale for the pipeline",
171
  )
172
  parser.add_argument(
173
  "--height",
174
  type=int,
175
- default=None,
176
  help="Height of the output video frames. Optional if an input image provided.",
177
  )
178
  parser.add_argument(
179
  "--width",
180
  type=int,
181
- default=None,
182
  help="Width of the output video frames. If None will infer from input image.",
183
  )
184
  parser.add_argument(
@@ -188,13 +232,18 @@ def main():
188
  help="Number of frames to generate in the output video",
189
  )
190
  parser.add_argument(
191
- "--frame_rate", type=int, default=25, help="Frame rate for the output video"
192
  )
193
-
194
  parser.add_argument(
195
- "--bfloat16",
196
- action="store_true",
197
- help="Denoise in bfloat16",
 
 
 
 
 
 
198
  )
199
 
200
  # Prompts
@@ -209,161 +258,517 @@ def main():
209
  default="worst quality, inconsistent motion, blurry, jittery, distorted",
210
  help="Negative prompt for undesired features",
211
  )
 
212
  parser.add_argument(
213
- "--custom_resolution",
214
  action="store_true",
215
- default=False,
216
- help="Enable custom resolution (not in recommneded resolutions) if specified (default: False)",
217
  )
218
 
219
- args = parser.parse_args()
 
 
 
 
 
 
220
 
221
- if args.input_image_path is None and args.input_video_path is None:
222
- assert (
223
- args.height is not None and args.width is not None
224
- ), "Must enter height and width for text to image generation."
225
-
226
- # Load media (video or image)
227
- if args.input_video_path:
228
- media_items = load_video_to_tensor_with_resize(
229
- args.input_video_path, args.height, args.width
230
- ).unsqueeze(0)
231
- elif args.input_image_path:
232
- media_items = load_image_to_tensor_with_resize(
233
- args.input_image_path, args.height, args.width
234
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  else:
236
- media_items = None
237
-
238
- height = args.height if args.height else media_items.shape[-2]
239
- width = args.width if args.width else media_items.shape[-1]
240
- assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
241
- assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
242
- assert (
243
- height,
244
- width,
245
- args.num_frames,
246
- ) in RECOMMENDED_RESOLUTIONS or args.custom_resolution, f"The selected resolution + num frames combination is not supported, results would be suboptimal. Supported (h,w,f) are: {RECOMMENDED_RESOLUTIONS}. Use --custom_resolution to enable working with this resolution."
247
-
248
- # Paths for the separate mode directories
249
- ckpt_dir = Path(args.ckpt_dir)
250
- unet_dir = ckpt_dir / "unet"
251
- vae_dir = ckpt_dir / "vae"
252
- scheduler_dir = ckpt_dir / "scheduler"
253
-
254
- # Load models
255
- vae = load_vae(vae_dir)
256
- unet = load_unet(unet_dir)
257
- scheduler = load_scheduler(scheduler_dir)
258
- patchifier = SymmetricPatchifier(patch_size=1)
259
  text_encoder = T5EncoderModel.from_pretrained(
260
- "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
261
  )
262
- if torch.cuda.is_available():
263
- text_encoder = text_encoder.to("cuda")
264
  tokenizer = T5Tokenizer.from_pretrained(
265
- "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
266
  )
267
 
268
- if args.bfloat16 and unet.dtype != torch.bfloat16:
269
- unet = unet.to(torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  # Use submodels for the pipeline
272
  submodel_dict = {
273
- "transformer": unet,
274
  "patchifier": patchifier,
275
  "text_encoder": text_encoder,
276
  "tokenizer": tokenizer,
277
  "scheduler": scheduler,
278
  "vae": vae,
 
 
 
 
 
279
  }
280
 
281
- pipeline = XoraVideoPipeline(**submodel_dict)
282
- if torch.cuda.is_available():
283
- pipeline = pipeline.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  # Prepare input for the pipeline
286
  sample = {
287
- "prompt": args.prompt,
288
  "prompt_attention_mask": None,
289
- "negative_prompt": args.negative_prompt,
290
  "negative_prompt_attention_mask": None,
291
- "media_items": media_items,
292
  }
293
 
294
- random.seed(args.seed)
295
- np.random.seed(args.seed)
296
- torch.manual_seed(args.seed)
297
- if torch.cuda.is_available():
298
- torch.cuda.manual_seed(args.seed)
299
-
300
- generator = torch.Generator(
301
- device="cuda" if torch.cuda.is_available() else "cpu"
302
- ).manual_seed(args.seed)
303
 
304
  images = pipeline(
305
- num_inference_steps=args.num_inference_steps,
306
- num_images_per_prompt=args.num_images_per_prompt,
307
- guidance_scale=args.guidance_scale,
308
  generator=generator,
309
  output_type="pt",
310
  callback_on_step_end=None,
311
- height=height,
312
- width=width,
313
- num_frames=args.num_frames,
314
- frame_rate=args.frame_rate,
315
  **sample,
 
 
316
  is_video=True,
317
  vae_per_channel_normalize=True,
318
- conditioning_method=(
319
- ConditioningMethod.FIRST_FRAME
320
- if media_items is not None
321
- else ConditioningMethod.UNCONDITIONAL
322
- ),
323
- mixed_precision=not args.bfloat16,
324
  ).images
325
 
326
- # Save output video
327
- def get_unique_filename(base, ext, dir=".", index_range=1000):
328
- for i in range(index_range):
329
- filename = os.path.join(dir, f"{base}_{i}{ext}")
330
- if not os.path.exists(filename):
331
- return filename
332
- raise FileExistsError(
333
- f"Could not find a unique filename after {index_range} attempts."
334
- )
335
 
336
  for i in range(images.shape[0]):
337
  # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
338
  video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
339
  # Unnormalizing images to [0, 255] range
340
  video_np = (video_np * 255).astype(np.uint8)
341
- fps = args.frame_rate
342
  height, width = video_np.shape[1:3]
 
343
  if video_np.shape[0] == 1:
344
- output_filename = (
345
- args.output_path
346
- if args.output_path is not None
347
- else get_unique_filename(f"image_output_{i}", ".png", ".")
 
 
 
348
  )
349
- cv2.imwrite(
350
- output_filename, video_np[0][..., ::-1]
351
- ) # Save single frame as image
352
  else:
353
- output_filename = (
354
- args.output_path
355
- if args.output_path is not None
356
- else get_unique_filename(f"video_output_{i}", ".mp4", ".")
 
 
 
357
  )
358
 
359
- out = cv2.VideoWriter(
360
- output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
- for frame in video_np[..., ::-1]:
364
- out.write(frame)
365
- out.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
 
368
  if __name__ == "__main__":
369
- main()
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
 
2
  import os
3
+ import random
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from diffusers.utils import logging
7
+ from typing import Optional, List, Union
8
+ import yaml
9
+
10
+ import imageio
11
+ import json
12
  import numpy as np
13
+ import torch
14
  import cv2
15
+ from safetensors import safe_open
16
  from PIL import Image
17
+ from transformers import (
18
+ T5EncoderModel,
19
+ T5Tokenizer,
20
+ AutoModelForCausalLM,
21
+ AutoProcessor,
22
+ AutoTokenizer,
23
+ )
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
27
+ CausalVideoAutoencoder,
28
+ )
29
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
30
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
31
+ from ltx_video.pipelines.pipeline_ltx_video import (
32
+ ConditioningItem,
33
+ LTXVideoPipeline,
34
+ LTXMultiScalePipeline,
35
+ )
36
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
37
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
38
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
39
+ import ltx_video.pipelines.crf_compressor as crf_compressor
40
+
41
+ MAX_HEIGHT = 720
42
+ MAX_WIDTH = 1280
43
+ MAX_NUM_FRAMES = 257
44
+
45
+ logger = logging.get_logger("LTX-Video")
46
+
47
+
48
+ def get_total_gpu_memory():
 
 
 
 
 
 
 
 
 
49
  if torch.cuda.is_available():
50
+ total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
51
+ return total_memory
52
+ return 0
53
 
54
 
55
+ def get_device():
 
 
 
 
 
 
56
  if torch.cuda.is_available():
57
+ return "cuda"
58
+ elif torch.backends.mps.is_available():
59
+ return "mps"
60
+ return "cpu"
61
+
62
+
63
+ def load_image_to_tensor_with_resize_and_crop(
64
+ image_input: Union[str, Image.Image],
65
+ target_height: int = 512,
66
+ target_width: int = 768,
67
+ just_crop: bool = False,
68
+ ) -> torch.Tensor:
69
+ """Load and process an image into a tensor.
70
+
71
+ Args:
72
+ image_input: Either a file path (str) or a PIL Image object
73
+ target_height: Desired height of output tensor
74
+ target_width: Desired width of output tensor
75
+ just_crop: If True, only crop the image to the target size without resizing
76
+ """
77
+ if isinstance(image_input, str):
78
+ image = Image.open(image_input).convert("RGB")
79
+ elif isinstance(image_input, Image.Image):
80
+ image = image_input
81
+ else:
82
+ raise ValueError("image_input must be either a file path or a PIL Image object")
83
 
84
+ input_width, input_height = image.size
 
85
  aspect_ratio_target = target_width / target_height
86
+ aspect_ratio_frame = input_width / input_height
87
  if aspect_ratio_frame > aspect_ratio_target:
88
+ new_width = int(input_height * aspect_ratio_target)
89
+ new_height = input_height
90
+ x_start = (input_width - new_width) // 2
91
+ y_start = 0
92
  else:
93
+ new_width = input_width
94
+ new_height = int(input_width / aspect_ratio_target)
95
+ x_start = 0
96
+ y_start = (input_height - new_height) // 2
97
+
98
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
99
+ if not just_crop:
100
+ image = image.resize((target_width, target_height))
101
+
102
+ image = np.array(image)
103
+ image = cv2.GaussianBlur(image, (3, 3), 0)
104
+ frame_tensor = torch.from_numpy(image).float()
105
+ frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
106
+ frame_tensor = frame_tensor.permute(2, 0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  frame_tensor = (frame_tensor / 127.5) - 1.0
108
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
109
  return frame_tensor.unsqueeze(0).unsqueeze(2)
110
 
111
 
112
+ def calculate_padding(
113
+ source_height: int, source_width: int, target_height: int, target_width: int
114
+ ) -> tuple[int, int, int, int]:
115
+
116
+ # Calculate total padding needed
117
+ pad_height = target_height - source_height
118
+ pad_width = target_width - source_width
119
+
120
+ # Calculate padding for each side
121
+ pad_top = pad_height // 2
122
+ pad_bottom = pad_height - pad_top # Handles odd padding
123
+ pad_left = pad_width // 2
124
+ pad_right = pad_width - pad_left # Handles odd padding
125
+
126
+ # Return padded tensor
127
+ # Padding format is (left, right, top, bottom)
128
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
129
+ return padding
130
+
131
+
132
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
133
+ # Remove non-letters and convert to lowercase
134
+ clean_text = "".join(
135
+ char.lower() for char in text if char.isalpha() or char.isspace()
136
+ )
137
+
138
+ # Split into words
139
+ words = clean_text.split()
140
+
141
+ # Build result string keeping track of length
142
+ result = []
143
+ current_length = 0
144
+
145
+ for word in words:
146
+ # Add word length plus 1 for underscore (except for first word)
147
+ new_length = current_length + len(word)
148
+
149
+ if new_length <= max_len:
150
+ result.append(word)
151
+ current_length += len(word)
152
+ else:
153
+ break
154
+
155
+ return "-".join(result)
156
+
157
+
158
+ # Generate output video name
159
+ def get_unique_filename(
160
+ base: str,
161
+ ext: str,
162
+ prompt: str,
163
+ seed: int,
164
+ resolution: tuple[int, int, int],
165
+ dir: Path,
166
+ endswith=None,
167
+ index_range=1000,
168
+ ) -> Path:
169
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
170
+ for i in range(index_range):
171
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
172
+ if not os.path.exists(filename):
173
+ return filename
174
+ raise FileExistsError(
175
+ f"Could not find a unique filename after {index_range} attempts."
176
+ )
177
+
178
+
179
+ def seed_everething(seed: int):
180
+ random.seed(seed)
181
+ np.random.seed(seed)
182
+ torch.manual_seed(seed)
183
+ if torch.cuda.is_available():
184
+ torch.cuda.manual_seed(seed)
185
+ if torch.backends.mps.is_available():
186
+ torch.mps.manual_seed(seed)
187
+
188
+
189
  def main():
190
  parser = argparse.ArgumentParser(
191
  description="Load models from separate directories and run the pipeline."
192
  )
193
 
194
  # Directories
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  parser.add_argument(
196
  "--output_path",
197
  type=str,
198
  default=None,
199
+ help="Path to the folder to save output video, if None will save in outputs/ directory.",
200
  )
201
  parser.add_argument("--seed", type=int, default="171198")
202
 
203
  # Pipeline parameters
 
 
 
204
  parser.add_argument(
205
  "--num_images_per_prompt",
206
  type=int,
 
208
  help="Number of images per prompt",
209
  )
210
  parser.add_argument(
211
+ "--image_cond_noise_scale",
212
  type=float,
213
+ default=0.15,
214
+ help="Amount of noise to add to the conditioned image",
215
  )
216
  parser.add_argument(
217
  "--height",
218
  type=int,
219
+ default=704,
220
  help="Height of the output video frames. Optional if an input image provided.",
221
  )
222
  parser.add_argument(
223
  "--width",
224
  type=int,
225
+ default=1216,
226
  help="Width of the output video frames. If None will infer from input image.",
227
  )
228
  parser.add_argument(
 
232
  help="Number of frames to generate in the output video",
233
  )
234
  parser.add_argument(
235
+ "--frame_rate", type=int, default=30, help="Frame rate for the output video"
236
  )
 
237
  parser.add_argument(
238
+ "--device",
239
+ default=None,
240
+ help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
241
+ )
242
+ parser.add_argument(
243
+ "--pipeline_config",
244
+ type=str,
245
+ default="configs/ltxv-13b-0.9.7-dev.yaml",
246
+ help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
247
  )
248
 
249
  # Prompts
 
258
  default="worst quality, inconsistent motion, blurry, jittery, distorted",
259
  help="Negative prompt for undesired features",
260
  )
261
+
262
  parser.add_argument(
263
+ "--offload_to_cpu",
264
  action="store_true",
265
+ help="Offloading unnecessary computations to CPU.",
 
266
  )
267
 
268
+ # video-to-video arguments:
269
+ parser.add_argument(
270
+ "--input_media_path",
271
+ type=str,
272
+ default=None,
273
+ help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
274
+ )
275
 
276
+ # Conditioning arguments
277
+ parser.add_argument(
278
+ "--conditioning_media_paths",
279
+ type=str,
280
+ nargs="*",
281
+ help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
282
+ )
283
+ parser.add_argument(
284
+ "--conditioning_strengths",
285
+ type=float,
286
+ nargs="*",
287
+ help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
288
+ )
289
+ parser.add_argument(
290
+ "--conditioning_start_frames",
291
+ type=int,
292
+ nargs="*",
293
+ help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
294
+ )
295
+
296
+ args = parser.parse_args()
297
+ logger.warning(f"Running generation with arguments: {args}")
298
+ infer(**vars(args))
299
+
300
+
301
+ def create_ltx_video_pipeline(
302
+ ckpt_path: str,
303
+ precision: str,
304
+ text_encoder_model_name_or_path: str,
305
+ sampler: Optional[str] = None,
306
+ device: Optional[str] = None,
307
+ enhance_prompt: bool = False,
308
+ prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
309
+ prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
310
+ ) -> LTXVideoPipeline:
311
+ ckpt_path = Path(ckpt_path)
312
+ assert os.path.exists(
313
+ ckpt_path
314
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
315
+
316
+ with safe_open(ckpt_path, framework="pt") as f:
317
+ metadata = f.metadata()
318
+ config_str = metadata.get("config")
319
+ configs = json.loads(config_str)
320
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
321
+
322
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
323
+ transformer = Transformer3DModel.from_pretrained(ckpt_path)
324
+
325
+ # Use constructor if sampler is specified, otherwise use from_pretrained
326
+ if sampler == "from_checkpoint" or not sampler:
327
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
328
  else:
329
+ scheduler = RectifiedFlowScheduler(
330
+ sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
331
+ )
332
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  text_encoder = T5EncoderModel.from_pretrained(
334
+ text_encoder_model_name_or_path, subfolder="text_encoder"
335
  )
336
+ patchifier = SymmetricPatchifier(patch_size=1)
 
337
  tokenizer = T5Tokenizer.from_pretrained(
338
+ text_encoder_model_name_or_path, subfolder="tokenizer"
339
  )
340
 
341
+ transformer = transformer.to(device)
342
+ vae = vae.to(device)
343
+ text_encoder = text_encoder.to(device)
344
+
345
+ if enhance_prompt:
346
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
347
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
348
+ )
349
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
350
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
351
+ )
352
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
353
+ prompt_enhancer_llm_model_name_or_path,
354
+ torch_dtype="bfloat16",
355
+ )
356
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
357
+ prompt_enhancer_llm_model_name_or_path,
358
+ )
359
+ else:
360
+ prompt_enhancer_image_caption_model = None
361
+ prompt_enhancer_image_caption_processor = None
362
+ prompt_enhancer_llm_model = None
363
+ prompt_enhancer_llm_tokenizer = None
364
+
365
+ vae = vae.to(torch.bfloat16)
366
+ if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
367
+ transformer = transformer.to(torch.bfloat16)
368
+ text_encoder = text_encoder.to(torch.bfloat16)
369
 
370
  # Use submodels for the pipeline
371
  submodel_dict = {
372
+ "transformer": transformer,
373
  "patchifier": patchifier,
374
  "text_encoder": text_encoder,
375
  "tokenizer": tokenizer,
376
  "scheduler": scheduler,
377
  "vae": vae,
378
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
379
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
380
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
381
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
382
+ "allowed_inference_steps": allowed_inference_steps,
383
  }
384
 
385
+ pipeline = LTXVideoPipeline(**submodel_dict)
386
+ pipeline = pipeline.to(device)
387
+ return pipeline
388
+
389
+
390
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
391
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
392
+ latent_upsampler.to(device)
393
+ latent_upsampler.eval()
394
+ return latent_upsampler
395
+
396
+
397
+ def infer(
398
+ output_path: Optional[str],
399
+ seed: int,
400
+ pipeline_config: str,
401
+ image_cond_noise_scale: float,
402
+ height: Optional[int],
403
+ width: Optional[int],
404
+ num_frames: int,
405
+ frame_rate: int,
406
+ prompt: str,
407
+ negative_prompt: str,
408
+ offload_to_cpu: bool,
409
+ input_media_path: Optional[str] = None,
410
+ conditioning_media_paths: Optional[List[str]] = None,
411
+ conditioning_strengths: Optional[List[float]] = None,
412
+ conditioning_start_frames: Optional[List[int]] = None,
413
+ device: Optional[str] = None,
414
+ **kwargs,
415
+ ):
416
+ # check if pipeline_config is a file
417
+ if not os.path.isfile(pipeline_config):
418
+ raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
419
+ with open(pipeline_config, "r") as f:
420
+ pipeline_config = yaml.safe_load(f)
421
+
422
+ models_dir = "MODEL_DIR"
423
+
424
+ ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
425
+ if not os.path.isfile(ltxv_model_name_or_path):
426
+ ltxv_model_path = hf_hub_download(
427
+ repo_id="Lightricks/LTX-Video",
428
+ filename=ltxv_model_name_or_path,
429
+ local_dir=models_dir,
430
+ repo_type="model",
431
+ )
432
+ else:
433
+ ltxv_model_path = ltxv_model_name_or_path
434
+
435
+ spatial_upscaler_model_name_or_path = pipeline_config.get(
436
+ "spatial_upscaler_model_path"
437
+ )
438
+ if spatial_upscaler_model_name_or_path and not os.path.isfile(
439
+ spatial_upscaler_model_name_or_path
440
+ ):
441
+ spatial_upscaler_model_path = hf_hub_download(
442
+ repo_id="Lightricks/LTX-Video",
443
+ filename=spatial_upscaler_model_name_or_path,
444
+ local_dir=models_dir,
445
+ repo_type="model",
446
+ )
447
+ else:
448
+ spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
449
+
450
+ if kwargs.get("input_image_path", None):
451
+ logger.warning(
452
+ "Please use conditioning_media_paths instead of input_image_path."
453
+ )
454
+ assert not conditioning_media_paths and not conditioning_start_frames
455
+ conditioning_media_paths = [kwargs["input_image_path"]]
456
+ conditioning_start_frames = [0]
457
+
458
+ # Validate conditioning arguments
459
+ if conditioning_media_paths:
460
+ # Use default strengths of 1.0
461
+ if not conditioning_strengths:
462
+ conditioning_strengths = [1.0] * len(conditioning_media_paths)
463
+ if not conditioning_start_frames:
464
+ raise ValueError(
465
+ "If `conditioning_media_paths` is provided, "
466
+ "`conditioning_start_frames` must also be provided"
467
+ )
468
+ if len(conditioning_media_paths) != len(conditioning_strengths) or len(
469
+ conditioning_media_paths
470
+ ) != len(conditioning_start_frames):
471
+ raise ValueError(
472
+ "`conditioning_media_paths`, `conditioning_strengths`, "
473
+ "and `conditioning_start_frames` must have the same length"
474
+ )
475
+ if any(s < 0 or s > 1 for s in conditioning_strengths):
476
+ raise ValueError("All conditioning strengths must be between 0 and 1")
477
+ if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
478
+ raise ValueError(
479
+ f"All conditioning start frames must be between 0 and {num_frames-1}"
480
+ )
481
+
482
+ seed_everething(seed)
483
+ if offload_to_cpu and not torch.cuda.is_available():
484
+ logger.warning(
485
+ "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
486
+ )
487
+ offload_to_cpu = False
488
+ else:
489
+ offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
490
+
491
+ output_dir = (
492
+ Path(output_path)
493
+ if output_path
494
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
495
+ )
496
+ output_dir.mkdir(parents=True, exist_ok=True)
497
+
498
+ # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
499
+ height_padded = ((height - 1) // 32 + 1) * 32
500
+ width_padded = ((width - 1) // 32 + 1) * 32
501
+ num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
502
+
503
+ padding = calculate_padding(height, width, height_padded, width_padded)
504
+
505
+ logger.warning(
506
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
507
+ )
508
+
509
+ prompt_enhancement_words_threshold = pipeline_config[
510
+ "prompt_enhancement_words_threshold"
511
+ ]
512
+
513
+ prompt_word_count = len(prompt.split())
514
+ enhance_prompt = (
515
+ prompt_enhancement_words_threshold > 0
516
+ and prompt_word_count < prompt_enhancement_words_threshold
517
+ )
518
+
519
+ if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
520
+ logger.info(
521
+ f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
522
+ )
523
+
524
+ precision = pipeline_config["precision"]
525
+ text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
526
+ sampler = pipeline_config["sampler"]
527
+ prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
528
+ "prompt_enhancer_image_caption_model_name_or_path"
529
+ ]
530
+ prompt_enhancer_llm_model_name_or_path = pipeline_config[
531
+ "prompt_enhancer_llm_model_name_or_path"
532
+ ]
533
+
534
+ pipeline = create_ltx_video_pipeline(
535
+ ckpt_path=ltxv_model_path,
536
+ precision=precision,
537
+ text_encoder_model_name_or_path=text_encoder_model_name_or_path,
538
+ sampler=sampler,
539
+ device=kwargs.get("device", get_device()),
540
+ enhance_prompt=enhance_prompt,
541
+ prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
542
+ prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
543
+ )
544
+
545
+ if pipeline_config.get("pipeline_type", None) == "multi-scale":
546
+ if not spatial_upscaler_model_path:
547
+ raise ValueError(
548
+ "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
549
+ )
550
+ latent_upsampler = create_latent_upsampler(
551
+ spatial_upscaler_model_path, pipeline.device
552
+ )
553
+ pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
554
+
555
+ media_item = None
556
+ if input_media_path:
557
+ media_item = load_media_file(
558
+ media_path=input_media_path,
559
+ height=height,
560
+ width=width,
561
+ max_frames=num_frames_padded,
562
+ padding=padding,
563
+ )
564
+
565
+ conditioning_items = (
566
+ prepare_conditioning(
567
+ conditioning_media_paths=conditioning_media_paths,
568
+ conditioning_strengths=conditioning_strengths,
569
+ conditioning_start_frames=conditioning_start_frames,
570
+ height=height,
571
+ width=width,
572
+ num_frames=num_frames,
573
+ padding=padding,
574
+ pipeline=pipeline,
575
+ )
576
+ if conditioning_media_paths
577
+ else None
578
+ )
579
+
580
+ stg_mode = pipeline_config.get("stg_mode", "attention_values")
581
+ del pipeline_config["stg_mode"]
582
+ if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
583
+ skip_layer_strategy = SkipLayerStrategy.AttentionValues
584
+ elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
585
+ skip_layer_strategy = SkipLayerStrategy.AttentionSkip
586
+ elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
587
+ skip_layer_strategy = SkipLayerStrategy.Residual
588
+ elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
589
+ skip_layer_strategy = SkipLayerStrategy.TransformerBlock
590
+ else:
591
+ raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
592
 
593
  # Prepare input for the pipeline
594
  sample = {
595
+ "prompt": prompt,
596
  "prompt_attention_mask": None,
597
+ "negative_prompt": negative_prompt,
598
  "negative_prompt_attention_mask": None,
 
599
  }
600
 
601
+ device = device or get_device()
602
+ generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
603
 
604
  images = pipeline(
605
+ **pipeline_config,
606
+ skip_layer_strategy=skip_layer_strategy,
 
607
  generator=generator,
608
  output_type="pt",
609
  callback_on_step_end=None,
610
+ height=height_padded,
611
+ width=width_padded,
612
+ num_frames=num_frames_padded,
613
+ frame_rate=frame_rate,
614
  **sample,
615
+ media_items=media_item,
616
+ conditioning_items=conditioning_items,
617
  is_video=True,
618
  vae_per_channel_normalize=True,
619
+ image_cond_noise_scale=image_cond_noise_scale,
620
+ mixed_precision=(precision == "mixed_precision"),
621
+ offload_to_cpu=offload_to_cpu,
622
+ device=device,
623
+ enhance_prompt=enhance_prompt,
 
624
  ).images
625
 
626
+ # Crop the padded images to the desired resolution and number of frames
627
+ (pad_left, pad_right, pad_top, pad_bottom) = padding
628
+ pad_bottom = -pad_bottom
629
+ pad_right = -pad_right
630
+ if pad_bottom == 0:
631
+ pad_bottom = images.shape[3]
632
+ if pad_right == 0:
633
+ pad_right = images.shape[4]
634
+ images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
635
 
636
  for i in range(images.shape[0]):
637
  # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
638
  video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
639
  # Unnormalizing images to [0, 255] range
640
  video_np = (video_np * 255).astype(np.uint8)
641
+ fps = frame_rate
642
  height, width = video_np.shape[1:3]
643
+ # In case a single image is generated
644
  if video_np.shape[0] == 1:
645
+ output_filename = get_unique_filename(
646
+ f"image_output_{i}",
647
+ ".png",
648
+ prompt=prompt,
649
+ seed=seed,
650
+ resolution=(height, width, num_frames),
651
+ dir=output_dir,
652
  )
653
+ imageio.imwrite(output_filename, video_np[0])
 
 
654
  else:
655
+ output_filename = get_unique_filename(
656
+ f"video_output_{i}",
657
+ ".mp4",
658
+ prompt=prompt,
659
+ seed=seed,
660
+ resolution=(height, width, num_frames),
661
+ dir=output_dir,
662
  )
663
 
664
+ # Write video
665
+ with imageio.get_writer(output_filename, fps=fps) as video:
666
+ for frame in video_np:
667
+ video.append_data(frame)
668
+
669
+ logger.warning(f"Output saved to {output_filename}")
670
+
671
+
672
+ def prepare_conditioning(
673
+ conditioning_media_paths: List[str],
674
+ conditioning_strengths: List[float],
675
+ conditioning_start_frames: List[int],
676
+ height: int,
677
+ width: int,
678
+ num_frames: int,
679
+ padding: tuple[int, int, int, int],
680
+ pipeline: LTXVideoPipeline,
681
+ ) -> Optional[List[ConditioningItem]]:
682
+ """Prepare conditioning items based on input media paths and their parameters.
683
+
684
+ Args:
685
+ conditioning_media_paths: List of paths to conditioning media (images or videos)
686
+ conditioning_strengths: List of conditioning strengths for each media item
687
+ conditioning_start_frames: List of frame indices where each item should be applied
688
+ height: Height of the output frames
689
+ width: Width of the output frames
690
+ num_frames: Number of frames in the output video
691
+ padding: Padding to apply to the frames
692
+ pipeline: LTXVideoPipeline object used for condition video trimming
693
+
694
+ Returns:
695
+ A list of ConditioningItem objects.
696
+ """
697
+ conditioning_items = []
698
+ for path, strength, start_frame in zip(
699
+ conditioning_media_paths, conditioning_strengths, conditioning_start_frames
700
+ ):
701
+ num_input_frames = orig_num_input_frames = get_media_num_frames(path)
702
+ if hasattr(pipeline, "trim_conditioning_sequence") and callable(
703
+ getattr(pipeline, "trim_conditioning_sequence")
704
+ ):
705
+ num_input_frames = pipeline.trim_conditioning_sequence(
706
+ start_frame, orig_num_input_frames, num_frames
707
  )
708
+ if num_input_frames < orig_num_input_frames:
709
+ logger.warning(
710
+ f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
711
+ )
712
+
713
+ media_tensor = load_media_file(
714
+ media_path=path,
715
+ height=height,
716
+ width=width,
717
+ max_frames=num_input_frames,
718
+ padding=padding,
719
+ just_crop=True,
720
+ )
721
+ conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
722
+ return conditioning_items
723
+
724
 
725
+ def get_media_num_frames(media_path: str) -> int:
726
+ is_video = any(
727
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
728
+ )
729
+ num_frames = 1
730
+ if is_video:
731
+ reader = imageio.get_reader(media_path)
732
+ num_frames = reader.count_frames()
733
+ reader.close()
734
+ return num_frames
735
+
736
+
737
+ def load_media_file(
738
+ media_path: str,
739
+ height: int,
740
+ width: int,
741
+ max_frames: int,
742
+ padding: tuple[int, int, int, int],
743
+ just_crop: bool = False,
744
+ ) -> torch.Tensor:
745
+ is_video = any(
746
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
747
+ )
748
+ if is_video:
749
+ reader = imageio.get_reader(media_path)
750
+ num_input_frames = min(reader.count_frames(), max_frames)
751
+
752
+ # Read and preprocess the relevant frames from the video file.
753
+ frames = []
754
+ for i in range(num_input_frames):
755
+ frame = Image.fromarray(reader.get_data(i))
756
+ frame_tensor = load_image_to_tensor_with_resize_and_crop(
757
+ frame, height, width, just_crop=just_crop
758
+ )
759
+ frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
760
+ frames.append(frame_tensor)
761
+ reader.close()
762
+
763
+ # Stack frames along the temporal dimension
764
+ media_tensor = torch.cat(frames, dim=2)
765
+ else: # Input image
766
+ media_tensor = load_image_to_tensor_with_resize_and_crop(
767
+ media_path, height, width, just_crop=just_crop
768
+ )
769
+ media_tensor = torch.nn.functional.pad(media_tensor, padding)
770
+ return media_tensor
771
 
772
 
773
  if __name__ == "__main__":
774
+ main()
requirements.txt CHANGED
@@ -1,14 +1,15 @@
1
- huggingface_hub<=0.25
2
- torch
3
- diffusers==0.28.2
4
- transformers==4.44.2
5
- sentencepiece>=0.1.96
6
  accelerate
7
- einops
8
- matplotlib
 
 
 
 
 
9
  opencv-python
10
- beautifulsoup4
11
- ftfy
12
- gradio
13
- openai
14
- gradio_toggle
 
 
 
 
 
 
 
1
  accelerate
2
+ transformers
3
+ sentencepiece
4
+ pillow
5
+ numpy
6
+ torchvision
7
+ huggingface_hub
8
+ spaces
9
  opencv-python
10
+ imageio
11
+ imageio-ffmpeg
12
+ einops
13
+ timm
14
+ av
15
+ git+https://github.com/huggingface/diffusers.git@main