tori29umai commited on
Commit
4d55b31
·
verified ·
1 Parent(s): 6821262

Update demo_gradio.py

Browse files
Files changed (1) hide show
  1. demo_gradio.py +340 -99
demo_gradio.py CHANGED
@@ -1,8 +1,21 @@
 
 
1
  from diffusers_helper.hf_login import login
2
 
3
  import os
4
 
5
- os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  import gradio as gr
8
  import torch
@@ -17,22 +30,42 @@ from PIL import Image
17
  from diffusers import AutoencoderKLHunyuanVideo
18
  from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
19
  from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
20
- from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
 
 
 
 
 
 
 
 
21
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
22
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
23
- from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
 
 
 
 
 
 
 
 
 
 
24
  from diffusers_helper.thread_utils import AsyncStream, async_run
25
  from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
26
  from transformers import SiglipImageProcessor, SiglipVisionModel
27
  from diffusers_helper.clip_vision import hf_clip_vision_encode
28
  from diffusers_helper.bucket_tools import find_nearest_bucket
 
 
29
 
30
 
31
  parser = argparse.ArgumentParser()
32
- parser.add_argument('--share', action='store_true')
33
- parser.add_argument("--server", type=str, default='0.0.0.0')
34
  parser.add_argument("--port", type=int, required=False)
35
- parser.add_argument("--inbrowser", action='store_true')
36
  args = parser.parse_args()
37
 
38
  # for win desktop probably use --server 127.0.0.1 --inbrowser
@@ -43,34 +76,56 @@ print(args)
43
  free_mem_gb = get_cuda_free_memory_gb(gpu)
44
  high_vram = free_mem_gb > 60
45
 
46
- print(f'Free VRAM {free_mem_gb} GB')
47
- print(f'High-VRAM Mode: {high_vram}')
48
-
49
- text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
50
- text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
51
- tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
52
- tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
53
- vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
54
-
55
- feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
56
- image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
57
-
58
- transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  vae.eval()
61
  text_encoder.eval()
62
  text_encoder_2.eval()
63
  image_encoder.eval()
64
- transformer.eval()
65
 
66
  if not high_vram:
67
  vae.enable_slicing()
68
  vae.enable_tiling()
69
 
70
- transformer.high_quality_fp32_output_for_inference = True
71
- print('transformer.high_quality_fp32_output_for_inference = True')
72
-
73
- transformer.to(dtype=torch.bfloat16)
74
  vae.to(dtype=torch.float16)
75
  image_encoder.to(dtype=torch.float16)
76
  text_encoder.to(dtype=torch.float16)
@@ -80,47 +135,70 @@ vae.requires_grad_(False)
80
  text_encoder.requires_grad_(False)
81
  text_encoder_2.requires_grad_(False)
82
  image_encoder.requires_grad_(False)
83
- transformer.requires_grad_(False)
84
 
85
  if not high_vram:
86
  # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
87
- DynamicSwapInstaller.install_model(transformer, device=gpu)
88
  DynamicSwapInstaller.install_model(text_encoder, device=gpu)
89
  else:
90
  text_encoder.to(gpu)
91
  text_encoder_2.to(gpu)
92
  image_encoder.to(gpu)
93
  vae.to(gpu)
94
- transformer.to(gpu)
95
 
96
  stream = AsyncStream()
97
 
98
- outputs_folder = './outputs/'
99
  os.makedirs(outputs_folder, exist_ok=True)
100
 
101
 
102
  @torch.no_grad()
103
- def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
105
  total_latent_sections = int(max(round(total_latent_sections), 1))
106
 
107
  job_id = generate_timestamp()
108
 
109
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
110
 
111
  try:
112
  # Clean GPU
113
  if not high_vram:
114
- unload_complete_models(
115
- text_encoder, text_encoder_2, image_encoder, vae, transformer
116
- )
117
 
118
  # Text encoding
119
 
120
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
121
 
122
  if not high_vram:
123
- fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
 
124
  load_model_as_complete(text_encoder_2, target_device=gpu)
125
 
126
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
@@ -135,20 +213,20 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
135
 
136
  # Processing input image
