Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- README.md +10 -79
- app.py +490 -766
- inference.py +645 -240
- requirements.txt +13 -12
README.md
CHANGED
@@ -1,82 +1,13 @@
|
|
1 |
---
|
2 |
-
title: LTX
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
# Xora️
|
15 |
-
|
16 |
-
</div>
|
17 |
-
|
18 |
-
This is the official repository for Xora.
|
19 |
-
|
20 |
-
## Table of Contents
|
21 |
-
|
22 |
-
- [Introduction](#introduction)
|
23 |
-
- [Installation](#installation)
|
24 |
-
- [Inference](#inference)
|
25 |
-
- [Inference Code](#inference-code)
|
26 |
-
- [Acknowledgement](#acknowledgement)
|
27 |
-
|
28 |
-
## Introduction
|
29 |
-
|
30 |
-
The performance of Diffusion Transformers is heavily influenced by the number of generated latent pixels (or tokens). In video generation, the token count becomes substantial as the number of frames increases. To address this, we designed a carefully optimized VAE that compresses videos into a smaller number of tokens while utilizing a deeper latent space. This approach enables our model to generate high-quality 768x512 videos at 24 FPS, achieving near real-time speeds.
|
31 |
-
|
32 |
-
## Installation
|
33 |
-
|
34 |
-
# Setup
|
35 |
-
|
36 |
-
The codebase currently uses Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
|
37 |
-
|
38 |
-
```bash
|
39 |
-
git clone https://github.com/LightricksResearch/xora-core.git
|
40 |
-
cd xora-core
|
41 |
-
|
42 |
-
# create env
|
43 |
-
python -m venv env
|
44 |
-
source env/bin/activate
|
45 |
-
python -m pip install -e .\[inference-script\]
|
46 |
-
```
|
47 |
-
|
48 |
-
Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/Xora)
|
49 |
-
|
50 |
-
```python
|
51 |
-
from huggingface_hub import snapshot_download
|
52 |
-
|
53 |
-
model_path = 'PATH' # The local directory to save downloaded checkpoint
|
54 |
-
snapshot_download("Lightricks/Xora", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
|
55 |
-
```
|
56 |
-
|
57 |
-
## Inference
|
58 |
-
|
59 |
-
### Inference Code
|
60 |
-
|
61 |
-
To use our model, please follow the inference code in `inference.py` at [https://github.com/LightricksResearch/xora-core/blob/main/inference.py]():
|
62 |
-
|
63 |
-
For text-to-video generation:
|
64 |
-
|
65 |
-
```bash
|
66 |
-
python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH
|
67 |
-
```
|
68 |
-
|
69 |
-
For image-to-video generation:
|
70 |
-
|
71 |
-
```python
|
72 |
-
python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH
|
73 |
-
|
74 |
-
```
|
75 |
-
|
76 |
-
## Acknowledgement
|
77 |
-
|
78 |
-
We are grateful for the following awesome projects when implementing Xora:
|
79 |
-
|
80 |
-
- [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.
|
81 |
-
|
82 |
-
[//]: # "## Citation"
|
|
|
1 |
---
|
2 |
+
title: LTX Video Fast
|
3 |
+
emoji: 🎥
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.42.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
short_description: ultra-fast video model, LTX 0.9.8 13B distilled
|
11 |
---
|
12 |
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,797 +1,521 @@
|
|
1 |
-
import spaces
|
2 |
-
from functools import lru_cache
|
3 |
import gradio as gr
|
4 |
-
from gradio_toggle import Toggle
|
5 |
import torch
|
6 |
-
|
7 |
-
from transformers import CLIPProcessor, CLIPModel
|
8 |
-
|
9 |
-
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
10 |
-
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
11 |
-
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
12 |
-
from ltx_video.schedulers.rf import RectifiedFlowScheduler
|
13 |
-
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline as XoraVideoPipeline
|
14 |
-
from transformers import T5EncoderModel, T5Tokenizer
|
15 |
-
from ltx_video.utils.conditioning_method import ConditioningMethod
|
16 |
-
from pathlib import Path
|
17 |
-
import safetensors.torch
|
18 |
-
import json
|
19 |
import numpy as np
|
20 |
-
import
|
21 |
-
from PIL import Image
|
22 |
-
import tempfile
|
23 |
import os
|
24 |
-
import
|
25 |
-
from
|
26 |
-
import
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
system_prompt_i2v_path = "assets/system_prompt_i2v.txt"
|
38 |
-
with open(system_prompt_t2v_path, "r") as f:
|
39 |
-
system_prompt_t2v = f.read()
|
40 |
-
|
41 |
-
with open(system_prompt_i2v_path, "r") as f:
|
42 |
-
system_prompt_i2v = f.read()
|
43 |
-
|
44 |
-
# Set model download directory within Hugging Face Spaces
|
45 |
-
model_path = Path("/home/elevin/xora-core/assets/")
|
46 |
-
cpkt_path = Path("/home/elevin/xora-core/assets/ltx-video-2b-v0.9.1.safetensors")
|
47 |
-
if not os.path.exists(cpkt_path):
|
48 |
-
hf_hub_download(repo_id="Lightricks/LTX-Video", filename="ltx-video-2b-v0.9.1.safetensors", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
|
49 |
-
# Global variables to load components
|
50 |
-
vae_dir = Path(model_path) / "vae"
|
51 |
-
unet_dir = Path(model_path) / "unet"
|
52 |
-
scheduler_dir = Path(model_path) / "scheduler"
|
53 |
-
|
54 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
-
|
56 |
-
DATA_DIR = "/data"
|
57 |
-
os.makedirs(DATA_DIR, exist_ok=True)
|
58 |
-
LOG_FILE_PATH = os.path.join("/data", "user_requests.csv")
|
59 |
-
|
60 |
-
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
|
61 |
-
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
|
62 |
-
|
63 |
-
|
64 |
-
if not os.path.exists(LOG_FILE_PATH):
|
65 |
-
with open(LOG_FILE_PATH, "w", newline="") as f:
|
66 |
-
writer = csv.writer(f)
|
67 |
-
writer.writerow(
|
68 |
-
[
|
69 |
-
"timestamp",
|
70 |
-
"request_type",
|
71 |
-
"prompt",
|
72 |
-
"negative_prompt",
|
73 |
-
"height",
|
74 |
-
"width",
|
75 |
-
"num_frames",
|
76 |
-
"frame_rate",
|
77 |
-
"seed",
|
78 |
-
"num_inference_steps",
|
79 |
-
"guidance_scale",
|
80 |
-
"is_enhanced",
|
81 |
-
"clip_embedding",
|
82 |
-
"original_resolution",
|
83 |
-
]
|
84 |
-
)
|
85 |
-
|
86 |
-
|
87 |
-
@lru_cache(maxsize=128)
|
88 |
-
def log_request(
|
89 |
-
request_type,
|
90 |
-
prompt,
|
91 |
-
negative_prompt,
|
92 |
-
height,
|
93 |
-
width,
|
94 |
-
num_frames,
|
95 |
-
frame_rate,
|
96 |
-
seed,
|
97 |
-
num_inference_steps,
|
98 |
-
guidance_scale,
|
99 |
-
is_enhanced,
|
100 |
-
clip_embedding=None,
|
101 |
-
original_resolution=None,
|
102 |
-
):
|
103 |
-
"""Log the user's request to a CSV file."""
|
104 |
-
timestamp = datetime.now().isoformat()
|
105 |
-
with open(LOG_FILE_PATH, "a", newline="") as f:
|
106 |
-
try:
|
107 |
-
writer = csv.writer(f)
|
108 |
-
writer.writerow(
|
109 |
-
[
|
110 |
-
timestamp,
|
111 |
-
request_type,
|
112 |
-
prompt,
|
113 |
-
negative_prompt,
|
114 |
-
height,
|
115 |
-
width,
|
116 |
-
num_frames,
|
117 |
-
frame_rate,
|
118 |
-
seed,
|
119 |
-
num_inference_steps,
|
120 |
-
guidance_scale,
|
121 |
-
is_enhanced,
|
122 |
-
clip_embedding,
|
123 |
-
original_resolution,
|
124 |
-
]
|
125 |
-
)
|
126 |
-
except Exception as e:
|
127 |
-
print(f"Error logging request: {e}")
|
128 |
|
129 |
|
130 |
-
|
|
|
|
|
|
|
|
|
131 |
"""
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
image (PIL.Image): Input image.
|
136 |
-
Returns:
|
137 |
-
list: CLIP embedding as a list of floats.
|
138 |
"""
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
aspect_ratio_target = target_width / target_height
|
161 |
-
aspect_ratio_frame = w / h
|
162 |
-
if aspect_ratio_frame > aspect_ratio_target:
|
163 |
-
new_width = int(h * aspect_ratio_target)
|
164 |
-
x_start = (w - new_width) // 2
|
165 |
-
frame_cropped = frame[:, x_start : x_start + new_width]
|
166 |
-
else:
|
167 |
-
new_height = int(w / aspect_ratio_target)
|
168 |
-
y_start = (h - new_height) // 2
|
169 |
-
frame_cropped = frame[y_start : y_start + new_height, :]
|
170 |
-
frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
|
171 |
-
return frame_resized
|
172 |
-
|
173 |
-
|
174 |
-
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
175 |
-
image = Image.open(image_path).convert("RGB")
|
176 |
-
image_np = np.array(image)
|
177 |
-
frame_resized = center_crop_and_resize(image_np, target_height, target_width)
|
178 |
-
frame_tensor = torch.tensor(frame_resized).permute(2, 0, 1).float()
|
179 |
-
frame_tensor = (frame_tensor / 127.5) - 1.0
|
180 |
-
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
181 |
-
|
182 |
-
|
183 |
-
def enhance_prompt_if_enabled(prompt, enhance_toggle, type="t2v"):
|
184 |
-
if not enhance_toggle:
|
185 |
-
print("Enhance toggle is off, Prompt: ", prompt)
|
186 |
-
return prompt
|
187 |
-
|
188 |
-
system_prompt = system_prompt_t2v if type == "t2v" else system_prompt_i2v
|
189 |
-
messages = [
|
190 |
-
{"role": "system", "content": system_prompt},
|
191 |
-
{"role": "user", "content": prompt},
|
192 |
-
]
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
return prompt
|
205 |
-
|
206 |
-
|
207 |
-
# Preset options for resolution and frame configuration
|
208 |
-
preset_options = [
|
209 |
-
{"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
|
210 |
-
{"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
|
211 |
-
{"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
|
212 |
-
{"label": "992x608, 65 frames", "width": 992, "height": 608, "num_frames": 65},
|
213 |
-
{"label": "896x608, 73 frames", "width": 896, "height": 608, "num_frames": 73},
|
214 |
-
{"label": "896x544, 81 frames", "width": 896, "height": 544, "num_frames": 81},
|
215 |
-
{"label": "832x544, 89 frames", "width": 832, "height": 544, "num_frames": 89},
|
216 |
-
{"label": "800x512, 97 frames", "width": 800, "height": 512, "num_frames": 97},
|
217 |
-
{"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
|
218 |
-
{"label": "800x480, 105 frames", "width": 800, "height": 480, "num_frames": 105},
|
219 |
-
{"label": "736x480, 113 frames", "width": 736, "height": 480, "num_frames": 113},
|
220 |
-
{"label": "704x480, 121 frames", "width": 704, "height": 480, "num_frames": 121},
|
221 |
-
{"label": "704x448, 129 frames", "width": 704, "height": 448, "num_frames": 129},
|
222 |
-
{"label": "672x448, 137 frames", "width": 672, "height": 448, "num_frames": 137},
|
223 |
-
{"label": "640x416, 153 frames", "width": 640, "height": 416, "num_frames": 153},
|
224 |
-
{"label": "672x384, 161 frames", "width": 672, "height": 384, "num_frames": 161},
|
225 |
-
{"label": "640x384, 169 frames", "width": 640, "height": 384, "num_frames": 169},
|
226 |
-
{"label": "608x384, 177 frames", "width": 608, "height": 384, "num_frames": 177},
|
227 |
-
{"label": "576x384, 185 frames", "width": 576, "height": 384, "num_frames": 185},
|
228 |
-
{"label": "608x352, 193 frames", "width": 608, "height": 352, "num_frames": 193},
|
229 |
-
{"label": "576x352, 201 frames", "width": 576, "height": 352, "num_frames": 201},
|
230 |
-
{"label": "544x352, 209 frames", "width": 544, "height": 352, "num_frames": 209},
|
231 |
-
{"label": "512x352, 225 frames", "width": 512, "height": 352, "num_frames": 225},
|
232 |
-
{"label": "512x352, 233 frames", "width": 512, "height": 352, "num_frames": 233},
|
233 |
-
{"label": "544x320, 241 frames", "width": 544, "height": 320, "num_frames": 241},
|
234 |
-
{"label": "512x320, 249 frames", "width": 512, "height": 320, "num_frames": 249},
|
235 |
-
{"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
|
236 |
-
]
|
237 |
-
|
238 |
-
|
239 |
-
# Function to toggle visibility of sliders based on preset selection
|
240 |
-
def preset_changed(preset):
|
241 |
-
if preset != "Custom":
|
242 |
-
selected = next(item for item in preset_options if item["label"] == preset)
|
243 |
-
return (
|
244 |
-
selected["height"],
|
245 |
-
selected["width"],
|
246 |
-
selected["num_frames"],
|
247 |
-
gr.update(visible=False),
|
248 |
-
gr.update(visible=False),
|
249 |
-
gr.update(visible=False),
|
250 |
-
)
|
251 |
else:
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
prompt="",
|
282 |
-
enhance_prompt_toggle=False,
|
283 |
-
txt2vid_analytics_toggle=True,
|
284 |
-
negative_prompt="",
|
285 |
-
frame_rate=25,
|
286 |
-
seed=646373,
|
287 |
-
num_inference_steps=30,
|
288 |
-
guidance_scale=3,
|
289 |
-
height=512,
|
290 |
-
width=768,
|
291 |
-
num_frames=121,
|
292 |
-
progress=gr.Progress(),
|
293 |
-
stg_scale=1.0,
|
294 |
-
stg_rescale=0.7,
|
295 |
-
stg_mode="stg_a",
|
296 |
-
stg_skip_layers="19",
|
297 |
-
):
|
298 |
-
if len(prompt.strip()) < 50:
|
299 |
-
raise gr.Error(
|
300 |
-
"Prompt must be at least 50 characters long. Please provide more details for the best results.",
|
301 |
-
duration=5,
|
302 |
-
)
|
303 |
-
|
304 |
-
if txt2vid_analytics_toggle:
|
305 |
-
log_request(
|
306 |
-
"txt2vid",
|
307 |
-
prompt,
|
308 |
-
negative_prompt,
|
309 |
-
height,
|
310 |
-
width,
|
311 |
-
num_frames,
|
312 |
-
frame_rate,
|
313 |
-
seed,
|
314 |
-
num_inference_steps,
|
315 |
-
guidance_scale,
|
316 |
-
enhance_prompt_toggle,
|
317 |
-
)
|
318 |
-
|
319 |
-
prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle, type="t2v")
|
320 |
-
|
321 |
-
sample = {
|
322 |
"prompt": prompt,
|
323 |
-
"prompt_attention_mask": None,
|
324 |
"negative_prompt": negative_prompt,
|
325 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
"media_items": None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
}
|
328 |
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
num_inference_steps=num_inference_steps,
|
338 |
-
num_images_per_prompt=1,
|
339 |
-
guidance_scale=guidance_scale,
|
340 |
-
generator=generator,
|
341 |
-
output_type="pt",
|
342 |
-
height=height,
|
343 |
-
width=width,
|
344 |
-
num_frames=num_frames,
|
345 |
-
frame_rate=frame_rate,
|
346 |
-
**sample,
|
347 |
-
is_video=True,
|
348 |
-
vae_per_channel_normalize=True,
|
349 |
-
conditioning_method=ConditioningMethod.UNCONDITIONAL,
|
350 |
-
mixed_precision=True,
|
351 |
-
callback_on_step_end=gradio_progress_callback,
|
352 |
-
stg_scale=stg_scale,
|
353 |
-
do_rescaling=stg_rescale != 1,
|
354 |
-
rescaling_scale=stg_rescale,
|
355 |
-
skip_layer_strategy=SkipLayerStrategy.Attention if stg_mode == "stg_a" else SkipLayerStrategy.Residual,
|
356 |
-
skip_block_list=[int(x.strip()) for x in stg_skip_layers.split(",")]
|
357 |
-
).images
|
358 |
-
except Exception as e:
|
359 |
-
raise gr.Error(
|
360 |
-
f"An error occurred while generating the video. Please try again. Error: {e}",
|
361 |
-
duration=5,
|
362 |
-
)
|
363 |
-
finally:
|
364 |
-
torch.cuda.empty_cache()
|
365 |
-
gc.collect()
|
366 |
-
|
367 |
-
output_path = tempfile.mktemp(suffix=".mp4")
|
368 |
-
print(images.shape)
|
369 |
-
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
370 |
video_np = (video_np * 255).astype(np.uint8)
|
371 |
-
height, width = video_np.shape[1:3]
|
372 |
-
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
|
373 |
-
for frame in video_np[..., ::-1]:
|
374 |
-
out.write(frame)
|
375 |
-
out.release()
|
376 |
-
# Explicitly delete tensors and clear cache
|
377 |
-
del images
|
378 |
-
del video_np
|
379 |
-
torch.cuda.empty_cache()
|
380 |
-
return output_path
|
381 |
-
|
382 |
-
@spaces.GPU(duration=120)
|
383 |
-
def generate_video_from_image(
|
384 |
-
image_path,
|
385 |
-
prompt="",
|
386 |
-
enhance_prompt_toggle=False,
|
387 |
-
img2vid_analytics_toggle=True,
|
388 |
-
negative_prompt="",
|
389 |
-
frame_rate=25,
|
390 |
-
seed=646373,
|
391 |
-
num_inference_steps=30,
|
392 |
-
guidance_scale=3,
|
393 |
-
height=512,
|
394 |
-
width=768,
|
395 |
-
num_frames=121,
|
396 |
-
progress=gr.Progress(),
|
397 |
-
stg_scale=1.0,
|
398 |
-
stg_rescale=0.7,
|
399 |
-
stg_mode="stg_a",
|
400 |
-
stg_skip_layers="19",
|
401 |
-
):
|
402 |
-
|
403 |
-
print("Height: ", height)
|
404 |
-
print("Width: ", width)
|
405 |
-
print("Num Frames: ", num_frames)
|
406 |
-
|
407 |
-
if len(prompt.strip()) < 50:
|
408 |
-
raise gr.Error(
|
409 |
-
"Prompt must be at least 50 characters long. Please provide more details for the best results.",
|
410 |
-
duration=5,
|
411 |
-
)
|
412 |
-
|
413 |
-
if not image_path:
|
414 |
-
raise gr.Error("Please provide an input image.", duration=5)
|
415 |
-
|
416 |
-
if img2vid_analytics_toggle:
|
417 |
-
with Image.open(image_path) as img:
|
418 |
-
original_resolution = f"{img.width}x{img.height}" # Format as "widthxheight"
|
419 |
-
clip_embedding = compute_clip_embedding(image=img)
|
420 |
-
|
421 |
-
log_request(
|
422 |
-
"img2vid",
|
423 |
-
prompt,
|
424 |
-
negative_prompt,
|
425 |
-
height,
|
426 |
-
width,
|
427 |
-
num_frames,
|
428 |
-
frame_rate,
|
429 |
-
seed,
|
430 |
-
num_inference_steps,
|
431 |
-
guidance_scale,
|
432 |
-
enhance_prompt_toggle,
|
433 |
-
json.dumps(clip_embedding),
|
434 |
-
original_resolution,
|
435 |
-
)
|
436 |
-
|
437 |
-
media_items = load_image_to_tensor_with_resize(image_path, height, width).to(device).detach()
|
438 |
-
|
439 |
-
prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle, type="i2v")
|
440 |
-
|
441 |
-
sample = {
|
442 |
-
"prompt": prompt,
|
443 |
-
"prompt_attention_mask": None,
|
444 |
-
"negative_prompt": negative_prompt,
|
445 |
-
"negative_prompt_attention_mask": None,
|
446 |
-
"media_items": media_items,
|
447 |
-
}
|
448 |
-
|
449 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
450 |
-
|
451 |
-
def gradio_progress_callback(self, step, timestep, kwargs):
|
452 |
-
progress((step + 1) / num_inference_steps)
|
453 |
|
|
|
|
|
|
|
|
|
454 |
try:
|
455 |
-
with
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
guidance_scale=guidance_scale,
|
460 |
-
generator=generator,
|
461 |
-
output_type="pt",
|
462 |
-
height=height,
|
463 |
-
width=width,
|
464 |
-
num_frames=num_frames,
|
465 |
-
frame_rate=frame_rate,
|
466 |
-
**sample,
|
467 |
-
is_video=True,
|
468 |
-
vae_per_channel_normalize=True,
|
469 |
-
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
470 |
-
mixed_precision=True,
|
471 |
-
callback_on_step_end=gradio_progress_callback,
|
472 |
-
stg_scale=stg_scale,
|
473 |
-
do_rescaling=stg_rescale != 1,
|
474 |
-
rescaling_scale=stg_rescale,
|
475 |
-
skip_layer_strategy=SkipLayerStrategy.Attention if stg_mode == "stg_a" else SkipLayerStrategy.Residual,
|
476 |
-
skip_block_list=[int(x.strip()) for x in stg_skip_layers.split(",")]
|
477 |
-
).images
|
478 |
-
|
479 |
-
output_path = tempfile.mktemp(suffix=".mp4")
|
480 |
-
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
481 |
-
video_np = (video_np * 255).astype(np.uint8)
|
482 |
-
height, width = video_np.shape[1:3]
|
483 |
-
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
|
484 |
-
for frame in video_np[..., ::-1]:
|
485 |
-
out.write(frame)
|
486 |
-
out.release()
|
487 |
except Exception as e:
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
def
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
)
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
)
|
629 |
-
|
630 |
-
txt2vid_preset = gr.Dropdown(
|
631 |
-
choices=[p["label"] for p in preset_options],
|
632 |
-
value="768x512, 97 frames",
|
633 |
-
label="Step 3.1: Choose Resolution Preset",
|
634 |
-
)
|
635 |
-
|
636 |
-
txt2vid_frame_rate = gr.Slider(
|
637 |
-
label="Step 3.2: Frame Rate",
|
638 |
-
minimum=21,
|
639 |
-
maximum=30,
|
640 |
-
step=1,
|
641 |
-
value=25,
|
642 |
-
)
|
643 |
-
|
644 |
-
txt2vid_advanced = create_advanced_options()
|
645 |
-
txt2vid_generate = gr.Button(
|
646 |
-
"Step 5: Generate Video",
|
647 |
-
variant="primary",
|
648 |
-
size="lg",
|
649 |
-
)
|
650 |
-
|
651 |
-
with gr.Column():
|
652 |
-
txt2vid_output = gr.Video(label="Generated Output")
|
653 |
-
|
654 |
-
with gr.Row():
|
655 |
-
gr.Examples(
|
656 |
-
examples=[
|
657 |
-
[
|
658 |
-
"A young woman in a traditional Mongolian dress is peeking through a sheer white curtain, her face showing a mix of curiosity and apprehension. The woman has long black hair styled in two braids, adorned with white beads, and her eyes are wide with a hint of surprise. Her dress is a vibrant blue with intricate gold embroidery, and she wears a matching headband with a similar design. The background is a simple white curtain, which creates a sense of mystery and intrigue.ith long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair’s face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
|
659 |
-
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
660 |
-
"assets/t2v_2.mp4",
|
661 |
-
],
|
662 |
-
[
|
663 |
-
"A young man with blond hair wearing a yellow jacket stands in a forest and looks around. He has light skin and his hair is styled with a middle part. He looks to the left and then to the right, his gaze lingering in each direction. The camera angle is low, looking up at the man, and remains stationary throughout the video. The background is slightly out of focus, with green trees and the sun shining brightly behind the man. The lighting is natural and warm, with the sun creating a lens flare that moves across the man’s face. The scene is captured in real-life footage.",
|
664 |
-
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
665 |
-
"assets/t2v_1.mp4",
|
666 |
-
],
|
667 |
-
[
|
668 |
-
"A cyclist races along a winding mountain road. Clad in aerodynamic gear, he pedals intensely, sweat glistening on his brow. The camera alternates between close-ups of his determined expression and wide shots of the breathtaking landscape. Pine trees blur past, and the sky is a crisp blue. The scene is invigorating and competitive.",
|
669 |
-
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
670 |
-
"assets/t2v_0.mp4",
|
671 |
-
],
|
672 |
-
],
|
673 |
-
inputs=[txt2vid_prompt, txt2vid_negative_prompt, txt2vid_output],
|
674 |
-
label="Example Text-to-Video Generations",
|
675 |
-
)
|
676 |
-
|
677 |
-
# Image to Video Tab
|
678 |
-
with gr.TabItem("Image to Video"):
|
679 |
-
with gr.Row():
|
680 |
-
with gr.Column():
|
681 |
-
img2vid_image = gr.Image(
|
682 |
-
type="filepath",
|
683 |
-
label="Step 1: Upload Input Image",
|
684 |
-
elem_id="image_upload",
|
685 |
-
)
|
686 |
-
img2vid_prompt = gr.Textbox(
|
687 |
-
label="Step 2: Enter Your Prompt",
|
688 |
-
placeholder="Describe how you want to animate the image (minimum 50 characters)...",
|
689 |
-
value="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.",
|
690 |
-
lines=5,
|
691 |
-
)
|
692 |
-
img2vid_analytics_toggle = Toggle(
|
693 |
-
label="I agree to share my usage data anonymously to help improve the model features.",
|
694 |
-
value=True,
|
695 |
-
interactive=True,
|
696 |
-
)
|
697 |
-
img2vid_enhance_toggle = Toggle(
|
698 |
-
label="Enhance Prompt",
|
699 |
-
value=False,
|
700 |
-
interactive=True,
|
701 |
-
)
|
702 |
-
img2vid_negative_prompt = gr.Textbox(
|
703 |
-
label="Step 3: Enter Negative Prompt",
|
704 |
-
placeholder="Describe what you don't want in the video...",
|
705 |
-
value="low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
706 |
-
lines=2,
|
707 |
-
)
|
708 |
-
|
709 |
-
img2vid_preset = gr.Dropdown(
|
710 |
-
choices=[p["label"] for p in preset_options],
|
711 |
-
value="768x512, 97 frames",
|
712 |
-
label="Step 3.1: Choose Resolution Preset",
|
713 |
-
)
|
714 |
-
|
715 |
-
img2vid_frame_rate = gr.Slider(
|
716 |
-
label="Step 3.2: Frame Rate",
|
717 |
-
minimum=21,
|
718 |
-
maximum=30,
|
719 |
-
step=1,
|
720 |
-
value=25,
|
721 |
-
)
|
722 |
-
|
723 |
-
img2vid_advanced = create_advanced_options()
|
724 |
-
img2vid_generate = gr.Button("Step 6: Generate Video", variant="primary", size="lg")
|
725 |
-
|
726 |
-
with gr.Column():
|
727 |
-
img2vid_output = gr.Video(label="Generated Output")
|
728 |
-
|
729 |
-
with gr.Row():
|
730 |
-
gr.Examples(
|
731 |
-
examples=[
|
732 |
-
[
|
733 |
-
"assets/i2v_i2.png",
|
734 |
-
"A woman stirs a pot of boiling water on a white electric burner. Her hands, with purple nail polish, hold a wooden spoon and move it in a circular motion within a white pot filled with bubbling water. The pot sits on a white electric burner with black buttons and a digital display. The burner is positioned on a white countertop with a red and white checkered cloth partially visible in the bottom right corner. The camera angle is a direct overhead shot, remaining stationary throughout the scene. The lighting is bright and even, illuminating the scene with a neutral white light. The scene is real-life footage.",
|
735 |
-
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
736 |
-
"assets/i2v_2.mp4",
|
737 |
-
],
|
738 |
-
[
|
739 |
-
"assets/i2v_i0.png",
|
740 |
-
"A woman in a long, flowing dress stands in a field, her back to the camera, gazing towards the horizon; her hair is long and light, cascading down her back; she stands beneath the sprawling branches of a large oak tree; to her left, a classic American car is parked on the dry grass; in the distance, a wrecked car lies on its side; the sky above is a dramatic canvas of bright white clouds against a darker sky; the entire image is in black and white, emphasizing the contrast of light and shadow. The woman is walking slowly towards the car.",
|
741 |
-
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
742 |
-
"assets/i2v_0.mp4",
|
743 |
-
],
|
744 |
-
[
|
745 |
-
"assets/i2v_i1.png",
|
746 |
-
"A pair of hands shapes a piece of clay on a pottery wheel, gradually forming a cone shape. The hands, belonging to a person out of frame, are covered in clay and gently press a ball of clay onto the center of a spinning pottery wheel. The hands move in a circular motion, gradually forming a cone shape at the top of the clay. The camera is positioned directly above the pottery wheel, providing a bird’s-eye view of the clay being shaped. The lighting is bright and even, illuminating the clay and the hands working on it. The scene is captured in real-life footage.",
|
747 |
-
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
748 |
-
"assets/i2v_1.mp4",
|
749 |
-
],
|
750 |
-
],
|
751 |
-
inputs=[
|
752 |
-
img2vid_image,
|
753 |
-
img2vid_prompt,
|
754 |
-
img2vid_negative_prompt,
|
755 |
-
img2vid_output,
|
756 |
-
],
|
757 |
-
label="Example Image-to-Video Generations",
|
758 |
-
)
|
759 |
-
|
760 |
-
# [Previous event handlers remain the same]
|
761 |
-
txt2vid_preset.change(fn=preset_changed, inputs=[txt2vid_preset], outputs=txt2vid_advanced[3:])
|
762 |
-
|
763 |
-
txt2vid_generate.click(
|
764 |
-
fn=generate_video_from_text,
|
765 |
-
inputs=[
|
766 |
-
txt2vid_prompt,
|
767 |
-
txt2vid_enhance_toggle,
|
768 |
-
txt2vid_analytics_toggle,
|
769 |
-
txt2vid_negative_prompt,
|
770 |
-
txt2vid_frame_rate,
|
771 |
-
*txt2vid_advanced,
|
772 |
-
],
|
773 |
-
outputs=txt2vid_output,
|
774 |
-
concurrency_limit=1,
|
775 |
-
concurrency_id="generate_video",
|
776 |
)
|
777 |
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
img2vid_enhance_toggle,
|
786 |
-
img2vid_analytics_toggle,
|
787 |
-
img2vid_negative_prompt,
|
788 |
-
img2vid_frame_rate,
|
789 |
-
*img2vid_advanced,
|
790 |
-
],
|
791 |
-
outputs=img2vid_output,
|
792 |
-
concurrency_limit=1,
|
793 |
-
concurrency_id="generate_video",
|
794 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
|
796 |
if __name__ == "__main__":
|
797 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
+
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
+
import random
|
|
|
|
|
6 |
import os
|
7 |
+
import yaml
|
8 |
+
from pathlib import Path
|
9 |
+
import imageio
|
10 |
+
import tempfile
|
11 |
+
from PIL import Image
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
import shutil
|
14 |
+
|
15 |
+
from inference import (
|
16 |
+
create_ltx_video_pipeline,
|
17 |
+
create_latent_upsampler,
|
18 |
+
load_image_to_tensor_with_resize_and_crop,
|
19 |
+
seed_everething,
|
20 |
+
get_device,
|
21 |
+
calculate_padding,
|
22 |
+
load_media_file
|
23 |
+
)
|
24 |
+
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
|
25 |
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
26 |
|
27 |
+
config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml"
|
28 |
+
with open(config_file_path, "r") as file:
|
29 |
+
PIPELINE_CONFIG_YAML = yaml.safe_load(file)
|
30 |
+
|
31 |
+
LTX_REPO = "Lightricks/LTX-Video"
|
32 |
+
MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280)
|
33 |
+
MAX_NUM_FRAMES = 257
|
34 |
+
|
35 |
+
FPS = 30.0
|
36 |
+
|
37 |
+
# --- Global variables for loaded models ---
|
38 |
+
pipeline_instance = None
|
39 |
+
latent_upsampler_instance = None
|
40 |
+
models_dir = "downloaded_models_gradio_cpu_init"
|
41 |
+
Path(models_dir).mkdir(parents=True, exist_ok=True)
|
42 |
+
|
43 |
+
print("Downloading models (if not present)...")
|
44 |
+
distilled_model_actual_path = hf_hub_download(
|
45 |
+
repo_id=LTX_REPO,
|
46 |
+
filename=PIPELINE_CONFIG_YAML["checkpoint_path"],
|
47 |
+
local_dir=models_dir,
|
48 |
+
local_dir_use_symlinks=False
|
49 |
+
)
|
50 |
+
PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
|
51 |
+
print(f"Distilled model path: {distilled_model_actual_path}")
|
52 |
+
|
53 |
+
SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
|
54 |
+
spatial_upscaler_actual_path = hf_hub_download(
|
55 |
+
repo_id=LTX_REPO,
|
56 |
+
filename=SPATIAL_UPSCALER_FILENAME,
|
57 |
+
local_dir=models_dir,
|
58 |
+
local_dir_use_symlinks=False
|
59 |
+
)
|
60 |
+
PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
|
61 |
+
print(f"Spatial upscaler model path: {spatial_upscaler_actual_path}")
|
62 |
+
|
63 |
+
print("Creating LTX Video pipeline on CPU...")
|
64 |
+
pipeline_instance = create_ltx_video_pipeline(
|
65 |
+
ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
|
66 |
+
precision=PIPELINE_CONFIG_YAML["precision"],
|
67 |
+
text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
|
68 |
+
sampler=PIPELINE_CONFIG_YAML["sampler"],
|
69 |
+
device="cpu",
|
70 |
+
enhance_prompt=False,
|
71 |
+
prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
|
72 |
+
prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
|
73 |
+
)
|
74 |
+
print("LTX Video pipeline created on CPU.")
|
75 |
+
|
76 |
+
if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
|
77 |
+
print("Creating latent upsampler on CPU...")
|
78 |
+
latent_upsampler_instance = create_latent_upsampler(
|
79 |
+
PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"],
|
80 |
+
device="cpu"
|
81 |
+
)
|
82 |
+
print("Latent upsampler created on CPU.")
|
83 |
|
84 |
+
target_inference_device = "cuda"
|
85 |
+
print(f"Target inference device: {target_inference_device}")
|
86 |
+
pipeline_instance.to(target_inference_device)
|
87 |
+
if latent_upsampler_instance:
|
88 |
+
latent_upsampler_instance.to(target_inference_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
+
# --- Helper function for dimension calculation ---
|
92 |
+
MIN_DIM_SLIDER = 256 # As defined in the sliders minimum attribute
|
93 |
+
TARGET_FIXED_SIDE = 768 # Desired fixed side length as per requirement
|
94 |
+
|
95 |
+
def calculate_new_dimensions(orig_w, orig_h):
|
96 |
"""
|
97 |
+
Calculates new dimensions for height and width sliders based on original media dimensions.
|
98 |
+
Ensures one side is TARGET_FIXED_SIDE, the other is scaled proportionally,
|
99 |
+
both are multiples of 32, and within [MIN_DIM_SLIDER, MAX_IMAGE_SIZE].
|
|
|
|
|
|
|
100 |
"""
|
101 |
+
if orig_w == 0 or orig_h == 0:
|
102 |
+
# Default to TARGET_FIXED_SIDE square if original dimensions are invalid
|
103 |
+
return int(TARGET_FIXED_SIDE), int(TARGET_FIXED_SIDE)
|
104 |
+
|
105 |
+
if orig_w >= orig_h: # Landscape or square
|
106 |
+
new_h = TARGET_FIXED_SIDE
|
107 |
+
aspect_ratio = orig_w / orig_h
|
108 |
+
new_w_ideal = new_h * aspect_ratio
|
109 |
+
|
110 |
+
# Round to nearest multiple of 32
|
111 |
+
new_w = round(new_w_ideal / 32) * 32
|
112 |
+
|
113 |
+
# Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
|
114 |
+
new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
|
115 |
+
# Ensure new_h is also clamped (TARGET_FIXED_SIDE should be within these bounds if configured correctly)
|
116 |
+
new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
|
117 |
+
else: # Portrait
|
118 |
+
new_w = TARGET_FIXED_SIDE
|
119 |
+
aspect_ratio = orig_h / orig_w # Use H/W ratio for portrait scaling
|
120 |
+
new_h_ideal = new_w * aspect_ratio
|
121 |
+
|
122 |
+
# Round to nearest multiple of 32
|
123 |
+
new_h = round(new_h_ideal / 32) * 32
|
124 |
+
|
125 |
+
# Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
|
126 |
+
new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
|
127 |
+
# Ensure new_w is also clamped
|
128 |
+
new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
|
129 |
+
|
130 |
+
return int(new_h), int(new_w)
|
131 |
+
|
132 |
+
def get_duration(prompt, negative_prompt, input_image_filepath, input_video_filepath,
|
133 |
+
height_ui, width_ui, mode,
|
134 |
+
duration_ui, # Removed ui_steps
|
135 |
+
ui_frames_to_use,
|
136 |
+
seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
|
137 |
+
progress):
|
138 |
+
if duration_ui > 7:
|
139 |
+
return 75
|
140 |
+
else:
|
141 |
+
return 60
|
142 |
+
|
143 |
+
@spaces.GPU(duration=get_duration)
|
144 |
+
def generate(prompt, negative_prompt, input_image_filepath=None, input_video_filepath=None,
|
145 |
+
height_ui=512, width_ui=704, mode="text-to-video",
|
146 |
+
duration_ui=2.0,
|
147 |
+
ui_frames_to_use=9,
|
148 |
+
seed_ui=42, randomize_seed=True, ui_guidance_scale=3.0, improve_texture_flag=True,
|
149 |
+
progress=gr.Progress(track_tqdm=True)):
|
150 |
+
"""
|
151 |
+
Generate high-quality videos using LTX Video model with support for text-to-video, image-to-video, and video-to-video modes.
|
152 |
|
153 |
+
Args:
|
154 |
+
prompt (str): Text description of the desired video content. Required for all modes.
|
155 |
+
negative_prompt (str): Text describing what to avoid in the generated video. Optional, can be empty string.
|
156 |
+
input_image_filepath (str or None): Path to input image file. Required for image-to-video mode, None for other modes.
|
157 |
+
input_video_filepath (str or None): Path to input video file. Required for video-to-video mode, None for other modes.
|
158 |
+
height_ui (int): Height of the output video in pixels, must be divisible by 32. Default: 512.
|
159 |
+
width_ui (int): Width of the output video in pixels, must be divisible by 32. Default: 704.
|
160 |
+
mode (str): Generation mode. Required. One of "text-to-video", "image-to-video", or "video-to-video". Default: "text-to-video".
|
161 |
+
duration_ui (float): Duration of the output video in seconds. Range: 0.3 to 8.5. Default: 2.0.
|
162 |
+
ui_frames_to_use (int): Number of frames to use from input video. Only used in video-to-video mode. Must be N*8+1. Default: 9.
|
163 |
+
seed_ui (int): Random seed for reproducible generation. Range: 0 to 2^32-1. Default: 42.
|
164 |
+
randomize_seed (bool): Whether to use a random seed instead of seed_ui. Default: True.
|
165 |
+
ui_guidance_scale (float): CFG scale controlling prompt influence. Range: 1.0 to 10.0. Higher values = stronger prompt influence. Default: 3.0.
|
166 |
+
improve_texture_flag (bool): Whether to use multi-scale generation for better texture quality. Slower but higher quality. Default: True.
|
167 |
+
progress (gr.Progress): Progress tracker for the generation process. Optional, used for UI updates.
|
168 |
|
169 |
+
Returns:
|
170 |
+
tuple: A tuple containing (output_video_path, used_seed) where output_video_path is the path to the generated video file and used_seed is the actual seed used for generation.
|
171 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
# Validate mode-specific required parameters
|
174 |
+
if mode == "image-to-video":
|
175 |
+
if not input_image_filepath:
|
176 |
+
raise gr.Error("input_image_filepath is required for image-to-video mode")
|
177 |
+
elif mode == "video-to-video":
|
178 |
+
if not input_video_filepath:
|
179 |
+
raise gr.Error("input_video_filepath is required for video-to-video mode")
|
180 |
+
elif mode == "text-to-video":
|
181 |
+
# No additional file inputs required for text-to-video
|
182 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
else:
|
184 |
+
raise gr.Error(f"Invalid mode: {mode}. Must be one of: text-to-video, image-to-video, video-to-video")
|
185 |
+
|
186 |
+
if randomize_seed:
|
187 |
+
seed_ui = random.randint(0, 2**32 - 1)
|
188 |
+
seed_everething(int(seed_ui))
|
189 |
+
|
190 |
+
target_frames_ideal = duration_ui * FPS
|
191 |
+
target_frames_rounded = round(target_frames_ideal)
|
192 |
+
if target_frames_rounded < 1:
|
193 |
+
target_frames_rounded = 1
|
194 |
+
|
195 |
+
n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
|
196 |
+
actual_num_frames = int(n_val * 8 + 1)
|
197 |
+
|
198 |
+
actual_num_frames = max(9, actual_num_frames)
|
199 |
+
actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
|
200 |
+
|
201 |
+
actual_height = int(height_ui)
|
202 |
+
actual_width = int(width_ui)
|
203 |
+
|
204 |
+
height_padded = ((actual_height - 1) // 32 + 1) * 32
|
205 |
+
width_padded = ((actual_width - 1) // 32 + 1) * 32
|
206 |
+
num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
|
207 |
+
if num_frames_padded != actual_num_frames:
|
208 |
+
print(f"Warning: actual_num_frames ({actual_num_frames}) and num_frames_padded ({num_frames_padded}) differ. Using num_frames_padded for pipeline.")
|
209 |
+
|
210 |
+
padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
|
211 |
+
|
212 |
+
call_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
"prompt": prompt,
|
|
|
214 |
"negative_prompt": negative_prompt,
|
215 |
+
"height": height_padded,
|
216 |
+
"width": width_padded,
|
217 |
+
"num_frames": num_frames_padded,
|
218 |
+
"frame_rate": int(FPS),
|
219 |
+
"generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
|
220 |
+
"output_type": "pt",
|
221 |
+
"conditioning_items": None,
|
222 |
"media_items": None,
|
223 |
+
"decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
|
224 |
+
"decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"],
|
225 |
+
"stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"],
|
226 |
+
"image_cond_noise_scale": 0.15,
|
227 |
+
"is_video": True,
|
228 |
+
"vae_per_channel_normalize": True,
|
229 |
+
"mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"),
|
230 |
+
"offload_to_cpu": False,
|
231 |
+
"enhance_prompt": False,
|
232 |
}
|
233 |
|
234 |
+
stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values")
|
235 |
+
if stg_mode_str.lower() in ["stg_av", "attention_values"]:
|
236 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionValues
|
237 |
+
elif stg_mode_str.lower() in ["stg_as", "attention_skip"]:
|
238 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionSkip
|
239 |
+
elif stg_mode_str.lower() in ["stg_r", "residual"]:
|
240 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.Residual
|
241 |
+
elif stg_mode_str.lower() in ["stg_t", "transformer_block"]:
|
242 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.TransformerBlock
|
243 |
+
else:
|
244 |
+
raise ValueError(f"Invalid stg_mode: {stg_mode_str}")
|
245 |
|
246 |
+
if mode == "image-to-video" and input_image_filepath:
|
247 |
+
try:
|
248 |
+
media_tensor = load_image_to_tensor_with_resize_and_crop(
|
249 |
+
input_image_filepath, actual_height, actual_width
|
250 |
+
)
|
251 |
+
media_tensor = torch.nn.functional.pad(media_tensor, padding_values)
|
252 |
+
call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_inference_device), 0, 1.0)]
|
253 |
+
except Exception as e:
|
254 |
+
print(f"Error loading image {input_image_filepath}: {e}")
|
255 |
+
raise gr.Error(f"Could not load image: {e}")
|
256 |
+
elif mode == "video-to-video" and input_video_filepath:
|
257 |
+
try:
|
258 |
+
call_kwargs["media_items"] = load_media_file(
|
259 |
+
media_path=input_video_filepath,
|
260 |
+
height=actual_height,
|
261 |
+
width=actual_width,
|
262 |
+
max_frames=int(ui_frames_to_use),
|
263 |
+
padding=padding_values
|
264 |
+
).to(target_inference_device)
|
265 |
+
except Exception as e:
|
266 |
+
print(f"Error loading video {input_video_filepath}: {e}")
|
267 |
+
raise gr.Error(f"Could not load video: {e}")
|
268 |
+
|
269 |
+
print(f"Moving models to {target_inference_device} for inference (if not already there)...")
|
270 |
+
|
271 |
+
active_latent_upsampler = None
|
272 |
+
if improve_texture_flag and latent_upsampler_instance:
|
273 |
+
active_latent_upsampler = latent_upsampler_instance
|
274 |
+
|
275 |
+
result_images_tensor = None
|
276 |
+
if improve_texture_flag:
|
277 |
+
if not active_latent_upsampler:
|
278 |
+
raise gr.Error("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
|
279 |
+
|
280 |
+
multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
|
281 |
+
|
282 |
+
first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
|
283 |
+
first_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
|
284 |
+
# num_inference_steps will be derived from len(timesteps) in the pipeline
|
285 |
+
first_pass_args.pop("num_inference_steps", None)
|
286 |
+
|
287 |
+
|
288 |
+
second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
|
289 |
+
second_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
|
290 |
+
# num_inference_steps will be derived from len(timesteps) in the pipeline
|
291 |
+
second_pass_args.pop("num_inference_steps", None)
|
292 |
+
|
293 |
+
multi_scale_call_kwargs = call_kwargs.copy()
|
294 |
+
multi_scale_call_kwargs.update({
|
295 |
+
"downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
|
296 |
+
"first_pass": first_pass_args,
|
297 |
+
"second_pass": second_pass_args,
|
298 |
+
})
|
299 |
+
|
300 |
+
print(f"Calling multi-scale pipeline (eff. HxW: {actual_height}x{actual_width}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
|
301 |
+
result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
|
302 |
+
else:
|
303 |
+
single_pass_call_kwargs = call_kwargs.copy()
|
304 |
+
first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
|
305 |
+
|
306 |
+
single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
|
307 |
+
single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
|
308 |
+
single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
|
309 |
+
single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
|
310 |
+
single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
|
311 |
+
|
312 |
+
# Remove keys that might conflict or are not used in single pass / handled by above
|
313 |
+
single_pass_call_kwargs.pop("num_inference_steps", None)
|
314 |
+
single_pass_call_kwargs.pop("first_pass", None)
|
315 |
+
single_pass_call_kwargs.pop("second_pass", None)
|
316 |
+
single_pass_call_kwargs.pop("downscale_factor", None)
|
317 |
+
|
318 |
+
print(f"Calling base pipeline (padded HxW: {height_padded}x{width_padded}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
|
319 |
+
result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
|
320 |
+
|
321 |
+
if result_images_tensor is None:
|
322 |
+
raise gr.Error("Generation failed.")
|
323 |
+
|
324 |
+
pad_left, pad_right, pad_top, pad_bottom = padding_values
|
325 |
+
slice_h_end = -pad_bottom if pad_bottom > 0 else None
|
326 |
+
slice_w_end = -pad_right if pad_right > 0 else None
|
327 |
+
|
328 |
+
result_images_tensor = result_images_tensor[
|
329 |
+
:, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
|
330 |
+
]
|
331 |
|
332 |
+
video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
|
333 |
+
|
334 |
+
video_np = np.clip(video_np, 0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
video_np = (video_np * 255).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
+
temp_dir = tempfile.mkdtemp()
|
338 |
+
timestamp = random.randint(10000,99999)
|
339 |
+
output_video_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
|
340 |
+
|
341 |
try:
|
342 |
+
with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
|
343 |
+
for frame_idx in range(video_np.shape[0]):
|
344 |
+
progress(frame_idx / video_np.shape[0], desc="Saving video")
|
345 |
+
video_writer.append_data(video_np[frame_idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
except Exception as e:
|
347 |
+
print(f"Error saving video with macro_block_size=1: {e}")
|
348 |
+
try:
|
349 |
+
with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264', quality=8) as video_writer:
|
350 |
+
for frame_idx in range(video_np.shape[0]):
|
351 |
+
progress(frame_idx / video_np.shape[0], desc="Saving video (fallback ffmpeg)")
|
352 |
+
video_writer.append_data(video_np[frame_idx])
|
353 |
+
except Exception as e2:
|
354 |
+
print(f"Fallback video saving error: {e2}")
|
355 |
+
raise gr.Error(f"Failed to save video: {e2}")
|
356 |
+
|
357 |
+
return output_video_path, seed_ui
|
358 |
+
|
359 |
+
def update_task_image():
|
360 |
+
return "image-to-video"
|
361 |
+
|
362 |
+
def update_task_text():
|
363 |
+
return "text-to-video"
|
364 |
+
|
365 |
+
def update_task_video():
|
366 |
+
return "video-to-video"
|
367 |
+
|
368 |
+
# --- Gradio UI Definition ---
|
369 |
+
css="""
|
370 |
+
#col-container {
|
371 |
+
margin: 0 auto;
|
372 |
+
max-width: 900px;
|
373 |
+
}
|
374 |
+
"""
|
375 |
+
|
376 |
+
with gr.Blocks(css=css) as demo:
|
377 |
+
gr.Markdown("# LTX Video 0.9.8 13B Distilled")
|
378 |
+
gr.Markdown("Fast high quality video generation.**Update (17/07):** now with the new v0.9.8 for improved prompt understanding and detail generation" )
|
379 |
+
gr.Markdown("[Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.8-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](https://huggingface.co/Lightricks/LTX-Video-0.9.8-13B-distilled#diffusers-🧨)")
|
380 |
+
with gr.Row():
|
381 |
+
with gr.Column():
|
382 |
+
with gr.Tab("image-to-video") as image_tab:
|
383 |
+
video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None)
|
384 |
+
image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "webcam", "clipboard"])
|
385 |
+
i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3)
|
386 |
+
i2v_button = gr.Button("Generate Image-to-Video", variant="primary")
|
387 |
+
with gr.Tab("text-to-video") as text_tab:
|
388 |
+
image_n_hidden = gr.Textbox(label="image_n", visible=False, value=None)
|
389 |
+
video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None)
|
390 |
+
t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3)
|
391 |
+
t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
|
392 |
+
with gr.Tab("video-to-video", visible=False) as video_tab:
|
393 |
+
image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
|
394 |
+
video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"]) # type defaults to filepath
|
395 |
+
frames_to_use = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames to use for conditioning/transformation. Must be N*8+1.")
|
396 |
+
v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3)
|
397 |
+
v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
|
398 |
+
|
399 |
+
duration_input = gr.Slider(
|
400 |
+
label="Video Duration (seconds)",
|
401 |
+
minimum=0.3,
|
402 |
+
maximum=8.5,
|
403 |
+
value=2,
|
404 |
+
step=0.1,
|
405 |
+
info=f"Target video duration (0.3s to 8.5s)"
|
406 |
+
)
|
407 |
+
improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True,visible=False, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
|
408 |
+
|
409 |
+
with gr.Column():
|
410 |
+
output_video = gr.Video(label="Generated Video", interactive=False)
|
411 |
+
# gr.DeepLinkButton()
|
412 |
+
|
413 |
+
with gr.Accordion("Advanced settings", open=False):
|
414 |
+
mode = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="task", value="image-to-video", visible=False)
|
415 |
+
negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
|
416 |
+
with gr.Row():
|
417 |
+
seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
|
418 |
+
randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
|
419 |
+
with gr.Row(visible=False):
|
420 |
+
guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
|
421 |
+
with gr.Row():
|
422 |
+
height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
|
423 |
+
width_input = gr.Slider(label="Width", value=704, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
|
424 |
+
|
425 |
+
|
426 |
+
# --- Event handlers for updating dimensions on upload ---
|
427 |
+
def handle_image_upload_for_dims(image_filepath, current_h, current_w):
|
428 |
+
if not image_filepath: # Image cleared or no image initially
|
429 |
+
# Keep current slider values if image is cleared or no input
|
430 |
+
return gr.update(value=current_h), gr.update(value=current_w)
|
431 |
+
try:
|
432 |
+
img = Image.open(image_filepath)
|
433 |
+
orig_w, orig_h = img.size
|
434 |
+
new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
|
435 |
+
return gr.update(value=new_h), gr.update(value=new_w)
|
436 |
+
except Exception as e:
|
437 |
+
print(f"Error processing image for dimension update: {e}")
|
438 |
+
# Keep current slider values on error
|
439 |
+
return gr.update(value=current_h), gr.update(value=current_w)
|
440 |
+
|
441 |
+
def handle_video_upload_for_dims(video_filepath, current_h, current_w):
|
442 |
+
if not video_filepath: # Video cleared or no video initially
|
443 |
+
return gr.update(value=current_h), gr.update(value=current_w)
|
444 |
+
try:
|
445 |
+
# Ensure video_filepath is a string for os.path.exists and imageio
|
446 |
+
video_filepath_str = str(video_filepath)
|
447 |
+
if not os.path.exists(video_filepath_str):
|
448 |
+
print(f"Video file path does not exist for dimension update: {video_filepath_str}")
|
449 |
+
return gr.update(value=current_h), gr.update(value=current_w)
|
450 |
+
|
451 |
+
orig_w, orig_h = -1, -1
|
452 |
+
with imageio.get_reader(video_filepath_str) as reader:
|
453 |
+
meta = reader.get_meta_data()
|
454 |
+
if 'size' in meta:
|
455 |
+
orig_w, orig_h = meta['size']
|
456 |
+
else:
|
457 |
+
# Fallback: read first frame if 'size' not in metadata
|
458 |
+
try:
|
459 |
+
first_frame = reader.get_data(0)
|
460 |
+
# Shape is (h, w, c) for frames
|
461 |
+
orig_h, orig_w = first_frame.shape[0], first_frame.shape[1]
|
462 |
+
except Exception as e_frame:
|
463 |
+
print(f"Could not get video size from metadata or first frame: {e_frame}")
|
464 |
+
return gr.update(value=current_h), gr.update(value=current_w)
|
465 |
+
|
466 |
+
if orig_w == -1 or orig_h == -1: # If dimensions couldn't be determined
|
467 |
+
print(f"Could not determine dimensions for video: {video_filepath_str}")
|
468 |
+
return gr.update(value=current_h), gr.update(value=current_w)
|
469 |
+
|
470 |
+
new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
|
471 |
+
return gr.update(value=new_h), gr.update(value=new_w)
|
472 |
+
except Exception as e:
|
473 |
+
# Log type of video_filepath for debugging if it's not a path-like string
|
474 |
+
print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
|
475 |
+
return gr.update(value=current_h), gr.update(value=current_w)
|
476 |
+
|
477 |
+
|
478 |
+
image_i2v.upload(
|
479 |
+
fn=handle_image_upload_for_dims,
|
480 |
+
inputs=[image_i2v, height_input, width_input],
|
481 |
+
outputs=[height_input, width_input]
|
482 |
+
)
|
483 |
+
video_v2v.upload(
|
484 |
+
fn=handle_video_upload_for_dims,
|
485 |
+
inputs=[video_v2v, height_input, width_input],
|
486 |
+
outputs=[height_input, width_input]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
)
|
488 |
|
489 |
+
image_tab.select(
|
490 |
+
fn=update_task_image,
|
491 |
+
outputs=[mode]
|
492 |
+
)
|
493 |
+
text_tab.select(
|
494 |
+
fn=update_task_text,
|
495 |
+
outputs=[mode]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
)
|
497 |
+
|
498 |
+
t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
|
499 |
+
height_input, width_input, mode,
|
500 |
+
duration_input, frames_to_use,
|
501 |
+
seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
|
502 |
+
|
503 |
+
i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
|
504 |
+
height_input, width_input, mode,
|
505 |
+
duration_input, frames_to_use,
|
506 |
+
seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
|
507 |
+
|
508 |
+
v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
|
509 |
+
height_input, width_input, mode,
|
510 |
+
duration_input, frames_to_use,
|
511 |
+
seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
|
512 |
+
|
513 |
+
t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video, seed_input], api_name="text_to_video")
|
514 |
+
i2v_button.click(fn=generate, inputs=i2v_inputs, outputs=[output_video, seed_input], api_name="image_to_video")
|
515 |
+
v2v_button.click(fn=generate, inputs=v2v_inputs, outputs=[output_video, seed_input], api_name="video_to_video")
|
516 |
|
517 |
if __name__ == "__main__":
|
518 |
+
if os.path.exists(models_dir) and os.path.isdir(models_dir):
|
519 |
+
print(f"Model directory: {Path(models_dir).resolve()}")
|
520 |
+
|
521 |
+
demo.queue().launch(debug=True, share=False, mcp_server=True)
|
inference.py
CHANGED
@@ -1,162 +1,206 @@
|
|
1 |
-
import torch
|
2 |
-
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
3 |
-
from xora.models.transformers.transformer3d import Transformer3DModel
|
4 |
-
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
5 |
-
from xora.schedulers.rf import RectifiedFlowScheduler
|
6 |
-
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
|
7 |
-
from pathlib import Path
|
8 |
-
from transformers import T5EncoderModel, T5Tokenizer
|
9 |
-
import safetensors.torch
|
10 |
-
import json
|
11 |
import argparse
|
12 |
-
from xora.utils.conditioning_method import ConditioningMethod
|
13 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
import numpy as np
|
|
|
15 |
import cv2
|
|
|
16 |
from PIL import Image
|
17 |
-
import
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
def load_vae(vae_dir):
|
51 |
-
vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
|
52 |
-
vae_config_path = vae_dir / "config.json"
|
53 |
-
with open(vae_config_path, "r") as f:
|
54 |
-
vae_config = json.load(f)
|
55 |
-
vae = CausalVideoAutoencoder.from_config(vae_config)
|
56 |
-
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
57 |
-
vae.load_state_dict(vae_state_dict)
|
58 |
if torch.cuda.is_available():
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
|
63 |
-
def
|
64 |
-
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
|
65 |
-
unet_config_path = unet_dir / "config.json"
|
66 |
-
transformer_config = Transformer3DModel.load_config(unet_config_path)
|
67 |
-
transformer = Transformer3DModel.from_config(transformer_config)
|
68 |
-
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
69 |
-
transformer.load_state_dict(unet_state_dict, strict=True)
|
70 |
if torch.cuda.is_available():
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
h, w, _ = frame.shape
|
83 |
aspect_ratio_target = target_width / target_height
|
84 |
-
aspect_ratio_frame =
|
85 |
if aspect_ratio_frame > aspect_ratio_target:
|
86 |
-
new_width = int(
|
87 |
-
|
88 |
-
|
|
|
89 |
else:
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
105 |
-
if target_height is not None:
|
106 |
-
frame_resized = center_crop_and_resize(
|
107 |
-
frame_rgb, target_height, target_width
|
108 |
-
)
|
109 |
-
else:
|
110 |
-
frame_resized = frame_rgb
|
111 |
-
frames.append(frame_resized)
|
112 |
-
cap.release()
|
113 |
-
video_np = (np.array(frames) / 127.5) - 1.0
|
114 |
-
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
115 |
-
return video_tensor
|
116 |
-
|
117 |
-
|
118 |
-
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
119 |
-
image = Image.open(image_path).convert("RGB")
|
120 |
-
image_np = np.array(image)
|
121 |
-
frame_resized = center_crop_and_resize(image_np, target_height, target_width)
|
122 |
-
frame_tensor = torch.tensor(frame_resized).permute(2, 0, 1).float()
|
123 |
frame_tensor = (frame_tensor / 127.5) - 1.0
|
124 |
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
125 |
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
126 |
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def main():
|
129 |
parser = argparse.ArgumentParser(
|
130 |
description="Load models from separate directories and run the pipeline."
|
131 |
)
|
132 |
|
133 |
# Directories
|
134 |
-
parser.add_argument(
|
135 |
-
"--ckpt_dir",
|
136 |
-
type=str,
|
137 |
-
required=True,
|
138 |
-
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
139 |
-
)
|
140 |
-
parser.add_argument(
|
141 |
-
"--input_video_path",
|
142 |
-
type=str,
|
143 |
-
help="Path to the input video file (first frame used)",
|
144 |
-
)
|
145 |
-
parser.add_argument(
|
146 |
-
"--input_image_path", type=str, help="Path to the input image file"
|
147 |
-
)
|
148 |
parser.add_argument(
|
149 |
"--output_path",
|
150 |
type=str,
|
151 |
default=None,
|
152 |
-
help="Path to save output video, if None will save in
|
153 |
)
|
154 |
parser.add_argument("--seed", type=int, default="171198")
|
155 |
|
156 |
# Pipeline parameters
|
157 |
-
parser.add_argument(
|
158 |
-
"--num_inference_steps", type=int, default=40, help="Number of inference steps"
|
159 |
-
)
|
160 |
parser.add_argument(
|
161 |
"--num_images_per_prompt",
|
162 |
type=int,
|
@@ -164,21 +208,21 @@ def main():
|
|
164 |
help="Number of images per prompt",
|
165 |
)
|
166 |
parser.add_argument(
|
167 |
-
"--
|
168 |
type=float,
|
169 |
-
default=
|
170 |
-
help="
|
171 |
)
|
172 |
parser.add_argument(
|
173 |
"--height",
|
174 |
type=int,
|
175 |
-
default=
|
176 |
help="Height of the output video frames. Optional if an input image provided.",
|
177 |
)
|
178 |
parser.add_argument(
|
179 |
"--width",
|
180 |
type=int,
|
181 |
-
default=
|
182 |
help="Width of the output video frames. If None will infer from input image.",
|
183 |
)
|
184 |
parser.add_argument(
|
@@ -188,13 +232,18 @@ def main():
|
|
188 |
help="Number of frames to generate in the output video",
|
189 |
)
|
190 |
parser.add_argument(
|
191 |
-
"--frame_rate", type=int, default=
|
192 |
)
|
193 |
-
|
194 |
parser.add_argument(
|
195 |
-
"--
|
196 |
-
|
197 |
-
help="
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
)
|
199 |
|
200 |
# Prompts
|
@@ -209,161 +258,517 @@ def main():
|
|
209 |
default="worst quality, inconsistent motion, blurry, jittery, distorted",
|
210 |
help="Negative prompt for undesired features",
|
211 |
)
|
|
|
212 |
parser.add_argument(
|
213 |
-
"--
|
214 |
action="store_true",
|
215 |
-
|
216 |
-
help="Enable custom resolution (not in recommneded resolutions) if specified (default: False)",
|
217 |
)
|
218 |
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
else:
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
|
241 |
-
assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
|
242 |
-
assert (
|
243 |
-
height,
|
244 |
-
width,
|
245 |
-
args.num_frames,
|
246 |
-
) in RECOMMENDED_RESOLUTIONS or args.custom_resolution, f"The selected resolution + num frames combination is not supported, results would be suboptimal. Supported (h,w,f) are: {RECOMMENDED_RESOLUTIONS}. Use --custom_resolution to enable working with this resolution."
|
247 |
-
|
248 |
-
# Paths for the separate mode directories
|
249 |
-
ckpt_dir = Path(args.ckpt_dir)
|
250 |
-
unet_dir = ckpt_dir / "unet"
|
251 |
-
vae_dir = ckpt_dir / "vae"
|
252 |
-
scheduler_dir = ckpt_dir / "scheduler"
|
253 |
-
|
254 |
-
# Load models
|
255 |
-
vae = load_vae(vae_dir)
|
256 |
-
unet = load_unet(unet_dir)
|
257 |
-
scheduler = load_scheduler(scheduler_dir)
|
258 |
-
patchifier = SymmetricPatchifier(patch_size=1)
|
259 |
text_encoder = T5EncoderModel.from_pretrained(
|
260 |
-
|
261 |
)
|
262 |
-
|
263 |
-
text_encoder = text_encoder.to("cuda")
|
264 |
tokenizer = T5Tokenizer.from_pretrained(
|
265 |
-
|
266 |
)
|
267 |
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
# Use submodels for the pipeline
|
272 |
submodel_dict = {
|
273 |
-
"transformer":
|
274 |
"patchifier": patchifier,
|
275 |
"text_encoder": text_encoder,
|
276 |
"tokenizer": tokenizer,
|
277 |
"scheduler": scheduler,
|
278 |
"vae": vae,
|
|
|
|
|
|
|
|
|
|
|
279 |
}
|
280 |
|
281 |
-
pipeline =
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
# Prepare input for the pipeline
|
286 |
sample = {
|
287 |
-
"prompt":
|
288 |
"prompt_attention_mask": None,
|
289 |
-
"negative_prompt":
|
290 |
"negative_prompt_attention_mask": None,
|
291 |
-
"media_items": media_items,
|
292 |
}
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
torch.manual_seed(args.seed)
|
297 |
-
if torch.cuda.is_available():
|
298 |
-
torch.cuda.manual_seed(args.seed)
|
299 |
-
|
300 |
-
generator = torch.Generator(
|
301 |
-
device="cuda" if torch.cuda.is_available() else "cpu"
|
302 |
-
).manual_seed(args.seed)
|
303 |
|
304 |
images = pipeline(
|
305 |
-
|
306 |
-
|
307 |
-
guidance_scale=args.guidance_scale,
|
308 |
generator=generator,
|
309 |
output_type="pt",
|
310 |
callback_on_step_end=None,
|
311 |
-
height=
|
312 |
-
width=
|
313 |
-
num_frames=
|
314 |
-
frame_rate=
|
315 |
**sample,
|
|
|
|
|
316 |
is_video=True,
|
317 |
vae_per_channel_normalize=True,
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
mixed_precision=not args.bfloat16,
|
324 |
).images
|
325 |
|
326 |
-
#
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
|
336 |
for i in range(images.shape[0]):
|
337 |
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
|
338 |
video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
|
339 |
# Unnormalizing images to [0, 255] range
|
340 |
video_np = (video_np * 255).astype(np.uint8)
|
341 |
-
fps =
|
342 |
height, width = video_np.shape[1:3]
|
|
|
343 |
if video_np.shape[0] == 1:
|
344 |
-
output_filename = (
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
348 |
)
|
349 |
-
|
350 |
-
output_filename, video_np[0][..., ::-1]
|
351 |
-
) # Save single frame as image
|
352 |
else:
|
353 |
-
output_filename = (
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
|
|
357 |
)
|
358 |
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
|
368 |
if __name__ == "__main__":
|
369 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
|
|
2 |
import os
|
3 |
+
import random
|
4 |
+
from datetime import datetime
|
5 |
+
from pathlib import Path
|
6 |
+
from diffusers.utils import logging
|
7 |
+
from typing import Optional, List, Union
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
import imageio
|
11 |
+
import json
|
12 |
import numpy as np
|
13 |
+
import torch
|
14 |
import cv2
|
15 |
+
from safetensors import safe_open
|
16 |
from PIL import Image
|
17 |
+
from transformers import (
|
18 |
+
T5EncoderModel,
|
19 |
+
T5Tokenizer,
|
20 |
+
AutoModelForCausalLM,
|
21 |
+
AutoProcessor,
|
22 |
+
AutoTokenizer,
|
23 |
+
)
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
|
26 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
27 |
+
CausalVideoAutoencoder,
|
28 |
+
)
|
29 |
+
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
30 |
+
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
31 |
+
from ltx_video.pipelines.pipeline_ltx_video import (
|
32 |
+
ConditioningItem,
|
33 |
+
LTXVideoPipeline,
|
34 |
+
LTXMultiScalePipeline,
|
35 |
+
)
|
36 |
+
from ltx_video.schedulers.rf import RectifiedFlowScheduler
|
37 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
38 |
+
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
|
39 |
+
import ltx_video.pipelines.crf_compressor as crf_compressor
|
40 |
+
|
41 |
+
MAX_HEIGHT = 720
|
42 |
+
MAX_WIDTH = 1280
|
43 |
+
MAX_NUM_FRAMES = 257
|
44 |
+
|
45 |
+
logger = logging.get_logger("LTX-Video")
|
46 |
+
|
47 |
+
|
48 |
+
def get_total_gpu_memory():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
if torch.cuda.is_available():
|
50 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
51 |
+
return total_memory
|
52 |
+
return 0
|
53 |
|
54 |
|
55 |
+
def get_device():
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
if torch.cuda.is_available():
|
57 |
+
return "cuda"
|
58 |
+
elif torch.backends.mps.is_available():
|
59 |
+
return "mps"
|
60 |
+
return "cpu"
|
61 |
+
|
62 |
+
|
63 |
+
def load_image_to_tensor_with_resize_and_crop(
|
64 |
+
image_input: Union[str, Image.Image],
|
65 |
+
target_height: int = 512,
|
66 |
+
target_width: int = 768,
|
67 |
+
just_crop: bool = False,
|
68 |
+
) -> torch.Tensor:
|
69 |
+
"""Load and process an image into a tensor.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
image_input: Either a file path (str) or a PIL Image object
|
73 |
+
target_height: Desired height of output tensor
|
74 |
+
target_width: Desired width of output tensor
|
75 |
+
just_crop: If True, only crop the image to the target size without resizing
|
76 |
+
"""
|
77 |
+
if isinstance(image_input, str):
|
78 |
+
image = Image.open(image_input).convert("RGB")
|
79 |
+
elif isinstance(image_input, Image.Image):
|
80 |
+
image = image_input
|
81 |
+
else:
|
82 |
+
raise ValueError("image_input must be either a file path or a PIL Image object")
|
83 |
|
84 |
+
input_width, input_height = image.size
|
|
|
85 |
aspect_ratio_target = target_width / target_height
|
86 |
+
aspect_ratio_frame = input_width / input_height
|
87 |
if aspect_ratio_frame > aspect_ratio_target:
|
88 |
+
new_width = int(input_height * aspect_ratio_target)
|
89 |
+
new_height = input_height
|
90 |
+
x_start = (input_width - new_width) // 2
|
91 |
+
y_start = 0
|
92 |
else:
|
93 |
+
new_width = input_width
|
94 |
+
new_height = int(input_width / aspect_ratio_target)
|
95 |
+
x_start = 0
|
96 |
+
y_start = (input_height - new_height) // 2
|
97 |
+
|
98 |
+
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
|
99 |
+
if not just_crop:
|
100 |
+
image = image.resize((target_width, target_height))
|
101 |
+
|
102 |
+
image = np.array(image)
|
103 |
+
image = cv2.GaussianBlur(image, (3, 3), 0)
|
104 |
+
frame_tensor = torch.from_numpy(image).float()
|
105 |
+
frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
|
106 |
+
frame_tensor = frame_tensor.permute(2, 0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
frame_tensor = (frame_tensor / 127.5) - 1.0
|
108 |
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
109 |
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
110 |
|
111 |
|
112 |
+
def calculate_padding(
|
113 |
+
source_height: int, source_width: int, target_height: int, target_width: int
|
114 |
+
) -> tuple[int, int, int, int]:
|
115 |
+
|
116 |
+
# Calculate total padding needed
|
117 |
+
pad_height = target_height - source_height
|
118 |
+
pad_width = target_width - source_width
|
119 |
+
|
120 |
+
# Calculate padding for each side
|
121 |
+
pad_top = pad_height // 2
|
122 |
+
pad_bottom = pad_height - pad_top # Handles odd padding
|
123 |
+
pad_left = pad_width // 2
|
124 |
+
pad_right = pad_width - pad_left # Handles odd padding
|
125 |
+
|
126 |
+
# Return padded tensor
|
127 |
+
# Padding format is (left, right, top, bottom)
|
128 |
+
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
129 |
+
return padding
|
130 |
+
|
131 |
+
|
132 |
+
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
|
133 |
+
# Remove non-letters and convert to lowercase
|
134 |
+
clean_text = "".join(
|
135 |
+
char.lower() for char in text if char.isalpha() or char.isspace()
|
136 |
+
)
|
137 |
+
|
138 |
+
# Split into words
|
139 |
+
words = clean_text.split()
|
140 |
+
|
141 |
+
# Build result string keeping track of length
|
142 |
+
result = []
|
143 |
+
current_length = 0
|
144 |
+
|
145 |
+
for word in words:
|
146 |
+
# Add word length plus 1 for underscore (except for first word)
|
147 |
+
new_length = current_length + len(word)
|
148 |
+
|
149 |
+
if new_length <= max_len:
|
150 |
+
result.append(word)
|
151 |
+
current_length += len(word)
|
152 |
+
else:
|
153 |
+
break
|
154 |
+
|
155 |
+
return "-".join(result)
|
156 |
+
|
157 |
+
|
158 |
+
# Generate output video name
|
159 |
+
def get_unique_filename(
|
160 |
+
base: str,
|
161 |
+
ext: str,
|
162 |
+
prompt: str,
|
163 |
+
seed: int,
|
164 |
+
resolution: tuple[int, int, int],
|
165 |
+
dir: Path,
|
166 |
+
endswith=None,
|
167 |
+
index_range=1000,
|
168 |
+
) -> Path:
|
169 |
+
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
|
170 |
+
for i in range(index_range):
|
171 |
+
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
|
172 |
+
if not os.path.exists(filename):
|
173 |
+
return filename
|
174 |
+
raise FileExistsError(
|
175 |
+
f"Could not find a unique filename after {index_range} attempts."
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
def seed_everething(seed: int):
|
180 |
+
random.seed(seed)
|
181 |
+
np.random.seed(seed)
|
182 |
+
torch.manual_seed(seed)
|
183 |
+
if torch.cuda.is_available():
|
184 |
+
torch.cuda.manual_seed(seed)
|
185 |
+
if torch.backends.mps.is_available():
|
186 |
+
torch.mps.manual_seed(seed)
|
187 |
+
|
188 |
+
|
189 |
def main():
|
190 |
parser = argparse.ArgumentParser(
|
191 |
description="Load models from separate directories and run the pipeline."
|
192 |
)
|
193 |
|
194 |
# Directories
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
parser.add_argument(
|
196 |
"--output_path",
|
197 |
type=str,
|
198 |
default=None,
|
199 |
+
help="Path to the folder to save output video, if None will save in outputs/ directory.",
|
200 |
)
|
201 |
parser.add_argument("--seed", type=int, default="171198")
|
202 |
|
203 |
# Pipeline parameters
|
|
|
|
|
|
|
204 |
parser.add_argument(
|
205 |
"--num_images_per_prompt",
|
206 |
type=int,
|
|
|
208 |
help="Number of images per prompt",
|
209 |
)
|
210 |
parser.add_argument(
|
211 |
+
"--image_cond_noise_scale",
|
212 |
type=float,
|
213 |
+
default=0.15,
|
214 |
+
help="Amount of noise to add to the conditioned image",
|
215 |
)
|
216 |
parser.add_argument(
|
217 |
"--height",
|
218 |
type=int,
|
219 |
+
default=704,
|
220 |
help="Height of the output video frames. Optional if an input image provided.",
|
221 |
)
|
222 |
parser.add_argument(
|
223 |
"--width",
|
224 |
type=int,
|
225 |
+
default=1216,
|
226 |
help="Width of the output video frames. If None will infer from input image.",
|
227 |
)
|
228 |
parser.add_argument(
|
|
|
232 |
help="Number of frames to generate in the output video",
|
233 |
)
|
234 |
parser.add_argument(
|
235 |
+
"--frame_rate", type=int, default=30, help="Frame rate for the output video"
|
236 |
)
|
|
|
237 |
parser.add_argument(
|
238 |
+
"--device",
|
239 |
+
default=None,
|
240 |
+
help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--pipeline_config",
|
244 |
+
type=str,
|
245 |
+
default="configs/ltxv-13b-0.9.7-dev.yaml",
|
246 |
+
help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
|
247 |
)
|
248 |
|
249 |
# Prompts
|
|
|
258 |
default="worst quality, inconsistent motion, blurry, jittery, distorted",
|
259 |
help="Negative prompt for undesired features",
|
260 |
)
|
261 |
+
|
262 |
parser.add_argument(
|
263 |
+
"--offload_to_cpu",
|
264 |
action="store_true",
|
265 |
+
help="Offloading unnecessary computations to CPU.",
|
|
|
266 |
)
|
267 |
|
268 |
+
# video-to-video arguments:
|
269 |
+
parser.add_argument(
|
270 |
+
"--input_media_path",
|
271 |
+
type=str,
|
272 |
+
default=None,
|
273 |
+
help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
|
274 |
+
)
|
275 |
|
276 |
+
# Conditioning arguments
|
277 |
+
parser.add_argument(
|
278 |
+
"--conditioning_media_paths",
|
279 |
+
type=str,
|
280 |
+
nargs="*",
|
281 |
+
help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--conditioning_strengths",
|
285 |
+
type=float,
|
286 |
+
nargs="*",
|
287 |
+
help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
|
288 |
+
)
|
289 |
+
parser.add_argument(
|
290 |
+
"--conditioning_start_frames",
|
291 |
+
type=int,
|
292 |
+
nargs="*",
|
293 |
+
help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
|
294 |
+
)
|
295 |
+
|
296 |
+
args = parser.parse_args()
|
297 |
+
logger.warning(f"Running generation with arguments: {args}")
|
298 |
+
infer(**vars(args))
|
299 |
+
|
300 |
+
|
301 |
+
def create_ltx_video_pipeline(
|
302 |
+
ckpt_path: str,
|
303 |
+
precision: str,
|
304 |
+
text_encoder_model_name_or_path: str,
|
305 |
+
sampler: Optional[str] = None,
|
306 |
+
device: Optional[str] = None,
|
307 |
+
enhance_prompt: bool = False,
|
308 |
+
prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
|
309 |
+
prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
|
310 |
+
) -> LTXVideoPipeline:
|
311 |
+
ckpt_path = Path(ckpt_path)
|
312 |
+
assert os.path.exists(
|
313 |
+
ckpt_path
|
314 |
+
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
|
315 |
+
|
316 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
317 |
+
metadata = f.metadata()
|
318 |
+
config_str = metadata.get("config")
|
319 |
+
configs = json.loads(config_str)
|
320 |
+
allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
321 |
+
|
322 |
+
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
323 |
+
transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
324 |
+
|
325 |
+
# Use constructor if sampler is specified, otherwise use from_pretrained
|
326 |
+
if sampler == "from_checkpoint" or not sampler:
|
327 |
+
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
|
328 |
else:
|
329 |
+
scheduler = RectifiedFlowScheduler(
|
330 |
+
sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
|
331 |
+
)
|
332 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
text_encoder = T5EncoderModel.from_pretrained(
|
334 |
+
text_encoder_model_name_or_path, subfolder="text_encoder"
|
335 |
)
|
336 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
|
|
337 |
tokenizer = T5Tokenizer.from_pretrained(
|
338 |
+
text_encoder_model_name_or_path, subfolder="tokenizer"
|
339 |
)
|
340 |
|
341 |
+
transformer = transformer.to(device)
|
342 |
+
vae = vae.to(device)
|
343 |
+
text_encoder = text_encoder.to(device)
|
344 |
+
|
345 |
+
if enhance_prompt:
|
346 |
+
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
|
347 |
+
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
|
348 |
+
)
|
349 |
+
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
|
350 |
+
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
|
351 |
+
)
|
352 |
+
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
|
353 |
+
prompt_enhancer_llm_model_name_or_path,
|
354 |
+
torch_dtype="bfloat16",
|
355 |
+
)
|
356 |
+
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
|
357 |
+
prompt_enhancer_llm_model_name_or_path,
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
prompt_enhancer_image_caption_model = None
|
361 |
+
prompt_enhancer_image_caption_processor = None
|
362 |
+
prompt_enhancer_llm_model = None
|
363 |
+
prompt_enhancer_llm_tokenizer = None
|
364 |
+
|
365 |
+
vae = vae.to(torch.bfloat16)
|
366 |
+
if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
|
367 |
+
transformer = transformer.to(torch.bfloat16)
|
368 |
+
text_encoder = text_encoder.to(torch.bfloat16)
|
369 |
|
370 |
# Use submodels for the pipeline
|
371 |
submodel_dict = {
|
372 |
+
"transformer": transformer,
|
373 |
"patchifier": patchifier,
|
374 |
"text_encoder": text_encoder,
|
375 |
"tokenizer": tokenizer,
|
376 |
"scheduler": scheduler,
|
377 |
"vae": vae,
|
378 |
+
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
|
379 |
+
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
|
380 |
+
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
|
381 |
+
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
|
382 |
+
"allowed_inference_steps": allowed_inference_steps,
|
383 |
}
|
384 |
|
385 |
+
pipeline = LTXVideoPipeline(**submodel_dict)
|
386 |
+
pipeline = pipeline.to(device)
|
387 |
+
return pipeline
|
388 |
+
|
389 |
+
|
390 |
+
def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
|
391 |
+
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
|
392 |
+
latent_upsampler.to(device)
|
393 |
+
latent_upsampler.eval()
|
394 |
+
return latent_upsampler
|
395 |
+
|
396 |
+
|
397 |
+
def infer(
|
398 |
+
output_path: Optional[str],
|
399 |
+
seed: int,
|
400 |
+
pipeline_config: str,
|
401 |
+
image_cond_noise_scale: float,
|
402 |
+
height: Optional[int],
|
403 |
+
width: Optional[int],
|
404 |
+
num_frames: int,
|
405 |
+
frame_rate: int,
|
406 |
+
prompt: str,
|
407 |
+
negative_prompt: str,
|
408 |
+
offload_to_cpu: bool,
|
409 |
+
input_media_path: Optional[str] = None,
|
410 |
+
conditioning_media_paths: Optional[List[str]] = None,
|
411 |
+
conditioning_strengths: Optional[List[float]] = None,
|
412 |
+
conditioning_start_frames: Optional[List[int]] = None,
|
413 |
+
device: Optional[str] = None,
|
414 |
+
**kwargs,
|
415 |
+
):
|
416 |
+
# check if pipeline_config is a file
|
417 |
+
if not os.path.isfile(pipeline_config):
|
418 |
+
raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
|
419 |
+
with open(pipeline_config, "r") as f:
|
420 |
+
pipeline_config = yaml.safe_load(f)
|
421 |
+
|
422 |
+
models_dir = "MODEL_DIR"
|
423 |
+
|
424 |
+
ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
|
425 |
+
if not os.path.isfile(ltxv_model_name_or_path):
|
426 |
+
ltxv_model_path = hf_hub_download(
|
427 |
+
repo_id="Lightricks/LTX-Video",
|
428 |
+
filename=ltxv_model_name_or_path,
|
429 |
+
local_dir=models_dir,
|
430 |
+
repo_type="model",
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
ltxv_model_path = ltxv_model_name_or_path
|
434 |
+
|
435 |
+
spatial_upscaler_model_name_or_path = pipeline_config.get(
|
436 |
+
"spatial_upscaler_model_path"
|
437 |
+
)
|
438 |
+
if spatial_upscaler_model_name_or_path and not os.path.isfile(
|
439 |
+
spatial_upscaler_model_name_or_path
|
440 |
+
):
|
441 |
+
spatial_upscaler_model_path = hf_hub_download(
|
442 |
+
repo_id="Lightricks/LTX-Video",
|
443 |
+
filename=spatial_upscaler_model_name_or_path,
|
444 |
+
local_dir=models_dir,
|
445 |
+
repo_type="model",
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
|
449 |
+
|
450 |
+
if kwargs.get("input_image_path", None):
|
451 |
+
logger.warning(
|
452 |
+
"Please use conditioning_media_paths instead of input_image_path."
|
453 |
+
)
|
454 |
+
assert not conditioning_media_paths and not conditioning_start_frames
|
455 |
+
conditioning_media_paths = [kwargs["input_image_path"]]
|
456 |
+
conditioning_start_frames = [0]
|
457 |
+
|
458 |
+
# Validate conditioning arguments
|
459 |
+
if conditioning_media_paths:
|
460 |
+
# Use default strengths of 1.0
|
461 |
+
if not conditioning_strengths:
|
462 |
+
conditioning_strengths = [1.0] * len(conditioning_media_paths)
|
463 |
+
if not conditioning_start_frames:
|
464 |
+
raise ValueError(
|
465 |
+
"If `conditioning_media_paths` is provided, "
|
466 |
+
"`conditioning_start_frames` must also be provided"
|
467 |
+
)
|
468 |
+
if len(conditioning_media_paths) != len(conditioning_strengths) or len(
|
469 |
+
conditioning_media_paths
|
470 |
+
) != len(conditioning_start_frames):
|
471 |
+
raise ValueError(
|
472 |
+
"`conditioning_media_paths`, `conditioning_strengths`, "
|
473 |
+
"and `conditioning_start_frames` must have the same length"
|
474 |
+
)
|
475 |
+
if any(s < 0 or s > 1 for s in conditioning_strengths):
|
476 |
+
raise ValueError("All conditioning strengths must be between 0 and 1")
|
477 |
+
if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
|
478 |
+
raise ValueError(
|
479 |
+
f"All conditioning start frames must be between 0 and {num_frames-1}"
|
480 |
+
)
|
481 |
+
|
482 |
+
seed_everething(seed)
|
483 |
+
if offload_to_cpu and not torch.cuda.is_available():
|
484 |
+
logger.warning(
|
485 |
+
"offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
|
486 |
+
)
|
487 |
+
offload_to_cpu = False
|
488 |
+
else:
|
489 |
+
offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
|
490 |
+
|
491 |
+
output_dir = (
|
492 |
+
Path(output_path)
|
493 |
+
if output_path
|
494 |
+
else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
|
495 |
+
)
|
496 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
497 |
+
|
498 |
+
# Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
|
499 |
+
height_padded = ((height - 1) // 32 + 1) * 32
|
500 |
+
width_padded = ((width - 1) // 32 + 1) * 32
|
501 |
+
num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
|
502 |
+
|
503 |
+
padding = calculate_padding(height, width, height_padded, width_padded)
|
504 |
+
|
505 |
+
logger.warning(
|
506 |
+
f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
|
507 |
+
)
|
508 |
+
|
509 |
+
prompt_enhancement_words_threshold = pipeline_config[
|
510 |
+
"prompt_enhancement_words_threshold"
|
511 |
+
]
|
512 |
+
|
513 |
+
prompt_word_count = len(prompt.split())
|
514 |
+
enhance_prompt = (
|
515 |
+
prompt_enhancement_words_threshold > 0
|
516 |
+
and prompt_word_count < prompt_enhancement_words_threshold
|
517 |
+
)
|
518 |
+
|
519 |
+
if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
|
520 |
+
logger.info(
|
521 |
+
f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
|
522 |
+
)
|
523 |
+
|
524 |
+
precision = pipeline_config["precision"]
|
525 |
+
text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
|
526 |
+
sampler = pipeline_config["sampler"]
|
527 |
+
prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
|
528 |
+
"prompt_enhancer_image_caption_model_name_or_path"
|
529 |
+
]
|
530 |
+
prompt_enhancer_llm_model_name_or_path = pipeline_config[
|
531 |
+
"prompt_enhancer_llm_model_name_or_path"
|
532 |
+
]
|
533 |
+
|
534 |
+
pipeline = create_ltx_video_pipeline(
|
535 |
+
ckpt_path=ltxv_model_path,
|
536 |
+
precision=precision,
|
537 |
+
text_encoder_model_name_or_path=text_encoder_model_name_or_path,
|
538 |
+
sampler=sampler,
|
539 |
+
device=kwargs.get("device", get_device()),
|
540 |
+
enhance_prompt=enhance_prompt,
|
541 |
+
prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
|
542 |
+
prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
|
543 |
+
)
|
544 |
+
|
545 |
+
if pipeline_config.get("pipeline_type", None) == "multi-scale":
|
546 |
+
if not spatial_upscaler_model_path:
|
547 |
+
raise ValueError(
|
548 |
+
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
|
549 |
+
)
|
550 |
+
latent_upsampler = create_latent_upsampler(
|
551 |
+
spatial_upscaler_model_path, pipeline.device
|
552 |
+
)
|
553 |
+
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
|
554 |
+
|
555 |
+
media_item = None
|
556 |
+
if input_media_path:
|
557 |
+
media_item = load_media_file(
|
558 |
+
media_path=input_media_path,
|
559 |
+
height=height,
|
560 |
+
width=width,
|
561 |
+
max_frames=num_frames_padded,
|
562 |
+
padding=padding,
|
563 |
+
)
|
564 |
+
|
565 |
+
conditioning_items = (
|
566 |
+
prepare_conditioning(
|
567 |
+
conditioning_media_paths=conditioning_media_paths,
|
568 |
+
conditioning_strengths=conditioning_strengths,
|
569 |
+
conditioning_start_frames=conditioning_start_frames,
|
570 |
+
height=height,
|
571 |
+
width=width,
|
572 |
+
num_frames=num_frames,
|
573 |
+
padding=padding,
|
574 |
+
pipeline=pipeline,
|
575 |
+
)
|
576 |
+
if conditioning_media_paths
|
577 |
+
else None
|
578 |
+
)
|
579 |
+
|
580 |
+
stg_mode = pipeline_config.get("stg_mode", "attention_values")
|
581 |
+
del pipeline_config["stg_mode"]
|
582 |
+
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
|
583 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionValues
|
584 |
+
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
|
585 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
|
586 |
+
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
|
587 |
+
skip_layer_strategy = SkipLayerStrategy.Residual
|
588 |
+
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
|
589 |
+
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
|
590 |
+
else:
|
591 |
+
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
|
592 |
|
593 |
# Prepare input for the pipeline
|
594 |
sample = {
|
595 |
+
"prompt": prompt,
|
596 |
"prompt_attention_mask": None,
|
597 |
+
"negative_prompt": negative_prompt,
|
598 |
"negative_prompt_attention_mask": None,
|
|
|
599 |
}
|
600 |
|
601 |
+
device = device or get_device()
|
602 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
|
604 |
images = pipeline(
|
605 |
+
**pipeline_config,
|
606 |
+
skip_layer_strategy=skip_layer_strategy,
|
|
|
607 |
generator=generator,
|
608 |
output_type="pt",
|
609 |
callback_on_step_end=None,
|
610 |
+
height=height_padded,
|
611 |
+
width=width_padded,
|
612 |
+
num_frames=num_frames_padded,
|
613 |
+
frame_rate=frame_rate,
|
614 |
**sample,
|
615 |
+
media_items=media_item,
|
616 |
+
conditioning_items=conditioning_items,
|
617 |
is_video=True,
|
618 |
vae_per_channel_normalize=True,
|
619 |
+
image_cond_noise_scale=image_cond_noise_scale,
|
620 |
+
mixed_precision=(precision == "mixed_precision"),
|
621 |
+
offload_to_cpu=offload_to_cpu,
|
622 |
+
device=device,
|
623 |
+
enhance_prompt=enhance_prompt,
|
|
|
624 |
).images
|
625 |
|
626 |
+
# Crop the padded images to the desired resolution and number of frames
|
627 |
+
(pad_left, pad_right, pad_top, pad_bottom) = padding
|
628 |
+
pad_bottom = -pad_bottom
|
629 |
+
pad_right = -pad_right
|
630 |
+
if pad_bottom == 0:
|
631 |
+
pad_bottom = images.shape[3]
|
632 |
+
if pad_right == 0:
|
633 |
+
pad_right = images.shape[4]
|
634 |
+
images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
|
635 |
|
636 |
for i in range(images.shape[0]):
|
637 |
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
|
638 |
video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
|
639 |
# Unnormalizing images to [0, 255] range
|
640 |
video_np = (video_np * 255).astype(np.uint8)
|
641 |
+
fps = frame_rate
|
642 |
height, width = video_np.shape[1:3]
|
643 |
+
# In case a single image is generated
|
644 |
if video_np.shape[0] == 1:
|
645 |
+
output_filename = get_unique_filename(
|
646 |
+
f"image_output_{i}",
|
647 |
+
".png",
|
648 |
+
prompt=prompt,
|
649 |
+
seed=seed,
|
650 |
+
resolution=(height, width, num_frames),
|
651 |
+
dir=output_dir,
|
652 |
)
|
653 |
+
imageio.imwrite(output_filename, video_np[0])
|
|
|
|
|
654 |
else:
|
655 |
+
output_filename = get_unique_filename(
|
656 |
+
f"video_output_{i}",
|
657 |
+
".mp4",
|
658 |
+
prompt=prompt,
|
659 |
+
seed=seed,
|
660 |
+
resolution=(height, width, num_frames),
|
661 |
+
dir=output_dir,
|
662 |
)
|
663 |
|
664 |
+
# Write video
|
665 |
+
with imageio.get_writer(output_filename, fps=fps) as video:
|
666 |
+
for frame in video_np:
|
667 |
+
video.append_data(frame)
|
668 |
+
|
669 |
+
logger.warning(f"Output saved to {output_filename}")
|
670 |
+
|
671 |
+
|
672 |
+
def prepare_conditioning(
|
673 |
+
conditioning_media_paths: List[str],
|
674 |
+
conditioning_strengths: List[float],
|
675 |
+
conditioning_start_frames: List[int],
|
676 |
+
height: int,
|
677 |
+
width: int,
|
678 |
+
num_frames: int,
|
679 |
+
padding: tuple[int, int, int, int],
|
680 |
+
pipeline: LTXVideoPipeline,
|
681 |
+
) -> Optional[List[ConditioningItem]]:
|
682 |
+
"""Prepare conditioning items based on input media paths and their parameters.
|
683 |
+
|
684 |
+
Args:
|
685 |
+
conditioning_media_paths: List of paths to conditioning media (images or videos)
|
686 |
+
conditioning_strengths: List of conditioning strengths for each media item
|
687 |
+
conditioning_start_frames: List of frame indices where each item should be applied
|
688 |
+
height: Height of the output frames
|
689 |
+
width: Width of the output frames
|
690 |
+
num_frames: Number of frames in the output video
|
691 |
+
padding: Padding to apply to the frames
|
692 |
+
pipeline: LTXVideoPipeline object used for condition video trimming
|
693 |
+
|
694 |
+
Returns:
|
695 |
+
A list of ConditioningItem objects.
|
696 |
+
"""
|
697 |
+
conditioning_items = []
|
698 |
+
for path, strength, start_frame in zip(
|
699 |
+
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
|
700 |
+
):
|
701 |
+
num_input_frames = orig_num_input_frames = get_media_num_frames(path)
|
702 |
+
if hasattr(pipeline, "trim_conditioning_sequence") and callable(
|
703 |
+
getattr(pipeline, "trim_conditioning_sequence")
|
704 |
+
):
|
705 |
+
num_input_frames = pipeline.trim_conditioning_sequence(
|
706 |
+
start_frame, orig_num_input_frames, num_frames
|
707 |
)
|
708 |
+
if num_input_frames < orig_num_input_frames:
|
709 |
+
logger.warning(
|
710 |
+
f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
|
711 |
+
)
|
712 |
+
|
713 |
+
media_tensor = load_media_file(
|
714 |
+
media_path=path,
|
715 |
+
height=height,
|
716 |
+
width=width,
|
717 |
+
max_frames=num_input_frames,
|
718 |
+
padding=padding,
|
719 |
+
just_crop=True,
|
720 |
+
)
|
721 |
+
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
|
722 |
+
return conditioning_items
|
723 |
+
|
724 |
|
725 |
+
def get_media_num_frames(media_path: str) -> int:
|
726 |
+
is_video = any(
|
727 |
+
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
|
728 |
+
)
|
729 |
+
num_frames = 1
|
730 |
+
if is_video:
|
731 |
+
reader = imageio.get_reader(media_path)
|
732 |
+
num_frames = reader.count_frames()
|
733 |
+
reader.close()
|
734 |
+
return num_frames
|
735 |
+
|
736 |
+
|
737 |
+
def load_media_file(
|
738 |
+
media_path: str,
|
739 |
+
height: int,
|
740 |
+
width: int,
|
741 |
+
max_frames: int,
|
742 |
+
padding: tuple[int, int, int, int],
|
743 |
+
just_crop: bool = False,
|
744 |
+
) -> torch.Tensor:
|
745 |
+
is_video = any(
|
746 |
+
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
|
747 |
+
)
|
748 |
+
if is_video:
|
749 |
+
reader = imageio.get_reader(media_path)
|
750 |
+
num_input_frames = min(reader.count_frames(), max_frames)
|
751 |
+
|
752 |
+
# Read and preprocess the relevant frames from the video file.
|
753 |
+
frames = []
|
754 |
+
for i in range(num_input_frames):
|
755 |
+
frame = Image.fromarray(reader.get_data(i))
|
756 |
+
frame_tensor = load_image_to_tensor_with_resize_and_crop(
|
757 |
+
frame, height, width, just_crop=just_crop
|
758 |
+
)
|
759 |
+
frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
|
760 |
+
frames.append(frame_tensor)
|
761 |
+
reader.close()
|
762 |
+
|
763 |
+
# Stack frames along the temporal dimension
|
764 |
+
media_tensor = torch.cat(frames, dim=2)
|
765 |
+
else: # Input image
|
766 |
+
media_tensor = load_image_to_tensor_with_resize_and_crop(
|
767 |
+
media_path, height, width, just_crop=just_crop
|
768 |
+
)
|
769 |
+
media_tensor = torch.nn.functional.pad(media_tensor, padding)
|
770 |
+
return media_tensor
|
771 |
|
772 |
|
773 |
if __name__ == "__main__":
|
774 |
+
main()
|
requirements.txt
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
-
huggingface_hub<=0.25
|
2 |
-
torch
|
3 |
-
diffusers==0.28.2
|
4 |
-
transformers==4.44.2
|
5 |
-
sentencepiece>=0.1.96
|
6 |
accelerate
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
9 |
opencv-python
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
accelerate
|
2 |
+
transformers
|
3 |
+
sentencepiece
|
4 |
+
pillow
|
5 |
+
numpy
|
6 |
+
torchvision
|
7 |
+
huggingface_hub
|
8 |
+
spaces
|
9 |
opencv-python
|
10 |
+
imageio
|
11 |
+
imageio-ffmpeg
|
12 |
+
einops
|
13 |
+
timm
|
14 |
+
av
|
15 |
+
git+https://github.com/huggingface/diffusers.git@main
|