Update handler.py
Browse files- handler.py +103 -265
handler.py
CHANGED
@@ -66,229 +66,37 @@ def print_directory_structure(startpath):
|
|
66 |
logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
|
67 |
apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")
|
68 |
|
69 |
-
logger.info("💡 Printing directory structure of ""/repository"":")
|
70 |
-
print_directory_structure("/repository")
|
71 |
|
72 |
@dataclass
|
73 |
class GenerationConfig:
|
74 |
"""Configuration for video generation"""
|
75 |
-
width: int = 768
|
76 |
-
height: int = 512
|
77 |
-
fps: int = 24
|
78 |
-
duration_sec: float = 4.0
|
79 |
-
num_inference_steps: int = 30
|
80 |
-
guidance_scale: float = 7.5
|
81 |
-
upscale_factor: float = 2.0
|
82 |
-
enable_interpolation: bool = False
|
83 |
-
seed: int = -1 # -1 means random seed
|
84 |
-
|
85 |
-
@property
|
86 |
-
def num_frames(self) -> int:
|
87 |
-
"""Calculate number of frames based on fps and duration"""
|
88 |
-
return int(self.duration_sec * self.fps) + 1
|
89 |
-
|
90 |
-
def validate_and_adjust(self) -> 'GenerationConfig':
|
91 |
-
"""Validate and adjust parameters to meet constraints"""
|
92 |
-
# Round dimensions to nearest multiple of 32
|
93 |
-
self.width = max(32, min(MAX_WIDTH, round(self.width / 32) * 32))
|
94 |
-
self.height = max(32, min(MAX_HEIGHT, round(self.height / 32) * 32))
|
95 |
-
|
96 |
-
# Adjust number of frames to be in format 8k + 1
|
97 |
-
k = (self.num_frames - 1) // 8
|
98 |
-
num_frames = min((k * 8) + 1, MAX_FRAMES)
|
99 |
-
self.duration_sec = (num_frames - 1) / self.fps
|
100 |
-
|
101 |
-
# Set random seed if not specified
|
102 |
-
if self.seed == -1:
|
103 |
-
self.seed = random.randint(0, 2**32 - 1)
|
104 |
-
|
105 |
-
return self
|
106 |
-
|
107 |
-
class EndpointHandler:
|
108 |
-
"""Handles video generation requests using LTX models and Varnish post-processing"""
|
109 |
-
|
110 |
-
def __init__(self, model_path: str = ""):
|
111 |
-
"""Initialize the handler with LTX models and Varnish
|
112 |
-
|
113 |
-
Args:
|
114 |
-
model_path: Path to LTX model weights
|
115 |
-
"""
|
116 |
-
# Enable TF32 for potential speedup on Ampere GPUs
|
117 |
-
#torch.backends.cuda.matmul.allow_tf32 = True
|
118 |
-
|
119 |
-
# Initialize models with bfloat16 precision
|
120 |
-
self.text_to_video = LTXPipeline.from_pretrained(
|
121 |
-
model_path,
|
122 |
-
torch_dtype=torch.bfloat16
|
123 |
-
).to("cuda")
|
124 |
-
|
125 |
-
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
|
126 |
-
model_path,
|
127 |
-
torch_dtype=torch.bfloat16
|
128 |
-
).to("cuda")
|
129 |
-
|
130 |
-
# Enable CPU offload for memory efficiency
|
131 |
-
#self.text_to_video.enable_model_cpu_offload()
|
132 |
-
#self.image_to_video.enable_model_cpu_offload()
|
133 |
-
|
134 |
-
# Initialize Varnish for post-processing
|
135 |
-
self.varnish = Varnish(
|
136 |
-
device="cuda" if torch.cuda.is_available() else "cpu",
|
137 |
-
output_format="mp4",
|
138 |
-
output_codec="h264",
|
139 |
-
output_quality=23,
|
140 |
-
enable_mmaudio=False,
|
141 |
-
#model_base_dir=os.path.abspath(os.path.join(os.getcwd(), "varnish"))
|
142 |
-
model_base_dir="/repository/varnish",
|
143 |
-
)
|
144 |
-
|
145 |
-
async def process_frames(
|
146 |
-
self,
|
147 |
-
frames: torch.Tensor,
|
148 |
-
config: GenerationConfig
|
149 |
-
) -> tuple[str, dict]:
|
150 |
-
"""Post-process generated frames using Varnish
|
151 |
-
|
152 |
-
Args:
|
153 |
-
frames: Generated video frames tensor
|
154 |
-
config: Generation configuration
|
155 |
-
|
156 |
-
Returns:
|
157 |
-
Tuple of (video data URI, metadata dictionary)
|
158 |
-
"""
|
159 |
-
try:
|
160 |
-
logger.info(f"Original frames shape: {frames.shape}")
|
161 |
-
|
162 |
-
# Remove batch dimension if present
|
163 |
-
if len(frames.shape) == 5:
|
164 |
-
frames = frames.squeeze(0) # Remove batch dimension
|
165 |
-
|
166 |
-
logger.info(f"Processed frames shape: {frames.shape}")
|
167 |
-
|
168 |
-
# Process video with Varnish
|
169 |
-
result = await self.varnish(
|
170 |
-
input_data=frames,
|
171 |
-
input_fps=config.fps,
|
172 |
-
output_fps=config.fps,
|
173 |
-
upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None,
|
174 |
-
enable_interpolation=config.enable_interpolation
|
175 |
-
)
|
176 |
-
|
177 |
-
# Convert to data URI
|
178 |
-
video_uri = await result.write(
|
179 |
-
output_type="data-uri",
|
180 |
-
output_format="mp4",
|
181 |
-
output_codec="h264",
|
182 |
-
output_quality=23
|
183 |
-
)
|
184 |
-
|
185 |
-
# Collect metadata
|
186 |
-
metadata = {
|
187 |
-
"width": result.metadata.width,
|
188 |
-
"height": result.metadata.height,
|
189 |
-
"num_frames": result.metadata.frame_count,
|
190 |
-
"fps": result.metadata.fps,
|
191 |
-
"duration": result.metadata.duration,
|
192 |
-
"num_inference_steps": config.num_inference_steps,
|
193 |
-
"seed": config.seed,
|
194 |
-
"upscale_factor": config.upscale_factor,
|
195 |
-
"interpolation_enabled": config.enable_interpolation
|
196 |
-
}
|
197 |
-
|
198 |
-
return video_uri, metadata
|
199 |
-
|
200 |
-
except Exception as e:
|
201 |
-
logger.error(f"Error in process_frames: {str(e)}")
|
202 |
-
raise RuntimeError(f"Failed to process frames: {str(e)}")
|
203 |
-
|
204 |
-
from dataclasses import dataclass
|
205 |
-
from pathlib import Path
|
206 |
-
import pathlib
|
207 |
-
from typing import Dict, Any, Optional, Tuple
|
208 |
-
import asyncio
|
209 |
-
import base64
|
210 |
-
import io
|
211 |
-
import pprint
|
212 |
-
import logging
|
213 |
-
import random
|
214 |
-
import traceback
|
215 |
-
import os
|
216 |
-
import numpy as np
|
217 |
-
import torch
|
218 |
-
from diffusers import LTXPipeline, LTXImageToVideoPipeline
|
219 |
-
from PIL import Image
|
220 |
-
|
221 |
-
from varnish import Varnish
|
222 |
-
|
223 |
-
# Configure logging
|
224 |
-
logging.basicConfig(level=logging.INFO)
|
225 |
-
logger = logging.getLogger(__name__)
|
226 |
-
|
227 |
-
# Constraints
|
228 |
-
MAX_WIDTH = 1280
|
229 |
-
MAX_HEIGHT = 720
|
230 |
-
MAX_FRAMES = 257
|
231 |
-
|
232 |
-
# this is only a temporary solution (famous last words)
|
233 |
-
def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory):
|
234 |
-
"""
|
235 |
-
Recursively rename all '.wut' files to '.pth' in the given directory
|
236 |
-
|
237 |
-
Args:
|
238 |
-
directory (str): Path to the directory to process
|
239 |
-
"""
|
240 |
-
# Convert the directory path to absolute path
|
241 |
-
directory = os.path.abspath(directory)
|
242 |
-
|
243 |
-
# Walk through directory and its subdirectories
|
244 |
-
for root, _, files in os.walk(directory):
|
245 |
-
for filename in files:
|
246 |
-
if filename.endswith('.wut'):
|
247 |
-
# Get full path of the file
|
248 |
-
old_path = os.path.join(root, filename)
|
249 |
-
# Create new filename by replacing the extension
|
250 |
-
new_filename = filename.replace('.wut', '.pth')
|
251 |
-
new_path = os.path.join(root, new_filename)
|
252 |
-
|
253 |
-
try:
|
254 |
-
os.rename(old_path, new_path)
|
255 |
-
print(f"Renamed: {old_path} -> {new_path}")
|
256 |
-
except OSError as e:
|
257 |
-
print(f"Error renaming {old_path}: {e}")
|
258 |
-
|
259 |
-
def print_directory_structure(startpath):
|
260 |
-
"""Print the directory structure starting from the given path."""
|
261 |
-
for root, dirs, files in os.walk(startpath):
|
262 |
-
level = root.replace(startpath, '').count(os.sep)
|
263 |
-
indent = ' ' * 4 * level
|
264 |
-
logger.info(f"{indent}{os.path.basename(root)}/")
|
265 |
-
subindent = ' ' * 4 * (level + 1)
|
266 |
-
for f in files:
|
267 |
-
logger.info(f"{subindent}{f}")
|
268 |
-
|
269 |
-
logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
|
270 |
-
apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")
|
271 |
|
272 |
-
|
273 |
-
|
|
|
274 |
|
275 |
-
|
276 |
-
class GenerationConfig:
|
277 |
-
"""Configuration for video generation"""
|
278 |
width: int = 768
|
279 |
height: int = 512
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
283 |
guidance_scale: float = 7.5
|
284 |
-
|
285 |
-
|
|
|
286 |
seed: int = -1 # -1 means random seed
|
287 |
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
292 |
|
293 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
294 |
"""Validate and adjust parameters to meet constraints"""
|
@@ -299,7 +107,7 @@ class GenerationConfig:
|
|
299 |
# Adjust number of frames to be in format 8k + 1
|
300 |
k = (self.num_frames - 1) // 8
|
301 |
num_frames = min((k * 8) + 1, MAX_FRAMES)
|
302 |
-
|
303 |
|
304 |
# Set random seed if not specified
|
305 |
if self.seed == -1:
|
@@ -339,9 +147,8 @@ class EndpointHandler:
|
|
339 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
340 |
output_format="mp4",
|
341 |
output_codec="h264",
|
342 |
-
output_quality=
|
343 |
enable_mmaudio=False,
|
344 |
-
#model_base_dir=os.path.abspath(os.path.join(os.getcwd(), "varnish"))
|
345 |
model_base_dir="/repository/varnish",
|
346 |
)
|
347 |
|
@@ -367,22 +174,22 @@ class EndpointHandler:
|
|
367 |
frames = frames.squeeze(0) # Remove batch dimension
|
368 |
|
369 |
logger.info(f"Processed frames shape: {frames.shape}")
|
370 |
-
|
371 |
# Process video with Varnish
|
372 |
result = await self.varnish(
|
373 |
input_data=frames,
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
enable_interpolation=config.enable_interpolation
|
378 |
)
|
379 |
|
380 |
# Convert to data URI
|
381 |
video_uri = await result.write(
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
386 |
)
|
387 |
|
388 |
# Collect metadata
|
@@ -392,10 +199,7 @@ class EndpointHandler:
|
|
392 |
"num_frames": result.metadata.frame_count,
|
393 |
"fps": result.metadata.fps,
|
394 |
"duration": result.metadata.duration,
|
395 |
-
"num_inference_steps": config.num_inference_steps,
|
396 |
"seed": config.seed,
|
397 |
-
"upscale_factor": config.upscale_factor,
|
398 |
-
"interpolation_enabled": config.enable_interpolation
|
399 |
}
|
400 |
|
401 |
return video_uri, metadata
|
@@ -404,45 +208,72 @@ class EndpointHandler:
|
|
404 |
logger.error(f"Error in process_frames: {str(e)}")
|
405 |
raise RuntimeError(f"Failed to process frames: {str(e)}")
|
406 |
|
|
|
407 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
408 |
"""Process incoming requests for video generation
|
409 |
|
410 |
Args:
|
411 |
data: Request data containing:
|
412 |
-
- inputs (
|
413 |
-
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
|
|
|
|
|
|
423 |
Returns:
|
424 |
Dictionary containing:
|
425 |
- video: Base64 encoded MP4 data URI
|
426 |
- content-type: MIME type
|
427 |
- metadata: Generation metadata
|
428 |
"""
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
|
434 |
# Create and validate configuration
|
435 |
config = GenerationConfig(
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
).validate_and_adjust()
|
|
|
|
|
|
|
446 |
|
447 |
try:
|
448 |
with torch.no_grad():
|
@@ -451,28 +282,35 @@ class EndpointHandler:
|
|
451 |
np.random.seed(config.seed)
|
452 |
generator = torch.manual_seed(config.seed)
|
453 |
|
454 |
-
# Prepare generation parameters
|
455 |
generation_kwargs = {
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
"output_type": "pt",
|
463 |
"generator": generator
|
464 |
}
|
465 |
-
|
466 |
-
logger.info(f"Parameters:")
|
467 |
pprint.pprint(generation_kwargs)
|
468 |
|
469 |
# Check if image-to-video generation is requested
|
470 |
-
|
471 |
-
if image_data:
|
472 |
# Process base64 image
|
473 |
-
if
|
474 |
-
|
475 |
-
image_bytes = base64.b64decode(
|
476 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
477 |
generation_kwargs["image"] = image
|
478 |
frames = self.image_to_video(**generation_kwargs).frames
|
|
|
66 |
logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
|
67 |
apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")
|
68 |
|
69 |
+
#logger.info("💡 Printing directory structure of ""/repository"":")
|
70 |
+
#print_directory_structure("/repository")
|
71 |
|
72 |
@dataclass
|
73 |
class GenerationConfig:
|
74 |
"""Configuration for video generation"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
# general content settings
|
77 |
+
prompt: str = ""
|
78 |
+
negative_prompt: str = "worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres",
|
79 |
|
80 |
+
# video model settings (will be used during generation of the initial raw video clip)
|
|
|
|
|
81 |
width: int = 768
|
82 |
height: int = 512
|
83 |
+
|
84 |
+
# users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
|
85 |
+
# The value must be a multiple of 8, plus 1 frame.
|
86 |
+
num_frames: int = 129
|
87 |
+
|
88 |
guidance_scale: float = 7.5
|
89 |
+
num_inference_steps: int = 50
|
90 |
+
|
91 |
+
# reproducible generation settings
|
92 |
seed: int = -1 # -1 means random seed
|
93 |
|
94 |
+
# varnish settings (will be used for post-processing after the raw video clip has been generated
|
95 |
+
fps: int = 24 # FPS of the final video (only applied at the the very end, when converting to mp4)
|
96 |
+
double_num_frames: bool = True # if True, the number of frames will be multiplied by 2 using RIFE
|
97 |
+
super_resolution: bool = True # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
98 |
+
|
99 |
+
grain_amount: float = 0.0
|
100 |
|
101 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
102 |
"""Validate and adjust parameters to meet constraints"""
|
|
|
107 |
# Adjust number of frames to be in format 8k + 1
|
108 |
k = (self.num_frames - 1) // 8
|
109 |
num_frames = min((k * 8) + 1, MAX_FRAMES)
|
110 |
+
|
111 |
|
112 |
# Set random seed if not specified
|
113 |
if self.seed == -1:
|
|
|
147 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
148 |
output_format="mp4",
|
149 |
output_codec="h264",
|
150 |
+
output_quality=17,
|
151 |
enable_mmaudio=False,
|
|
|
152 |
model_base_dir="/repository/varnish",
|
153 |
)
|
154 |
|
|
|
174 |
frames = frames.squeeze(0) # Remove batch dimension
|
175 |
|
176 |
logger.info(f"Processed frames shape: {frames.shape}")
|
177 |
+
|
178 |
# Process video with Varnish
|
179 |
result = await self.varnish(
|
180 |
input_data=frames,
|
181 |
+
double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
|
182 |
+
super_resolution=config.grain_amount_config, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
183 |
+
grain_amount_config.grain_amount,
|
|
|
184 |
)
|
185 |
|
186 |
# Convert to data URI
|
187 |
video_uri = await result.write(
|
188 |
+
type="data-uri",
|
189 |
+
format="mp4",
|
190 |
+
codec="h264",
|
191 |
+
fps=config.fps,
|
192 |
+
quality=23
|
193 |
)
|
194 |
|
195 |
# Collect metadata
|
|
|
199 |
"num_frames": result.metadata.frame_count,
|
200 |
"fps": result.metadata.fps,
|
201 |
"duration": result.metadata.duration,
|
|
|
202 |
"seed": config.seed,
|
|
|
|
|
203 |
}
|
204 |
|
205 |
return video_uri, metadata
|
|
|
208 |
logger.error(f"Error in process_frames: {str(e)}")
|
209 |
raise RuntimeError(f"Failed to process frames: {str(e)}")
|
210 |
|
211 |
+
|
212 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
213 |
"""Process incoming requests for video generation
|
214 |
|
215 |
Args:
|
216 |
data: Request data containing:
|
217 |
+
- inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
|
218 |
+
- parameters (dict):
|
219 |
+
- prompt (required, string): list of concepts to keep in the video.
|
220 |
+
- negative_prompt (optional, string): list of concepts to ignore in the video.
|
221 |
+
- width (optional, int, default to 768): width, or horizontal size in pixels.
|
222 |
+
- height (optional, int, default to 512): height, or vertical size in pixels.
|
223 |
+
- num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
|
224 |
+
- guidance_scale (optional, float, default to 7.5): Guidance scale
|
225 |
+
- num_inference_steps (optional, int, default to 50): number of inference steps
|
226 |
+
- seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
|
227 |
+
- fps (optional, int, default to 24): FPS of the final video
|
228 |
+
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
|
229 |
+
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
|
230 |
+
- grain_amount (optional, float): amount of film grain to add to the output video
|
231 |
Returns:
|
232 |
Dictionary containing:
|
233 |
- video: Base64 encoded MP4 data URI
|
234 |
- content-type: MIME type
|
235 |
- metadata: Generation metadata
|
236 |
"""
|
237 |
+
inputs = data.get("inputs", dict())
|
238 |
+
|
239 |
+
input_prompt = inputs.get("prompt", "")
|
240 |
+
input_image = inputs.get("image")
|
241 |
+
|
242 |
+
params = data.get("parameters", dict())
|
243 |
+
|
244 |
+
if not input_prompt:
|
245 |
+
raise ValueError("The prompt should not be empty")
|
246 |
+
|
247 |
+
logger.info(f"Prompt: {input_prompt}")
|
248 |
+
|
249 |
+
logger.info(f"Raw parameters:")
|
250 |
+
pprint.pprint(params)
|
251 |
|
252 |
# Create and validate configuration
|
253 |
config = GenerationConfig(
|
254 |
+
# general content settings
|
255 |
+
prompt: input_prompt,
|
256 |
+
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
|
257 |
+
|
258 |
+
# video model settings (will be used during generation of the initial raw video clip)
|
259 |
+
width=params.get("width", GenerationConfig.width),
|
260 |
+
height=params.get("height", GenerationConfig.height),
|
261 |
+
num_frames=params.get"num_frames", GenerationConfig.num_frames),
|
262 |
+
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
|
263 |
+
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
|
264 |
+
|
265 |
+
# reproducible generation settings
|
266 |
+
seed=params.get("seed", GenerationConfig.seed)
|
267 |
+
|
268 |
+
# varnish settings (will be used for post-processing after the raw video clip has been generated)
|
269 |
+
fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
|
270 |
+
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
|
271 |
+
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
272 |
+
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
273 |
).validate_and_adjust()
|
274 |
+
|
275 |
+
logger.info(f"Global request settings:")
|
276 |
+
pprint.pprint(config)
|
277 |
|
278 |
try:
|
279 |
with torch.no_grad():
|
|
|
282 |
np.random.seed(config.seed)
|
283 |
generator = torch.manual_seed(config.seed)
|
284 |
|
285 |
+
# Prepare generation parameters for the video model (we omit params that are destined to Varnish)
|
286 |
generation_kwargs = {
|
287 |
+
# general content settings
|
288 |
+
prompt: config.prompt,
|
289 |
+
negative_prompt=config.negative_prompt,
|
290 |
+
|
291 |
+
# video model settings (will be used during generation of the initial raw video clip)
|
292 |
+
width=params.config.width,
|
293 |
+
height=config.height,
|
294 |
+
num_frames=config.num_frames,
|
295 |
+
guidance_scale=config.guidance_scale,
|
296 |
+
num_inference_steps=config.num_inference_steps,
|
297 |
+
|
298 |
+
# reproducible generation settings
|
299 |
+
seed=config.seed,
|
300 |
+
|
301 |
+
# constants
|
302 |
"output_type": "pt",
|
303 |
"generator": generator
|
304 |
}
|
305 |
+
logger.info(f"Video model generation settings:")
|
|
|
306 |
pprint.pprint(generation_kwargs)
|
307 |
|
308 |
# Check if image-to-video generation is requested
|
309 |
+
if input_image:
|
|
|
310 |
# Process base64 image
|
311 |
+
if input_image.startswith('data:'):
|
312 |
+
input_image = image_data.split(',', 1)[1]
|
313 |
+
image_bytes = base64.b64decode(input_image)
|
314 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
315 |
generation_kwargs["image"] = image
|
316 |
frames = self.image_to_video(**generation_kwargs).frames
|