137
 
138
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
139
 
140
  H, W, C = input_image.shape
141
  height, width = find_nearest_bucket(H, W, resolution=640)
142
  input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
143
 
144
- Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
145
 
146
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
147
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
148
 
149
  # VAE encoding
150
 
151
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
152
 
153
  if not high_vram:
154
  load_model_as_complete(vae, target_device=gpu)
@@ -157,7 +235,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
157
 
158
  # CLIP Vision
159
 
160
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
161
 
162
  if not high_vram:
163
  load_model_as_complete(image_encoder, target_device=gpu)
@@ -167,15 +245,62 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
167
 
168
  # Dtype
169
 
170
- llama_vec = llama_vec.to(transformer.dtype)
171
- llama_vec_n = llama_vec_n.to(transformer.dtype)
172
- clip_l_pooler = clip_l_pooler.to(transformer.dtype)
173
- clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
174
- image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # Sampling
177
 
178
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
179
 
180
  rnd = torch.Generator("cpu").manual_seed(seed)
181
  num_frames = latent_window_size * 4 - 3
@@ -197,23 +322,34 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
197
  is_last_section = latent_padding == 0
198
  latent_padding_size = latent_padding * latent_window_size
199
 
200
- if stream.input_queue.top() == 'end':
201
- stream.output_queue.push(('end', None))
202
  return
203
 
204
- print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}')
205
 
206
  indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
207
- clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
 
 
 
 
 
 
 
208
  clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
209
 
210
  clean_latents_pre = start_latent.to(history_latents)
211
- clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
 
 
212
  clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
213
 
214
  if not high_vram:
215
  unload_complete_models()
216
- move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
 
 
217
 
218
  if use_teacache:
219
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
@@ -221,26 +357,26 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
221
  transformer.initialize_teacache(enable_teacache=False)
222
 
223
  def callback(d):
224
- preview = d['denoised']
225
  preview = vae_decode_fake(preview)
226
 
227
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
228
- preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
229
 
230
- if stream.input_queue.top() == 'end':
231
- stream.output_queue.push(('end', None))
232
- raise KeyboardInterrupt('User ends the task.')
233
 
234
- current_step = d['i'] + 1
235
  percentage = int(100.0 * current_step / steps)
236
- hint = f'Sampling {current_step}/{steps}'
237
- desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...'
238
- stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
239
  return
240
 
241
  generated_latents = sample_hunyuan(
242
  transformer=transformer,
243
- sampler='unipc',
244
  width=width,
245
  height=height,
246
  frames=num_frames,
@@ -293,13 +429,13 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
293
  if not high_vram:
294
  unload_complete_models()
295
 
296
- output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
297
 
298
- save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
299
 
300
- print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
301
 
302
- stream.output_queue.push(('file', output_filename))
303
 
304
  if is_last_section:
305
  break
@@ -307,49 +443,86 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
307
  traceback.print_exc()
308
 
309
  if not high_vram:
310
- unload_complete_models(
311
- text_encoder, text_encoder_2, image_encoder, vae, transformer
312
- )
313
 
314
- stream.output_queue.push(('end', None))
315
  return
316
 
317
 
318
- def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  global stream
320
- assert input_image is not None, 'No input image!'
321
 
322
- yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
323
 
324
  stream = AsyncStream()
325
 
326
- async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  output_filename = None
329
 
330
  while True:
331
  flag, data = stream.output_queue.next()
332
 
333
- if flag == 'file':
334
  output_filename = data
335
  yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
336
 
337
- if flag == 'progress':
338
  preview, desc, html = data
339
- yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
 
 
340
 
341
- if flag == 'end':
342
- yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
 
 
343
  break
344
 
345
 
346
  def end_process():
347
- stream.input_queue.push('end')
348
 
349
 
350
  quick_prompts = [
351
- 'The girl dances gracefully, with clear movements, full of charm.',
352
- 'A character doing some simple body movements.',
353
  ]
354
  quick_prompts = [[x] for x in quick_prompts]
355
 
@@ -357,42 +530,110 @@ quick_prompts = [[x] for x in quick_prompts]
357
  css = make_progress_bar_css()
358
  block = gr.Blocks(css=css).queue()
359
  with block:
360
- gr.Markdown('# FramePack')
361
  with gr.Row():
362
  with gr.Column():
363
- input_image = gr.Image(sources='upload', type="numpy", label="Image", height=320)
364
- prompt = gr.Textbox(label="Prompt", value='')
365
- example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Quick List', samples_per_page=1000, components=[prompt])
366
- example_quick_prompts.click(lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=False)
 
 
 
 
367
 
368
  with gr.Row():
369
  start_button = gr.Button(value="Start Generation")
370
  end_button = gr.Button(value="End Generation", interactive=False)
371
 
372
  with gr.Group():
373
- use_teacache = gr.Checkbox(label='Use TeaCache', value=True, info='Faster speed, but often makes hands and fingers slightly worse.')
 
 
374
 
375
  n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=False) # Not used
