jbilcke-hf HF staff commited on
Commit
1fd04e8
·
verified ·
1 Parent(s): 9d84818

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +103 -265
handler.py CHANGED
@@ -66,229 +66,37 @@ def print_directory_structure(startpath):
66
  logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
67
  apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")
68
 
69
- logger.info("💡 Printing directory structure of ""/repository"":")
70
- print_directory_structure("/repository")
71
 
72
  @dataclass
73
  class GenerationConfig:
74
  """Configuration for video generation"""
75
- width: int = 768
76
- height: int = 512
77
- fps: int = 24
78
- duration_sec: float = 4.0
79
- num_inference_steps: int = 30
80
- guidance_scale: float = 7.5
81
- upscale_factor: float = 2.0
82
- enable_interpolation: bool = False
83
- seed: int = -1 # -1 means random seed
84
-
85
- @property
86
- def num_frames(self) -> int:
87
- """Calculate number of frames based on fps and duration"""
88
- return int(self.duration_sec * self.fps) + 1
89
-
90
- def validate_and_adjust(self) -> 'GenerationConfig':
91
- """Validate and adjust parameters to meet constraints"""
92
- # Round dimensions to nearest multiple of 32
93
- self.width = max(32, min(MAX_WIDTH, round(self.width / 32) * 32))
94
- self.height = max(32, min(MAX_HEIGHT, round(self.height / 32) * 32))
95
-
96
- # Adjust number of frames to be in format 8k + 1
97
- k = (self.num_frames - 1) // 8
98
- num_frames = min((k * 8) + 1, MAX_FRAMES)
99
- self.duration_sec = (num_frames - 1) / self.fps
100
-
101
- # Set random seed if not specified
102
- if self.seed == -1:
103
- self.seed = random.randint(0, 2**32 - 1)
104
-
105
- return self
106
-
107
- class EndpointHandler:
108
- """Handles video generation requests using LTX models and Varnish post-processing"""
109
-
110
- def __init__(self, model_path: str = ""):
111
- """Initialize the handler with LTX models and Varnish
112
-
113
- Args:
114
- model_path: Path to LTX model weights
115
- """
116
- # Enable TF32 for potential speedup on Ampere GPUs
117
- #torch.backends.cuda.matmul.allow_tf32 = True
118
-
119
- # Initialize models with bfloat16 precision
120
- self.text_to_video = LTXPipeline.from_pretrained(
121
- model_path,
122
- torch_dtype=torch.bfloat16
123
- ).to("cuda")
124
-
125
- self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
126
- model_path,
127
- torch_dtype=torch.bfloat16
128
- ).to("cuda")
129
-
130
- # Enable CPU offload for memory efficiency
131
- #self.text_to_video.enable_model_cpu_offload()
132
- #self.image_to_video.enable_model_cpu_offload()
133
-
134
- # Initialize Varnish for post-processing
135
- self.varnish = Varnish(
136
- device="cuda" if torch.cuda.is_available() else "cpu",
137
- output_format="mp4",
138
- output_codec="h264",
139
- output_quality=23,
140
- enable_mmaudio=False,
141
- #model_base_dir=os.path.abspath(os.path.join(os.getcwd(), "varnish"))
142
- model_base_dir="/repository/varnish",
143
- )
144
-
145
- async def process_frames(
146
- self,
147
- frames: torch.Tensor,
148
- config: GenerationConfig
149
- ) -> tuple[str, dict]:
150
- """Post-process generated frames using Varnish
151
-
152
- Args:
153
- frames: Generated video frames tensor
154
- config: Generation configuration
155
-
156
- Returns:
157
- Tuple of (video data URI, metadata dictionary)
158
- """
159
- try:
160
- logger.info(f"Original frames shape: {frames.shape}")
161
-
162
- # Remove batch dimension if present
163
- if len(frames.shape) == 5:
164
- frames = frames.squeeze(0) # Remove batch dimension
165
-
166
- logger.info(f"Processed frames shape: {frames.shape}")
167
-
168
- # Process video with Varnish
169
- result = await self.varnish(
170
- input_data=frames,
171
- input_fps=config.fps,
172
- output_fps=config.fps,
173
- upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None,
174
- enable_interpolation=config.enable_interpolation
175
- )
176
-
177
- # Convert to data URI
178
- video_uri = await result.write(
179
- output_type="data-uri",
180
- output_format="mp4",
181
- output_codec="h264",
182
- output_quality=23
183
- )
184
-
185
- # Collect metadata
186
- metadata = {
187
- "width": result.metadata.width,
188
- "height": result.metadata.height,
189
- "num_frames": result.metadata.frame_count,
190
- "fps": result.metadata.fps,
191
- "duration": result.metadata.duration,
192
- "num_inference_steps": config.num_inference_steps,
193
- "seed": config.seed,
194
- "upscale_factor": config.upscale_factor,
195
- "interpolation_enabled": config.enable_interpolation
196
- }
197
-
198
- return video_uri, metadata
199
-
200
- except Exception as e:
201
- logger.error(f"Error in process_frames: {str(e)}")
202
- raise RuntimeError(f"Failed to process frames: {str(e)}")
203
-
204
- from dataclasses import dataclass
205
- from pathlib import Path
206
- import pathlib
207
- from typing import Dict, Any, Optional, Tuple
208
- import asyncio
209
- import base64
210
- import io
211
- import pprint
212
- import logging
213
- import random
214
- import traceback
215
- import os
216
- import numpy as np
217
- import torch
218
- from diffusers import LTXPipeline, LTXImageToVideoPipeline
219
- from PIL import Image
220
-
221
- from varnish import Varnish
222
-
223
- # Configure logging
224
- logging.basicConfig(level=logging.INFO)
225
- logger = logging.getLogger(__name__)
226
-
227
- # Constraints
228
- MAX_WIDTH = 1280
229
- MAX_HEIGHT = 720
230
- MAX_FRAMES = 257
231
-
232
- # this is only a temporary solution (famous last words)
233
- def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory):
234
- """
235
- Recursively rename all '.wut' files to '.pth' in the given directory
236
-
237
- Args:
238
- directory (str): Path to the directory to process
239
- """
240
- # Convert the directory path to absolute path
241
- directory = os.path.abspath(directory)
242
-
243
- # Walk through directory and its subdirectories
244
- for root, _, files in os.walk(directory):
245
- for filename in files:
246
- if filename.endswith('.wut'):
247
- # Get full path of the file
248
- old_path = os.path.join(root, filename)
249
- # Create new filename by replacing the extension
250
- new_filename = filename.replace('.wut', '.pth')
251
- new_path = os.path.join(root, new_filename)
252
-
253
- try:
254
- os.rename(old_path, new_path)
255
- print(f"Renamed: {old_path} -> {new_path}")
256
- except OSError as e:
257
- print(f"Error renaming {old_path}: {e}")
258
-
259
- def print_directory_structure(startpath):
260
- """Print the directory structure starting from the given path."""
261
- for root, dirs, files in os.walk(startpath):
262
- level = root.replace(startpath, '').count(os.sep)
263
- indent = ' ' * 4 * level
264
- logger.info(f"{indent}{os.path.basename(root)}/")
265
- subindent = ' ' * 4 * (level + 1)
266
- for f in files:
267
- logger.info(f"{subindent}{f}")
268
-
269
- logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
270
- apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")
271
 
