innoai commited on
Commit
0b63431
·
verified ·
1 Parent(s): fb854a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -60
app.py CHANGED
@@ -171,59 +171,37 @@ ASPECT_RATIOS = {
171
 
172
  def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
173
  """
174
- 根据不同的长宽比,生成符合 VAE 解码器缓存格式的零张量缓存。
175
- 缓存张量格式必须与 ZERO_VAE_CACHE 保持一致: [batch, time, channels, height, width]
176
  """
177
- ar_config = ASPECT_RATIOS[aspect_ratio]
178
- latent_w = ar_config["latent_w"]
179
- latent_h = ar_config["latent_h"]
180
-
181
- # 这里 time 维度初始化为 1,channels 对应各级别的通道数
182
- cache = []
183
-
184
- # 第一级特征,channels=512,下采样 8 倍
185
- cache.append(torch.zeros(
186
- 1, # batch size
187
- 1, # time frames
188
- 512, # channels
189
- latent_h // 8, # height
190
- latent_w // 8, # width
191
- device=device,
192
- dtype=dtype
193
- ))
194
- # 第二级特征,channels=512,下采样 4
195
- cache.append(torch.zeros(
196
- 1,
197
- 1,
198
- 512,
199
- latent_h // 4,
200
- latent_w // 4,
201
- device=device,
202
- dtype=dtype
203
- ))
204
- # 第三级特征,channels=256,下采样 2
205
- cache.append(torch.zeros(
206
- 1,
207
- 1,
208
- 256,
209
- latent_h // 2,
210
- latent_w // 2,
211
- device=device,
212
- dtype=dtype
213
- ))
214
- # 第四级特征,channels=128,不下采样
215
- cache.append(torch.zeros(
216
- 1,
217
- 1,
218
- 128,
219
- latent_h,
220
- latent_w,
221
- device=device,
222
- dtype=dtype
223
- ))
224
-
225
- return cache
226
-
227
 
228
  def frames_to_ts_file(frames, filepath, fps = 15):
229
  """
@@ -416,14 +394,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, aspect_ratio="16
416
 
417
  vae_cache, latents_cache = None, None
418
  if not APP_STATE["current_use_taehv"] and not args.trt:
419
- # For non-TRT and non-TAEHV, we need to handle aspect ratio properly
420
- # Use the original ZERO_VAE_CACHE as a template but adjust dimensions
421
- if aspect_ratio == "16:9":
422
- # Use default cache for 16:9
423
- vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
424
- else:
425
- # Create custom cache for 9:16
426
- vae_cache = get_vae_cache_for_aspect_ratio(aspect_ratio, gpu, torch.float16)
427
 
428
  num_blocks = 7
429
  current_start_frame = 0
 
171
 
172
  def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
173
  """
174
+ Create VAE cache with appropriate dimensions for the given aspect ratio.
175
+ Based on the structure of ZERO_VAE_CACHE but adjusted for different aspect ratios.
176
  """
177
+ # First, let's check the structure of ZERO_VAE_CACHE to understand the format
178
+ print(f"Creating VAE cache for {aspect_ratio}")
179
+
180
+ # For 9:16, we need to swap the height and width dimensions from the 16:9 default
181
+ if aspect_ratio == "9:16":
182
+ # The cache structure from ZERO_VAE_CACHE appears to be feature maps at different scales
183
+ # We need to maintain the same structure but swap H and W dimensions
184
+ cache = []
185
+ for i, tensor in enumerate(ZERO_VAE_CACHE):
186
+ # Get the original shape
187
+ original_shape = list(tensor.shape)
188
+ print(f"Original cache tensor {i} shape: {original_shape}")
189
+
190
+ # For 9:16, we swap the last two dimensions (H and W)
191
+ if len(original_shape) == 5: # (B, C, T, H, W)
192
+ new_shape = original_shape.copy()
193
+ new_shape[-2], new_shape[-1] = original_shape[-1], original_shape[-2] # Swap H and W
194
+ new_tensor = torch.zeros(new_shape, device=device, dtype=dtype)
195
+ cache.append(new_tensor)
196
+ print(f"New cache tensor {i} shape: {new_shape}")
197
+ else:
198
+ # If not 5D, just copy as is
199
+ cache.append(tensor.to(device=device, dtype=dtype))
200
+
201
+ return cache
202
+ else:
203
+ # For 16:9, use the default cache
204
+ return [c.to(device=device, dtype=dtype) for c in ZERO_VAE_CACHE]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  def frames_to_ts_file(frames, filepath, fps = 15):
207
  """
 
394
 
395
  vae_cache, latents_cache = None, None
396
  if not APP_STATE["current_use_taehv"] and not args.trt:
397
+ # Create VAE cache appropriate for the aspect ratio
398
+ vae_cache = get_vae_cache_for_aspect_ratio(aspect_ratio, gpu, torch.float16)
 
 
 
 
 
 
399
 
400
  num_blocks = 7
401
  current_start_frame = 0