376
  seed = gr.Number(label="Seed", value=31337, precision=0)
377
 
378
  total_second_length = gr.Slider(label="Total Video Length (Seconds)", minimum=1, maximum=120, value=5, step=0.1)
379
- latent_window_size = gr.Slider(label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, visible=False) # Should not change
380
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, info='Changing this value is not recommended.')
381
-
382
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, visible=False) # Should not change
383
- gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01, info='Changing this value is not recommended.')
384
- rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) # Should not change
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
- gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.")
 
 
 
387
 
388
  with gr.Column():
389
  preview_image = gr.Image(label="Next Latents", height=200, visible=False)
390
  result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True)
391
- gr.Markdown('Note that the ending actions will be generated before the starting actions due to the inverted sampling. If the starting action is not in the video, you just need to wait, and it will be generated later.')
392
- progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
393
- progress_bar = gr.HTML('', elem_classes='no-generating-animation')
394
- ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache]
395
- start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  end_button.click(fn=end_process)
397
 
398
 
 
1
+ import gc
2
+ import time
3
  from diffusers_helper.hf_login import login
4
 
5
  import os
6
 
7
+ # os.environ["HF_HOME"] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), "./hf_download")))
8
+
9
+ # we use HF_HOME in following order:
10
+ # 1. "../FramePack/hf_download" if exists.
11
+ # 2. "./hf_download"
12
+ hf_home_path_1 = os.path.abspath(
13
+ os.path.realpath(os.path.join(os.path.dirname(os.path.dirname(__file__)), "FramePack", "hf_download"))
14
+ )
15
+ hf_home_path_2 = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), "hf_download")))
16
+ hf_home = hf_home_path_1 if os.path.exists(hf_home_path_1) else hf_home_path_2
17
+ os.environ["HF_HOME"] = hf_home
18
+ print(f"Set HF_HOME env to {hf_home}")
19
 
20
  import gradio as gr
21
  import torch
 
30
  from diffusers import AutoencoderKLHunyuanVideo
31
  from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
32
  from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
33
+ from diffusers_helper.utils import (
34
+ save_bcthw_as_mp4,
35
+ crop_or_pad_yield_mask,
36
+ soft_append_bcthw,
37
+ resize_and_center_crop,
38
+ state_dict_weighted_merge,
39
+ state_dict_offset_merge,
40
+ generate_timestamp,
41
+ )
42
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
43
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
44
+ from diffusers_helper.memory import (
45
+ cpu,
46
+ gpu,
47
+ get_cuda_free_memory_gb,
48
+ move_model_to_device_with_memory_preservation,
49
+ offload_model_from_device_for_memory_preservation,
50
+ fake_diffusers_current_device,
51
+ DynamicSwapInstaller,
52
+ unload_complete_models,
53
+ load_model_as_complete,
54
+ )
55
  from diffusers_helper.thread_utils import AsyncStream, async_run
56
  from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
57
  from transformers import SiglipImageProcessor, SiglipVisionModel
58
  from diffusers_helper.clip_vision import hf_clip_vision_encode
59
  from diffusers_helper.bucket_tools import find_nearest_bucket
60
+ from utils.lora_utils import merge_lora_to_state_dict
61
+ from utils.fp8_optimization_utils import optimize_state_dict_with_fp8, apply_fp8_monkey_patch
62
 