272
- logger.info("💡 Printing directory structure of ""/repository"":")
273
- print_directory_structure("/repository")
 
274
 
275
- @dataclass
276
- class GenerationConfig:
277
- """Configuration for video generation"""
278
  width: int = 768
279
  height: int = 512
280
- fps: int = 24
281
- duration_sec: float = 4.0
282
- num_inference_steps: int = 30
 
 
283
  guidance_scale: float = 7.5
284
- upscale_factor: float = 2.0
285
- enable_interpolation: bool = False
 
286
  seed: int = -1 # -1 means random seed
287
 
288
- @property
289
- def num_frames(self) -> int:
290
- """Calculate number of frames based on fps and duration"""
291
- return int(self.duration_sec * self.fps) + 1
 
 
292
 
293
  def validate_and_adjust(self) -> 'GenerationConfig':
294
  """Validate and adjust parameters to meet constraints"""
@@ -299,7 +107,7 @@ class GenerationConfig:
299
  # Adjust number of frames to be in format 8k + 1
300
  k = (self.num_frames - 1) // 8
301
  num_frames = min((k * 8) + 1, MAX_FRAMES)
302
- self.duration_sec = (num_frames - 1) / self.fps
303
 
304
  # Set random seed if not specified
305
  if self.seed == -1:
@@ -339,9 +147,8 @@ class EndpointHandler:
339
  device="cuda" if torch.cuda.is_available() else "cpu",
340
  output_format="mp4",
341
  output_codec="h264",
342
- output_quality=23,
343
  enable_mmaudio=False,
344
- #model_base_dir=os.path.abspath(os.path.join(os.getcwd(), "varnish"))
345
  model_base_dir="/repository/varnish",
346
  )
347
 
@@ -367,22 +174,22 @@ class EndpointHandler:
367
  frames = frames.squeeze(0) # Remove batch dimension
368
 
369
  logger.info(f"Processed frames shape: {frames.shape}")
370
-
371
  # Process video with Varnish
372
  result = await self.varnish(
373
  input_data=frames,
374
- input_fps=config.fps,
375
- output_fps=config.fps,
376
- upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None,
377
- enable_interpolation=config.enable_interpolation
378
  )
379
 
380
  # Convert to data URI
381
  video_uri = await result.write(
382
- output_type="data-uri",
383
- output_format="mp4",
384
- output_codec="h264",
385
- output_quality=23
 
386
  )
387
 
388
  # Collect metadata
@@ -392,10 +199,7 @@ class EndpointHandler:
392
  "num_frames": result.metadata.frame_count,
393
  "fps": result.metadata.fps,
394
  "duration": result.metadata.duration,
395
- "num_inference_steps": config.num_inference_steps,
396
  "seed": config.seed,
397
- "upscale_factor": config.upscale_factor,
398
- "interpolation_enabled": config.enable_interpolation
399
  }
400
 
401
  return video_uri, metadata
@@ -404,45 +208,72 @@ class EndpointHandler:
404
  logger.error(f"Error in process_frames: {str(e)}")
405
  raise RuntimeError(f"Failed to process frames: {str(e)}")
406
 
 
407
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
408
  """Process incoming requests for video generation
409
 
410
  Args:
411
  data: Request data containing:
412
- - inputs (str): Text prompt or image
413
- - width (optional): Video width
414
- - height (optional): Video height
415
- - fps (optional): Frames per second
416
- - duration_sec (optional): Video duration
417
- - num_inference_steps (optional): Inference steps
418
- - guidance_scale (optional): Guidance scale
419
- - upscale_factor (optional): Upscaling factor
420
- - enable_interpolation (optional): Enable frame interpolation
421
- - seed (optional): Random seed
422
-
 
 
 
423
  Returns:
424
  Dictionary containing:
425
  - video: Base64 encoded MP4 data URI
426
  - content-type: MIME type
427
  - metadata: Generation metadata
428
  """
429
- # Extract prompt
430
- prompt = data.get("inputs")
431
- if not prompt:
432
- raise ValueError("No prompt provided in the 'inputs' field")
 
 
 
 
 
 
 
 
 
 
433
 
434
  # Create and validate configuration
435
  config = GenerationConfig(
436
- width=data.get("width", GenerationConfig.width),
437
- height=data.get("height", GenerationConfig.height),
438
- fps=data.get("fps", GenerationConfig.fps),
439
- duration_sec=data.get("duration_sec", GenerationConfig.duration_sec),
440
- num_inference_steps=data.get("num_inference_steps", GenerationConfig.num_inference_steps),
441
- guidance_scale=data.get("guidance_scale", GenerationConfig.guidance_scale),
442
- upscale_factor=data.get("upscale_factor", GenerationConfig.upscale_factor),
443
- enable_interpolation=data.get("enable_interpolation", GenerationConfig.enable_interpolation),
444
- seed=data.get("seed", GenerationConfig.seed)
 
 
 
 
 
 
 
 
 
 
445
  ).validate_and_adjust()
 
 
 