63
 
64
  parser = argparse.ArgumentParser()
65
+ parser.add_argument("--share", action="store_true")
66
+ parser.add_argument("--server", type=str, default="0.0.0.0")
67
  parser.add_argument("--port", type=int, required=False)
68
+ parser.add_argument("--inbrowser", action="store_true")
69
  args = parser.parse_args()
70
 
71
  # for win desktop probably use --server 127.0.0.1 --inbrowser
 
76
  free_mem_gb = get_cuda_free_memory_gb(gpu)
77
  high_vram = free_mem_gb > 60
78
 
79
+ print(f"Free VRAM {free_mem_gb} GB")
80
+ print(f"High-VRAM Mode: {high_vram}")
81
+
82
+ text_encoder = LlamaModel.from_pretrained(
83
+ "hunyuanvideo-community/HunyuanVideo", subfolder="text_encoder", torch_dtype=torch.float16
84
+ ).cpu()
85
+ text_encoder_2 = CLIPTextModel.from_pretrained(
86
+ "hunyuanvideo-community/HunyuanVideo", subfolder="text_encoder_2", torch_dtype=torch.float16
87
+ ).cpu()
88
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer")
89
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2")
90
+ vae = AutoencoderKLHunyuanVideo.from_pretrained(
91
+ "hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16
92
+ ).cpu()
93
+
94
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder="feature_extractor")
95
+ image_encoder = SiglipVisionModel.from_pretrained(
96
+ "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
97
+ ).cpu()
98
+
99
+
100
+ def load_transfomer():
101
+ print("Loading transformer ...")
102
+ transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
103
+ "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
104
+ ).cpu()
105
+ transformer.eval()
106
+ transformer.high_quality_fp32_output_for_inference = True
107
+ print("transformer.high_quality_fp32_output_for_inference = True")
108
+
109
+ transformer.to(dtype=torch.bfloat16)
110
+ transformer.requires_grad_(False)
111
+ return transformer
112
+
113
+
114
+ transformer = None # load later
115
+ transformer_dtype = torch.bfloat16
116
+ previous_lora_file = None
117
+ previous_lora_multiplier = None
118
+ previous_fp8_optimization = None
119
 
120
  vae.eval()
121
  text_encoder.eval()
122
  text_encoder_2.eval()
123
  image_encoder.eval()
 
124
 
125
  if not high_vram:
126
  vae.enable_slicing()
127
  vae.enable_tiling()
128
 
 
 
 
 
129
  vae.to(dtype=torch.float16)
130
  image_encoder.to(dtype=torch.float16)
131
  text_encoder.to(dtype=torch.float16)
 
135
  text_encoder.requires_grad_(False)
136
  text_encoder_2.requires_grad_(False)
137
  image_encoder.requires_grad_(False)
 
138
 
139
  if not high_vram:
140
  # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
141
+ # DynamicSwapInstaller.install_model(transformer, device=gpu)
142
  DynamicSwapInstaller.install_model(text_encoder, device=gpu)
143
  else:
144
  text_encoder.to(gpu)
145
  text_encoder_2.to(gpu)
146
  image_encoder.to(gpu)
147
  vae.to(gpu)
148
+ # transformer.to(gpu)
149
 
150
  stream = AsyncStream()
151
 
152
+ outputs_folder = "./outputs/"
153
  os.makedirs(outputs_folder, exist_ok=True)
154
 
155
 
156
  @torch.no_grad()
157
+ def worker(
158
+ input_image,
159
+ prompt,
160
+ n_prompt,
161
+ seed,
162
+ total_second_length,
163
+ latent_window_size,
164
+ steps,
165
+ cfg,
166
+ gs,
167
+ rs,
168
+ gpu_memory_preservation,
169
+ use_teacache,
170
+ mp4_crf,
171
+ lora_file,
172
+ lora_multiplier,
173
+ fp8_optimization,
174
+ ):
175
+ global transformer, previous_lora_file, previous_lora_multiplier, previous_fp8_optimization
176
+
177
+ model_changed = transformer is None or (
178
+ lora_file != previous_lora_file
179
+ or lora_multiplier != previous_lora_multiplier
180
+ or fp8_optimization != previous_fp8_optimization
181
+ )
182
+
183
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
184
  total_latent_sections = int(max(round(total_latent_sections), 1))