446
 
447
  try:
448
  with torch.no_grad():
@@ -451,28 +282,35 @@ class EndpointHandler:
451
  np.random.seed(config.seed)
452
  generator = torch.manual_seed(config.seed)
453
 
454
- # Prepare generation parameters
455
  generation_kwargs = {
456
- "prompt": prompt,
457
- "height": config.height,
458
- "width": config.width,
459
- "num_frames": config.num_frames,
460
- "guidance_scale": config.guidance_scale,
461
- "num_inference_steps": config.num_inference_steps,
 
 
 
 
 
 
 
 
 
462
  "output_type": "pt",
463
  "generator": generator
464
  }
465
-
466
- logger.info(f"Parameters:")
467
  pprint.pprint(generation_kwargs)
468
 
469
  # Check if image-to-video generation is requested
470
- image_data = data.get("image")
471
- if image_data:
472
  # Process base64 image
473
- if image_data.startswith('data:'):
474
- image_data = image_data.split(',', 1)[1]
475
- image_bytes = base64.b64decode(image_data)
476
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
477
  generation_kwargs["image"] = image
478
  frames = self.image_to_video(**generation_kwargs).frames
 
66
  logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
67
  apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")
68
 
69
+ #logger.info("💡 Printing directory structure of ""/repository"":")
70
+ #print_directory_structure("/repository")
71
 
72
  @dataclass
73
  class GenerationConfig:
74
  """Configuration for video generation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # general content settings
77
+ prompt: str = ""
78
+ negative_prompt: str = "worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres",
79
 
80
+ # video model settings (will be used during generation of the initial raw video clip)
 
 
81
  width: int = 768
82
  height: int = 512
83
+
84
+ # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
85
+ # The value must be a multiple of 8, plus 1 frame.
86
+ num_frames: int = 129
87
+
88
  guidance_scale: float = 7.5
89
+ num_inference_steps: int = 50
90
+
91
+ # reproducible generation settings
92
  seed: int = -1 # -1 means random seed
93
 
94
+ # varnish settings (will be used for post-processing after the raw video clip has been generated
95
+ fps: int = 24 # FPS of the final video (only applied at the the very end, when converting to mp4)
96
+ double_num_frames: bool = True # if True, the number of frames will be multiplied by 2 using RIFE
97
+ super_resolution: bool = True # if True, the resolution will be multiplied by 2 using Real_ESRGAN
98
+
99
+ grain_amount: float = 0.0
100
 
101
  def validate_and_adjust(self) -> 'GenerationConfig':
102
  """Validate and adjust parameters to meet constraints"""
 
107
  # Adjust number of frames to be in format 8k + 1
108
  k = (self.num_frames - 1) // 8
109
  num_frames = min((k * 8) + 1, MAX_FRAMES)
110
+
111
 
112
  # Set random seed if not specified
113
  if self.seed == -1:
 
147
  device="cuda" if torch.cuda.is_available() else "cpu",
148
  output_format="mp4",
149
  output_codec="h264",
150
+ output_quality=17,
151
  enable_mmaudio=False,
 
152
  model_base_dir="/repository/varnish",
153
  )
154
 
 
174
  frames = frames.squeeze(0) # Remove batch dimension
175
 
176
  logger.info(f"Processed frames shape: {frames.shape}")
177
+
178
  # Process video with Varnish
179
  result = await self.varnish(
180
  input_data=frames,
181
+ double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
182
+ super_resolution=config.grain_amount_config, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
183
+ grain_amount_config.grain_amount,
 
184
  )
185
 
186
  # Convert to data URI
187
  video_uri = await result.write(
188
+ type="data-uri",
189
+ format="mp4",
190
+ codec="h264",
191
+ fps=config.fps,
192
+ quality=23
193
  )
194
 
195
  # Collect metadata
 
199
  "num_frames": result.metadata.frame_count,
200
  "fps": result.metadata.fps,
201
  "duration": result.metadata.duration,
 
202
  "seed": config.seed,
 
 
203
  }
204
 
205
  return video_uri, metadata
 
208
  logger.error(f"Error in process_frames: {str(e)}")
209
  raise RuntimeError(f"Failed to process frames: {str(e)}")
210
 
211
+
212
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
213
  """Process incoming requests for video generation