185
 
186
  job_id = generate_timestamp()
187
 
188
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Starting ..."))))
189
 
190
  try:
191
  # Clean GPU
192
  if not high_vram:
193
+ unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
 
 
194
 
195
  # Text encoding
196
 
197
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Text encoding ..."))))
198
 
199
  if not high_vram:
200
+ # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
201
+ fake_diffusers_current_device(text_encoder, gpu)
202
  load_model_as_complete(text_encoder_2, target_device=gpu)
203
 
204
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
 
213
 
214
  # Processing input image
215
 
216
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Image processing ..."))))
217
 
218
  H, W, C = input_image.shape
219
  height, width = find_nearest_bucket(H, W, resolution=640)
220
  input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
221
 
222
+ Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f"{job_id}.png"))
223
 
224
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
225
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
226
 
227
  # VAE encoding
228
 
229
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "VAE encoding ..."))))
230
 
231
  if not high_vram:
232
  load_model_as_complete(vae, target_device=gpu)
 
235
 
236
  # CLIP Vision
237
 
238
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "CLIP Vision encoding ..."))))
239
 
240
  if not high_vram:
241
  load_model_as_complete(image_encoder, target_device=gpu)
 
245
 
246
  # Dtype
247
 
248
+ llama_vec = llama_vec.to(transformer_dtype)
249
+ llama_vec_n = llama_vec_n.to(transformer_dtype)
250
+ clip_l_pooler = clip_l_pooler.to(transformer_dtype)
251
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer_dtype)
252
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer_dtype)
253
+
254
+ # Load transformer model
255
+ if model_changed:
256
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading transformer ..."))))
257
+
258
+ transformer = None
259
+ time.sleep(1.0) # wait for the previous model to be unloaded
260
+ torch.cuda.empty_cache()
261
+ gc.collect()
262
+
263
+ previous_lora_file = lora_file
264
+ previous_lora_multiplier = lora_multiplier
265
+ previous_fp8_optimization = fp8_optimization
266
+
267
+ transformer = load_transfomer() # bfloat16, on cpu
268
+
269
+ if lora_file is not None or fp8_optimization:
270
+ state_dict = transformer.state_dict()
271
+
272
+ # LoRA should be merged before fp8 optimization
273
+ if lora_file is not None:
274
+ # TODO It would be better to merge the LoRA into the state dict before creating the transformer instance.
275
+ # Use from_config() instead of from_pretrained to make the instance without loading.
276
+
277
+ print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
278
+ state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
279
+ gc.collect()
280
+
281
+ if fp8_optimization:
282
+ TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
283
+ EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
284
+
285
+ # inplace optimization
286
+ print("Optimizing for fp8")
287
+ state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False)
288
+
289
+ # apply monkey patching
290
+ apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False)
291
+ gc.collect()
292
+
293
+ info = transformer.load_state_dict(state_dict, strict=True, assign=True)
294
+ print(f"LoRA and/or fp8 optimization applied: {info}")
295
+
296
+ if not high_vram:
297
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
298
+ else:
299
+ transformer.to(gpu)
300
 
301
  # Sampling
302
 
303
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Start sampling ..."))))
304
 
305
  rnd = torch.Generator("cpu").manual_seed(seed)
306
  num_frames = latent_window_size * 4 - 3
 
322
  is_last_section = latent_padding == 0
323
  latent_padding_size = latent_padding * latent_window_size
324
 
325
+ if stream.input_queue.top() == "end":
326
+ stream.output_queue.push(("end", None))
327
  return
328
 
329
+ print(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
330
 
331
  indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
332
+ (
333
+ clean_latent_indices_pre,
334
+ blank_indices,
335
+ latent_indices,
336
+ clean_latent_indices_post,
337
+ clean_latent_2x_indices,
338
+ clean_latent_4x_indices,
339
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
340
  clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
341
 
342
  clean_latents_pre = start_latent.to(history_latents)
343
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
344
+ [1, 2, 16], dim=2
345
+ )
346
  clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
347
 
348
  if not high_vram:
349
  unload_complete_models()
350
+ move_model_to_device_with_memory_preservation(
351
+ transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation
352
+ )
353
 
354
  if use_teacache:
355
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
 
357
  transformer.initialize_teacache(enable_teacache=False)
358
 
359
  def callback(d):
360
+ preview = d["denoised"]
361
  preview = vae_decode_fake(preview)
362
 
363
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
364
+ preview = einops.rearrange(preview, "b c t h w -> (b h) (t w) c")
365
 
366
+ if stream.input_queue.top() == "end":
367
+ stream.output_queue.push(("end", None))
368
+ raise KeyboardInterrupt("User ends the task.")
369
 
370
+ current_step = d["i"] + 1
371
  percentage = int(100.0 * current_step / steps)
372
+ hint = f"Sampling {current_step}/{steps}"
373
+ desc = f"Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ..."
374
+ stream.output_queue.push(("progress", (preview, desc, make_progress_bar_html(percentage, hint))))
375
  return
376
 
377
  generated_latents = sample_hunyuan(
378
  transformer=transformer,
379
+ sampler="unipc",
380
  width=width,
381
  height=height,
382
  frames=num_frames,
 
429
  if not high_vram:
430
  unload_complete_models()
431
 
432
+ output_filename = os.path.join(outputs_folder, f"{job_id}_{total_generated_latent_frames}.mp4")
433
 
434
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf)
435
 
436
+ print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
437
 
438
+ stream.output_queue.push(("file", output_filename))
439
 
440
  if is_last_section:
441
  break
 
443
  traceback.print_exc()
444
 
445
  if not high_vram:
446
+ unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
 
 
447
 
448
+ stream.output_queue.push(("end", None))
449
  return
450
 
451
 
452
+ def process(
453
+ input_image,
454
+ prompt,
455
+ n_prompt,
456
+ seed,
457
+ total_second_length,
458
+ latent_window_size,
459
+ steps,
460
+ cfg,
461
+ gs,
462
+ rs,
463
+ gpu_memory_preservation,
464
+ use_teacache,
465
+ mp4_crf,
466
+ lora_file,
467
+ lora_multiplier,
468
+ fp8_optimization,
469
+ ):
470
  global stream
471
+ assert input_image is not None, "No input image!"
472
 
473
+ yield None, None, "", "", gr.update(interactive=False), gr.update(interactive=True)
474
 
475
  stream = AsyncStream()
476
 
477
+ async_run(
478
+ worker,
479
+ input_image,
480
+ prompt,
481
+ n_prompt,
482
+ seed,
483
+ total_second_length,
484
+ latent_window_size,
485
+ steps,
486
+ cfg,
487
+ gs,
488
+ rs,
489
+ gpu_memory_preservation,
490
+ use_teacache,
491
+ mp4_crf,
492
+ lora_file,
493
+ lora_multiplier,
494
+ fp8_optimization,
495
+ )
496
 
497
  output_filename = None
498
 
499
  while True:
500
  flag, data = stream.output_queue.next()
501
 
502
+ if flag == "file":
503
  output_filename = data
504
  yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
505
 
506
+ if flag == "progress":
507
  preview, desc, html = data
508
+ yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(
509
+ interactive=True
510
+ )
511
 
512
+ if flag == "end":
513
+ yield output_filename, gr.update(visible=False), gr.update(), "", gr.update(interactive=True), gr.update(
514
+ interactive=False
515
+ )
516
  break
517
 
518
 
519
  def end_process():
520
+ stream.input_queue.push("end")
521
 
522
 
523
  quick_prompts = [
524
+ "The girl dances gracefully, with clear movements, full of charm.",
525
+ "A character doing some simple body movements.",
526
  ]
527
  quick_prompts = [[x] for x in quick_prompts]
528
 
 
530
  css = make_progress_bar_css()
531
  block = gr.Blocks(css=css).queue()
532
  with block:
533
+ gr.Markdown("# FramePack")
534
  with gr.Row():
535
  with gr.Column():
536
+ input_image = gr.Image(sources="upload", type="numpy", label="Image", height=320)
537
+ prompt = gr.Textbox(label="Prompt", value="")
538
+ example_quick_prompts = gr.Dataset(
539
+ samples=quick_prompts, label="Quick List", samples_per_page=1000, components=[prompt]
540
+ )
541
+ example_quick_prompts.click(
542
+ lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=False
543
+ )
544
 
545
  with gr.Row():
546
  start_button = gr.Button(value="Start Generation")
547
  end_button = gr.Button(value="End Generation", interactive=False)
548
 
549
  with gr.Group():
550
+ use_teacache = gr.Checkbox(
551
+ label="Use TeaCache", value=True, info="Faster speed, but often makes hands and fingers slightly worse."
552
+ )
553
 
554
  n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=False) # Not used