214
 
215
  Args:
216
  data: Request data containing:
217
+ - inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
218
+ - parameters (dict):
219
+ - prompt (required, string): list of concepts to keep in the video.
220
+ - negative_prompt (optional, string): list of concepts to ignore in the video.
221
+ - width (optional, int, default to 768): width, or horizontal size in pixels.
222
+ - height (optional, int, default to 512): height, or vertical size in pixels.
223
+ - num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
224
+ - guidance_scale (optional, float, default to 7.5): Guidance scale
225
+ - num_inference_steps (optional, int, default to 50): number of inference steps
226
+ - seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
227
+ - fps (optional, int, default to 24): FPS of the final video
228
+ - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
229
+ - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
230
+ - grain_amount (optional, float): amount of film grain to add to the output video
231
  Returns:
232
  Dictionary containing:
233
  - video: Base64 encoded MP4 data URI
234
  - content-type: MIME type
235
  - metadata: Generation metadata
236
  """
237
+ inputs = data.get("inputs", dict())
238
+
239
+ input_prompt = inputs.get("prompt", "")
240
+ input_image = inputs.get("image")
241
+
242
+ params = data.get("parameters", dict())
243
+
244
+ if not input_prompt:
245
+ raise ValueError("The prompt should not be empty")
246
+
247
+ logger.info(f"Prompt: {input_prompt}")
248
+
249
+ logger.info(f"Raw parameters:")
250
+ pprint.pprint(params)
251
 
252
  # Create and validate configuration
253
  config = GenerationConfig(
254
+ # general content settings
255
+ prompt: input_prompt,
256
+ negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
257
+
258
+ # video model settings (will be used during generation of the initial raw video clip)
259
+ width=params.get("width", GenerationConfig.width),
260
+ height=params.get("height", GenerationConfig.height),
261
+ num_frames=params.get"num_frames", GenerationConfig.num_frames),
262
+ guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
263
+ num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
264
+
265
+ # reproducible generation settings
266
+ seed=params.get("seed", GenerationConfig.seed)
267
+
268
+ # varnish settings (will be used for post-processing after the raw video clip has been generated)
269
+ fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
270
+ double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
271
+ super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
272
+ grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
273
  ).validate_and_adjust()
274
+
275
+ logger.info(f"Global request settings:")
276
+ pprint.pprint(config)
277
 
278
  try:
279
  with torch.no_grad():
 
282
  np.random.seed(config.seed)
283
  generator = torch.manual_seed(config.seed)
284
 
285
+ # Prepare generation parameters for the video model (we omit params that are destined to Varnish)
286
  generation_kwargs = {
287
+ # general content settings
288
+ prompt: config.prompt,
289
+ negative_prompt=config.negative_prompt,
290
+
291
+ # video model settings (will be used during generation of the initial raw video clip)
292
+ width=params.config.width,
293
+ height=config.height,
294
+ num_frames=config.num_frames,
295
+ guidance_scale=config.guidance_scale,
296
+ num_inference_steps=config.num_inference_steps,
297
+
298
+ # reproducible generation settings
299
+ seed=config.seed,
300
+
301
+ # constants
302
  "output_type": "pt",
303
  "generator": generator
304
  }
305
+ logger.info(f"Video model generation settings:")
 
306
  pprint.pprint(generation_kwargs)
307
 
308
  # Check if image-to-video generation is requested
309
+ if input_image:
 
310
  # Process base64 image
311
+ if input_image.startswith('data:'):
312
+ input_image = image_data.split(',', 1)[1]
313
+ image_bytes = base64.b64decode(input_image)
314
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
315
  generation_kwargs["image"] = image
316
  frames = self.image_to_video(**generation_kwargs).frames