555
  seed = gr.Number(label="Seed", value=31337, precision=0)
556
 
557
  total_second_length = gr.Slider(label="Total Video Length (Seconds)", minimum=1, maximum=120, value=5, step=0.1)
558
+ latent_window_size = gr.Slider(
559
+ label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, visible=False
560
+ ) # Should not change
561
+ steps = gr.Slider(
562
+ label="Steps", minimum=1, maximum=100, value=25, step=1, info="Changing this value is not recommended."
563
+ )
564
+
565
+ cfg = gr.Slider(
566
+ label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, visible=False
567
+ ) # Should not change
568
+ gs = gr.Slider(
569
+ label="Distilled CFG Scale",
570
+ minimum=1.0,
571
+ maximum=32.0,
572
+ value=10.0,
573
+ step=0.01,
574
+ info="Changing this value is not recommended.",
575
+ )
576
+ rs = gr.Slider(
577
+ label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False
578
+ ) # Should not change
579
+
580
+ gpu_memory_preservation = gr.Slider(
581
+ label="GPU Inference Preserved Memory (GB) (larger means slower)",
582
+ minimum=6,
583
+ maximum=128,
584
+ value=6,
585
+ step=0.1,
586
+ info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.",
587
+ )
588
+
589
+ mp4_crf = gr.Slider(
590
+ label="MP4 Compression",
591
+ minimum=0,
592
+ maximum=100,
593
+ value=16,
594
+ step=1,
595
+ info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ",
596
+ )
597
 
598
+ with gr.Group():
599
+ lora_file = gr.File(label="LoRA File", file_count="single", type="filepath")
600
+ lora_multiplier = gr.Slider(label="LoRA Multiplier", minimum=0.0, maximum=1.0, value=0.8, step=0.1)
601
+ fp8_optimization = gr.Checkbox(label="FP8 Optimization", value=False)
602
 
603
  with gr.Column():
604
  preview_image = gr.Image(label="Next Latents", height=200, visible=False)
605
  result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True)
606
+ gr.Markdown(
607
+ "Note that the ending actions will be generated before the starting actions due to the inverted sampling. If the starting action is not in the video, you just need to wait, and it will be generated later."
608
+ )
609
+ progress_desc = gr.Markdown("", elem_classes="no-generating-animation")
610
+ progress_bar = gr.HTML("", elem_classes="no-generating-animation")
611
+
612
+ gr.HTML(
613
+ '<div style="text-align:center; margin-top:20px;">Share your results and find ideas at the <a href="https://x.com/search?q=framepack&f=live" target="_blank">FramePack Twitter (X) thread</a></div>'
614
+ )
615
+
616
+ ips = [
617
+ input_image,
618
+ prompt,
619
+ n_prompt,
620
+ seed,
621
+ total_second_length,
622
+ latent_window_size,
623
+ steps,
624
+ cfg,
625
+ gs,
626
+ rs,
627
+ gpu_memory_preservation,
628
+ use_teacache,
629
+ mp4_crf,
630
+ lora_file,
631
+ lora_multiplier,
632
+ fp8_optimization,
633
+ ]
634
+ start_button.click(
635
+ fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button]
636
+ )
637
  end_button.click(fn=end_process)
638
 
639