DeepBeepMeep commited on
Commit
fcbb34f
·
1 Parent(s): d29010d

Added RIFLEx support

Browse files
README.md CHANGED
@@ -19,7 +19,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
19
 
20
  ## 🔥 Latest News!!
21
 
22
- * Mar 03, 2025: 👋 Wan2.1GP DeepBeepMeep out of this World version ! Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p
 
23
  * Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
24
  * Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
25
 
@@ -35,7 +36,6 @@ This version has the following improvements over the original Alibaba model:
35
  - Improved gradio interface with progression bar and more options
36
  - Multiples prompts / multiple generations per prompt
37
  - Support multiple pretrained Loras with 32 GB of RAM or less
38
- - Switch easily between Hunyuan and Fast Hunyuan models and quantized / non quantized models
39
  - Much simpler installation
40
 
41
 
@@ -105,10 +105,28 @@ pip install https://github.com/deepbeepmeep/SageAttention/raw/refs/heads/main/re
105
  ## Run the application
106
 
107
  ### Run a Gradio Server on port 7860 (recommended)
 
 
108
  ```bash
109
  python gradio_server.py
 
 
 
 
 
 
 
 
110
  ```
111
 
 
 
 
 
 
 
 
 
112
 
113
  ### Loras support
114
 
@@ -131,7 +149,8 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
131
 
132
 
133
  ### Command line parameters for Gradio Server
134
- --profile no : default (4) : no of profile between 1 and 5\
 
135
  --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
136
  --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
137
  --lora-preset preset : name of preset gile (without the extension) to preload
@@ -141,6 +160,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
141
  --open-browser : open automatically Browser when launching Gradio Server\
142
  --compile : turn on pytorch compilation\
143
  --attention mode: force attention mode among, sdpa, flash, sage, sage2\
 
144
 
145
  ### Profiles (for power users only)
146
  You can choose between 5 profiles, these will try to leverage the most your hardware, but have little impact for HunyuanVideo GP:
 
19
 
20
  ## 🔥 Latest News!!
21
 
22
+ * Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep brings: Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
23
+
24
  * Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
25
  * Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
26
 
 
36
  - Improved gradio interface with progression bar and more options
37
  - Multiples prompts / multiple generations per prompt
38
  - Support multiple pretrained Loras with 32 GB of RAM or less
 
39
  - Much simpler installation
40
 
41
 
 
105
  ## Run the application
106
 
107
  ### Run a Gradio Server on port 7860 (recommended)
108
+
109
+ To run the text to video generator (in Low VRAM mode):
110
  ```bash
111
  python gradio_server.py
112
+ #or
113
+ python gradio_server.py --t2v
114
+
115
+ ```
116
+
117
+ To run the image to video generator (in Low VRAM mode):
118
+ ```bash
119
+ python gradio_server.py --i2v
120
  ```
121
 
122
+ Within the application you can configure which video generator will be launched without specifying a command line switch.
123
+
124
+ To run the application while loading entirely the diffusion model in VRAM (slightly faster but requires 24 GB of VRAM for a 8 bits quantized 14B model )
125
+ ```bash
126
+ python gradio_server.py --profile 3
127
+ ```
128
+ Please note that diffusion model of Wan2.1GP is extremely VRAM optimized and this will greatly benefit low VRAM systems since the diffusion / denoising step is the longest part of the generation process. However, the VAE encoder (at the beginning of a image 2 video process) and the VAE decoder (at the end of any video process) is only 20% lighter and it will require temporarly 22 GB of VRAM for a 720p generation and 12 GB of VRAM for a 480p generation. Therefore if you have less than these numbers, you may experience slow down at the begining and at the end of the generation process due to pytorch VRAM offloading.
129
+
130
 
131
  ### Loras support
132
 
 
149
 
150
 
151
  ### Command line parameters for Gradio Server
152
+ --i2v : launch the image to video generator\
153
+ --t2v : launch the text to video generator\
154
  --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
155
  --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
156
  --lora-preset preset : name of preset gile (without the extension) to preload
 
160
  --open-browser : open automatically Browser when launching Gradio Server\
161
  --compile : turn on pytorch compilation\
162
  --attention mode: force attention mode among, sdpa, flash, sage, sage2\
163
+ --profile no : default (4) : no of profile between 1 and 5\
164
 
165
  ### Profiles (for power users only)
166
  You can choose between 5 profiles, these will try to leverage the most your hardware, but have little impact for HunyuanVideo GP:
gradio/i2v_14B_singleGPU.py CHANGED
@@ -288,7 +288,7 @@ if __name__ == '__main__':
288
  # resolution = '720P'
289
  resolution = '480P'
290
 
291
- load_model(resolution)
292
 
293
  print("Step1: Init prompt_expander...", end='', flush=True)
294
  if args.prompt_extend_method == "dashscope":
 
288
  # resolution = '720P'
289
  resolution = '480P'
290
 
291
+ load_i2v_model(resolution)
292
 
293
  print("Step1: Init prompt_expander...", end='', flush=True)
294
  if args.prompt_extend_method == "dashscope":
gradio_server.py CHANGED
@@ -18,6 +18,8 @@ from wan.utils.utils import cache_video
18
  from wan.modules.attention import get_attention_modes
19
  import torch
20
  import gc
 
 
21
 
22
  def _parse_args():
23
  parser = argparse.ArgumentParser(
@@ -752,7 +754,7 @@ def generate_video(
752
  state["in_progress"] = True
753
  state["selected"] = 0
754
 
755
- enable_riflex = RIFLEx_setting == 0 and video_length > (5* 24) or RIFLEx_setting == 1
756
  # VAE Tiling
757
  device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
758
 
@@ -810,7 +812,8 @@ def generate_video(
810
  n_prompt=negative_prompt,
811
  seed=seed,
812
  offload_model=False,
813
- callback=callback
 
814
  )
815
 
816
  else:
@@ -824,9 +827,10 @@ def generate_video(
824
  n_prompt=negative_prompt,
825
  seed=seed,
826
  offload_model=False,
827
- callback=callback
 
828
  )
829
- except:
830
  gen_in_progress = False
831
  if temp_filename!= None and os.path.isfile(temp_filename):
832
  os.remove(temp_filename)
@@ -838,7 +842,21 @@ def generate_video(
838
 
839
  gc.collect()
840
  torch.cuda.empty_cache()
841
- raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
 
843
 
844
  if samples != None:
@@ -949,7 +967,7 @@ def create_demo():
949
  else:
950
  gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
951
 
952
- gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B> the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
953
 
954
  if use_image2video and False:
955
  pass
@@ -959,7 +977,7 @@ def create_demo():
959
  gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
960
  gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
961
  gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
962
-
963
  gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
964
 
965
 
@@ -1092,25 +1110,37 @@ def create_demo():
1092
 
1093
 
1094
  with gr.Row():
1095
- resolution = gr.Dropdown(
1096
- choices=[
1097
- # 720p
1098
- ("1280x720 (16:9, 720p)", "1280x720"),
1099
- ("720x1280 (9:16, 720p)", "720x1280"),
1100
- ("1024x1024 (4:3, 720p, T2V only)", "1024x024"),
1101
- # ("832x1104 (3:4, 720p)", "832x1104"),
1102
- # ("960x960 (1:1, 720p)", "960x960"),
1103
- # 480p
1104
- # ("960x544 (16:9, 480p)", "960x544"),
1105
- ("832x480 (16:9, 480p)", "832x480"),
1106
- ("480x832 (9:16, 480p)", "480x832"),
1107
- # ("832x624 (4:3, 540p)", "832x624"),
1108
- # ("624x832 (3:4, 540p)", "624x832"),
1109
- # ("720x720 (1:1, 540p)", "720x720"),
1110
- ],
1111
- value="832x480",
1112
- label="Resolution"
1113
- )
 
 
 
 
 
 
 
 
 
 
 
 
1114
 
1115
  with gr.Row():
1116
  with gr.Column():
@@ -1125,7 +1155,7 @@ def create_demo():
1125
  with gr.Row(visible= len(loras)>0):
1126
  lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
1127
  with gr.Column(scale=5):
1128
- lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=False, choices= lset_choices, value=default_lora_preset)
1129
  with gr.Column(scale=1):
1130
  # with gr.Column():
1131
  with gr.Row(height=17):
 
18
  from wan.modules.attention import get_attention_modes
19
  import torch
20
  import gc
21
+ import traceback
22
+
23
 
24
  def _parse_args():
25
  parser = argparse.ArgumentParser(
 
754
  state["in_progress"] = True
755
  state["selected"] = 0
756
 
757
+ enable_RIFLEx = RIFLEx_setting == 0 and video_length > (5* 16) or RIFLEx_setting == 1
758
  # VAE Tiling
759
  device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
760
 
 
812
  n_prompt=negative_prompt,
813
  seed=seed,
814
  offload_model=False,
815
+ callback=callback,
816
+ enable_RIFLEx = enable_RIFLEx
817
  )
818
 
819
  else:
 
827
  n_prompt=negative_prompt,
828
  seed=seed,
829
  offload_model=False,
830
+ callback=callback,
831
+ enable_RIFLEx = enable_RIFLEx
832
  )
833
+ except Exception as e:
834
  gen_in_progress = False
835
  if temp_filename!= None and os.path.isfile(temp_filename):
836
  os.remove(temp_filename)
 
842
 
843
  gc.collect()
844
  torch.cuda.empty_cache()
845
+ s = str(e)
846
+ keyword_list = ["vram", "VRAM", "memory", "triton", "cuda", "allocat"]
847
+ VRAM_crash= False
848
+ if any( keyword in s for keyword in keyword_list):
849
+ VRAM_crash = True
850
+ else:
851
+ stack = traceback.extract_stack(f=None, limit=5)
852
+ for frame in stack:
853
+ if any( keyword in frame.name for keyword in keyword_list):
854
+ VRAM_crash = True
855
+ break
856
+ if VRAM_crash:
857
+ raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
858
+ else:
859
+ raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
860
 
861
 
862
  if samples != None:
 
967
  else:
968
  gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
969
 
970
+ gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B>, the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
971
 
972
  if use_image2video and False:
973
  pass
 
977
  gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
978
  gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
979
  gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
980
+ gr.Markdown("It is not recommmended to generate a video longer than 8s even if there is still some VRAM left as some artifact may appear")
981
  gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
982
 
983
 
 
1110
 
1111
 
1112
  with gr.Row():
1113
+ if use_image2video:
1114
+ resolution = gr.Dropdown(
1115
+ choices=[
1116
+ # 720p
1117
+ ("720p", "1280x720"),
1118
+ ("480p", "832x480"),
1119
+ ],
1120
+ value="832x480",
1121
+ label="Resolution (video will have the same height / width ratio than the original image)"
1122
+ )
1123
+
1124
+ else:
1125
+ resolution = gr.Dropdown(
1126
+ choices=[
1127
+ # 720p
1128
+ ("1280x720 (16:9, 720p)", "1280x720"),
1129
+ ("720x1280 (9:16, 720p)", "720x1280"),
1130
+ ("1024x1024 (4:3, 720p)", "1024x024"),
1131
+ # ("832x1104 (3:4, 720p)", "832x1104"),
1132
+ # ("960x960 (1:1, 720p)", "960x960"),
1133
+ # 480p
1134
+ # ("960x544 (16:9, 480p)", "960x544"),
1135
+ ("832x480 (16:9, 480p)", "832x480"),
1136
+ ("480x832 (9:16, 480p)", "480x832"),
1137
+ # ("832x624 (4:3, 540p)", "832x624"),
1138
+ # ("624x832 (3:4, 540p)", "624x832"),
1139
+ # ("720x720 (1:1, 540p)", "720x720"),
1140
+ ],
1141
+ value="832x480",
1142
+ label="Resolution"
1143
+ )
1144
 
1145
  with gr.Row():
1146
  with gr.Column():
 
1155
  with gr.Row(visible= len(loras)>0):
1156
  lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
1157
  with gr.Column(scale=5):
1158
+ lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
1159
  with gr.Column(scale=1):
1160
  # with gr.Column():
1161
  with gr.Row(height=17):
wan/image2video.py CHANGED
@@ -143,7 +143,9 @@ class WanI2V:
143
  n_prompt="",
144
  seed=-1,
145
  offload_model=True,
146
- callback = None
 
 
147
  ):
148
  r"""
149
  Generates video frames from input image and text prompt using diffusion process.
@@ -262,104 +264,107 @@ class WanI2V:
262
  no_sync = getattr(self.model, 'no_sync', noop_no_sync)
263
 
264
  # evaluation mode
265
- with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
266
-
267
- if sample_solver == 'unipc':
268
- sample_scheduler = FlowUniPCMultistepScheduler(
269
- num_train_timesteps=self.num_train_timesteps,
270
- shift=1,
271
- use_dynamic_shifting=False)
272
- sample_scheduler.set_timesteps(
273
- sampling_steps, device=self.device, shift=shift)
274
- timesteps = sample_scheduler.timesteps
275
- elif sample_solver == 'dpm++':
276
- sample_scheduler = FlowDPMSolverMultistepScheduler(
277
- num_train_timesteps=self.num_train_timesteps,
278
- shift=1,
279
- use_dynamic_shifting=False)
280
- sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
281
- timesteps, _ = retrieve_timesteps(
282
- sample_scheduler,
283
- device=self.device,
284
- sigmas=sampling_sigmas)
285
- else:
286
- raise NotImplementedError("Unsupported solver.")
287
-
288
- # sample videos
289
- latent = noise
290
-
291
- arg_c = {
292
- 'context': [context[0]],
293
- 'clip_fea': clip_context,
294
- 'seq_len': max_seq_len,
295
- 'y': [y],
296
- 'pipeline' : self
297
- }
298
-
299
- arg_null = {
300
- 'context': context_null,
301
- 'clip_fea': clip_context,
302
- 'seq_len': max_seq_len,
303
- 'y': [y],
304
- 'pipeline' : self
305
- }
306
 
307
- if offload_model:
308
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
- # self.model.to(self.device)
311
- if callback != None:
312
- callback(-1, None)
313
-
314
- self._interrupt = False
315
- for i, t in enumerate(tqdm(timesteps)):
316
- latent_model_input = [latent.to(self.device)]
317
- timestep = [t]
318
-
319
- timestep = torch.stack(timestep).to(self.device)
320
-
321
- noise_pred_cond = self.model(
322
- latent_model_input, t=timestep, **arg_c)[0]
323
- if self._interrupt:
324
- return None
325
- if offload_model:
326
- torch.cuda.empty_cache()
327
- noise_pred_uncond = self.model(
328
- latent_model_input, t=timestep, **arg_null)[0]
329
- if self._interrupt:
330
- return None
331
- del latent_model_input
332
- if offload_model:
333
- torch.cuda.empty_cache()
334
- noise_pred = noise_pred_uncond + guide_scale * (
335
- noise_pred_cond - noise_pred_uncond)
336
- del noise_pred_uncond
337
-
338
- latent = latent.to(
339
- torch.device('cpu') if offload_model else self.device)
340
-
341
- temp_x0 = sample_scheduler.step(
342
- noise_pred.unsqueeze(0),
343
- t,
344
- latent.unsqueeze(0),
345
- return_dict=False,
346
- generator=seed_g)[0]
347
- latent = temp_x0.squeeze(0)
348
- del temp_x0
349
- del timestep
350
-
351
- if callback is not None:
352
- callback(i, latent)
353
-
354
-
355
- x0 = [latent.to(self.device)]
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  if offload_model:
358
- self.model.cpu()
359
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
- if self.rank == 0:
362
- videos = self.vae.decode(x0)
363
 
364
  del noise, latent
365
  del sample_scheduler
 
143
  n_prompt="",
144
  seed=-1,
145
  offload_model=True,
146
+ callback = None,
147
+ enable_RIFLEx = False
148
+
149
  ):
150
  r"""
151
  Generates video frames from input image and text prompt using diffusion process.
 
264
  no_sync = getattr(self.model, 'no_sync', noop_no_sync)
265
 
266
  # evaluation mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ if sample_solver == 'unipc':
269
+ sample_scheduler = FlowUniPCMultistepScheduler(
270
+ num_train_timesteps=self.num_train_timesteps,
271
+ shift=1,
272
+ use_dynamic_shifting=False)
273
+ sample_scheduler.set_timesteps(
274
+ sampling_steps, device=self.device, shift=shift)
275
+ timesteps = sample_scheduler.timesteps
276
+ elif sample_solver == 'dpm++':
277
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
278
+ num_train_timesteps=self.num_train_timesteps,
279
+ shift=1,
280
+ use_dynamic_shifting=False)
281
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
282
+ timesteps, _ = retrieve_timesteps(
283
+ sample_scheduler,
284
+ device=self.device,
285
+ sigmas=sampling_sigmas)
286
+ else:
287
+ raise NotImplementedError("Unsupported solver.")
288
+
289
+ # sample videos
290
+ latent = noise
291
+
292
+ freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
293
+
294
+ arg_c = {
295
+ 'context': [context[0]],
296
+ 'clip_fea': clip_context,
297
+ 'seq_len': max_seq_len,
298
+ 'y': [y],
299
+ 'freqs' : freqs,
300
+ 'pipeline' : self
301
+ }
302
+
303
+ arg_null = {
304
+ 'context': context_null,
305
+ 'clip_fea': clip_context,
306
+ 'seq_len': max_seq_len,
307
+ 'y': [y],
308
+ 'freqs' : freqs,
309
+ 'pipeline' : self
310
+ }
311
+
312
+ if offload_model:
313
+ torch.cuda.empty_cache()
314
+
315
+ # self.model.to(self.device)
316
+ if callback != None:
317
+ callback(-1, None)
318
 
319
+ self._interrupt = False
320
+ for i, t in enumerate(tqdm(timesteps)):
321
+ latent_model_input = [latent.to(self.device)]
322
+ timestep = [t]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ timestep = torch.stack(timestep).to(self.device)
325
+
326
+ noise_pred_cond = self.model(
327
+ latent_model_input, t=timestep, **arg_c)[0]
328
+ if self._interrupt:
329
+ return None
330
+ if offload_model:
331
+ torch.cuda.empty_cache()
332
+ noise_pred_uncond = self.model(
333
+ latent_model_input, t=timestep, **arg_null)[0]
334
+ if self._interrupt:
335
+ return None
336
+ del latent_model_input
337
  if offload_model:
 
338
  torch.cuda.empty_cache()
339
+ noise_pred = noise_pred_uncond + guide_scale * (
340
+ noise_pred_cond - noise_pred_uncond)
341
+ del noise_pred_uncond
342
+
343
+ latent = latent.to(
344
+ torch.device('cpu') if offload_model else self.device)
345
+
346
+ temp_x0 = sample_scheduler.step(
347
+ noise_pred.unsqueeze(0),
348
+ t,
349
+ latent.unsqueeze(0),
350
+ return_dict=False,
351
+ generator=seed_g)[0]
352
+ latent = temp_x0.squeeze(0)
353
+ del temp_x0
354
+ del timestep
355
+
356
+ if callback is not None:
357
+ callback(i, latent)
358
+
359
+
360
+ x0 = [latent.to(self.device)]
361
+
362
+ if offload_model:
363
+ self.model.cpu()
364
+ torch.cuda.empty_cache()
365
 
366
+ if self.rank == 0:
367
+ videos = self.vae.decode(x0)
368
 
369
  del noise, latent
370
  del sample_scheduler
wan/modules/model.py CHANGED
@@ -6,6 +6,8 @@ import torch.cuda.amp as amp
6
  import torch.nn as nn
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
8
  from diffusers.models.modeling_utils import ModelMixin
 
 
9
 
10
  from .attention import pay_attention
11
 
@@ -25,7 +27,49 @@ def sinusoidal_embedding_1d(dim, position):
25
  return x
26
 
27
 
28
- # @amp.autocast(enabled=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def rope_params(max_seq_len, dim, theta=10000):
30
  assert dim % 2 == 0
31
  freqs = torch.outer(
@@ -588,14 +632,6 @@ class WanModel(ModelMixin, ConfigMixin):
588
  self.head = Head(dim, out_dim, patch_size, eps)
589
 
590
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
591
- assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
592
- d = dim // num_heads
593
- self.freqs = torch.cat([
594
- rope_params(1024, d - 4 * (d // 6)),
595
- rope_params(1024, 2 * (d // 6)),
596
- rope_params(1024, 2 * (d // 6))
597
- ],
598
- dim=1)
599
 
600
  if model_type == 'i2v':
601
  self.img_emb = MLPProj(1280, dim)
@@ -603,6 +639,29 @@ class WanModel(ModelMixin, ConfigMixin):
603
  # initialize weights
604
  self.init_weights()
605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
  def forward(
607
  self,
608
  x,
@@ -611,6 +670,7 @@ class WanModel(ModelMixin, ConfigMixin):
611
  seq_len,
612
  clip_fea=None,
613
  y=None,
 
614
  pipeline = None,
615
  ):
616
  r"""
@@ -638,8 +698,8 @@ class WanModel(ModelMixin, ConfigMixin):
638
  assert clip_fea is not None and y is not None
639
  # params
640
  device = self.patch_embedding.weight.device
641
- if self.freqs.device != device:
642
- self.freqs = self.freqs.to(device)
643
 
644
  if y is not None:
645
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@@ -683,7 +743,7 @@ class WanModel(ModelMixin, ConfigMixin):
683
  e=e0,
684
  seq_lens=seq_lens,
685
  grid_sizes=grid_sizes,
686
- freqs=self.freqs,
687
  context=context,
688
  context_lens=context_lens)
689
 
 
6
  import torch.nn as nn
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
8
  from diffusers.models.modeling_utils import ModelMixin
9
+ import numpy as np
10
+ from typing import Union,Optional
11
 
12
  from .attention import pay_attention
13
 
 
27
  return x
28
 
29
 
30
+
31
+
32
+ def identify_k( b: float, d: int, N: int):
33
+ """
34
+ This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
35
+
36
+ Args:
37
+ b (`float`): The base frequency for RoPE.
38
+ d (`int`): Dimension of the frequency tensor
39
+ N (`int`): the first observed repetition frame in latent space
40
+ Returns:
41
+ k (`int`): the index of intrinsic frequency component
42
+ N_k (`int`): the period of intrinsic frequency component in latent space
43
+ Example:
44
+ In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
45
+ k, N_k = identify_k(b=256, d=16, N=48)
46
+ In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
47
+ """
48
+
49
+ # Compute the period of each frequency in RoPE according to Eq.(4)
50
+ periods = []
51
+ for j in range(1, d // 2 + 1):
52
+ theta_j = 1.0 / (b ** (2 * (j - 1) / d))
53
+ N_j = round(2 * torch.pi / theta_j)
54
+ periods.append(N_j)
55
+
56
+ # Identify the intrinsic frequency whose period is closed to N(see Eq.(7))
57
+ diffs = [abs(N_j - N) for N_j in periods]
58
+ k = diffs.index(min(diffs)) + 1
59
+ N_k = periods[k-1]
60
+ return k, N_k
61
+
62
+ def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
63
+ assert dim % 2 == 0
64
+ exponents = torch.arange(0, dim, 2, dtype=torch.float64).div(dim)
65
+ inv_theta_pow = 1.0 / torch.pow(theta, exponents)
66
+
67
+ inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
68
+
69
+ freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
70
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
71
+ return freqs
72
+
73
  def rope_params(max_seq_len, dim, theta=10000):
74
  assert dim % 2 == 0
75
  freqs = torch.outer(
 
632
  self.head = Head(dim, out_dim, patch_size, eps)
633
 
634
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
 
 
 
 
 
 
 
 
635
 
636
  if model_type == 'i2v':
637
  self.img_emb = MLPProj(1280, dim)
 
639
  # initialize weights
640
  self.init_weights()
641
 
642
+
643
+ # self.freqs = torch.cat([
644
+ # rope_params(1024, d - 4 * (d // 6)), #44
645
+ # rope_params(1024, 2 * (d // 6)), #42
646
+ # rope_params(1024, 2 * (d // 6)) #42
647
+ # ],dim=1)
648
+
649
+
650
+ def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None):
651
+ dim = self.dim
652
+ num_heads = self.num_heads
653
+ d = dim // num_heads
654
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
655
+
656
+
657
+ freqs = torch.cat([
658
+ rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ), #44
659
+ rope_params(1024, 2 * (d // 6)), #42
660
+ rope_params(1024, 2 * (d // 6)) #42
661
+ ],dim=1)
662
+
663
+ return freqs
664
+
665
  def forward(
666
  self,
667
  x,
 
670
  seq_len,
671
  clip_fea=None,
672
  y=None,
673
+ freqs = None,
674
  pipeline = None,
675
  ):
676
  r"""
 
698
  assert clip_fea is not None and y is not None
699
  # params
700
  device = self.patch_embedding.weight.device
701
+ if freqs.device != device:
702
+ freqs = freqs.to(device)
703
 
704
  if y is not None:
705
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
 
743
  e=e0,
744
  seq_lens=seq_lens,
745
  grid_sizes=grid_sizes,
746
+ freqs=freqs,
747
  context=context,
748
  context_lens=context_lens)
749
 
wan/text2video.py CHANGED
@@ -128,7 +128,8 @@ class WanT2V:
128
  n_prompt="",
129
  seed=-1,
130
  offload_model=True,
131
- callback = None
 
132
  ):
133
  r"""
134
  Generates video frames from text prompt using diffusion process.
@@ -209,77 +210,84 @@ class WanT2V:
209
  no_sync = getattr(self.model, 'no_sync', noop_no_sync)
210
 
211
  # evaluation mode
212
- with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
213
-
214
- if sample_solver == 'unipc':
215
- sample_scheduler = FlowUniPCMultistepScheduler(
216
- num_train_timesteps=self.num_train_timesteps,
217
- shift=1,
218
- use_dynamic_shifting=False)
219
- sample_scheduler.set_timesteps(
220
- sampling_steps, device=self.device, shift=shift)
221
- timesteps = sample_scheduler.timesteps
222
- elif sample_solver == 'dpm++':
223
- sample_scheduler = FlowDPMSolverMultistepScheduler(
224
- num_train_timesteps=self.num_train_timesteps,
225
- shift=1,
226
- use_dynamic_shifting=False)
227
- sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
228
- timesteps, _ = retrieve_timesteps(
229
- sample_scheduler,
230
- device=self.device,
231
- sigmas=sampling_sigmas)
232
- else:
233
- raise NotImplementedError("Unsupported solver.")
234
-
235
- # sample videos
236
- latents = noise
237
-
238
- arg_c = {'context': context, 'seq_len': seq_len, 'pipeline': self}
239
- arg_null = {'context': context_null, 'seq_len': seq_len, 'pipeline': self}
240
-
241
- if callback != None:
242
- callback(-1, None)
243
- self._interrupt = False
244
- for i, t in enumerate(tqdm(timesteps)):
245
- latent_model_input = latents
246
- timestep = [t]
247
-
248
- timestep = torch.stack(timestep)
249
-
250
- # self.model.to(self.device)
251
- noise_pred_cond = self.model(
252
- latent_model_input, t=timestep, **arg_c)[0]
253
- if self._interrupt:
254
- return None
255
- noise_pred_uncond = self.model(
256
- latent_model_input, t=timestep, **arg_null)[0]
257
- if self._interrupt:
258
- return None
259
-
260
- del latent_model_input
261
- noise_pred = noise_pred_uncond + guide_scale * (
262
- noise_pred_cond - noise_pred_uncond)
263
- del noise_pred_uncond
264
-
265
- temp_x0 = sample_scheduler.step(
266
- noise_pred.unsqueeze(0),
267
- t,
268
- latents[0].unsqueeze(0),
269
- return_dict=False,
270
- generator=seed_g)[0]
271
- latents = [temp_x0.squeeze(0)]
272
- del temp_x0
273
-
274
- if callback is not None:
275
- callback(i, latents)
276
-
277
- x0 = latents
278
- if offload_model:
279
- self.model.cpu()
280
- torch.cuda.empty_cache()
281
- if self.rank == 0:
282
- videos = self.vae.decode(x0)
 
 
 
 
 
 
 
283
 
284
 
285
  del noise, latents
 
128
  n_prompt="",
129
  seed=-1,
130
  offload_model=True,
131
+ callback = None,
132
+ enable_RIFLEx = None
133
  ):
134
  r"""
135
  Generates video frames from text prompt using diffusion process.
 
210
  no_sync = getattr(self.model, 'no_sync', noop_no_sync)
211
 
212
  # evaluation mode
213
+
214
+ if sample_solver == 'unipc':
215
+ sample_scheduler = FlowUniPCMultistepScheduler(
216
+ num_train_timesteps=self.num_train_timesteps,
217
+ shift=1,
218
+ use_dynamic_shifting=False)
219
+ sample_scheduler.set_timesteps(
220
+ sampling_steps, device=self.device, shift=shift)
221
+ timesteps = sample_scheduler.timesteps
222
+ elif sample_solver == 'dpm++':
223
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
224
+ num_train_timesteps=self.num_train_timesteps,
225
+ shift=1,
226
+ use_dynamic_shifting=False)
227
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
228
+ timesteps, _ = retrieve_timesteps(
229
+ sample_scheduler,
230
+ device=self.device,
231
+ sigmas=sampling_sigmas)
232
+ else:
233
+ raise NotImplementedError("Unsupported solver.")
234
+
235
+ # sample videos
236
+ latents = noise
237
+
238
+ # from .modules.model import identify_k
239
+ # for nf in range(20, 50):
240
+ # k, N_k = identify_k(10000, 44, 26)
241
+ # print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
242
+
243
+ freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
244
+
245
+ arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
246
+ arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
247
+
248
+
249
+ if callback != None:
250
+ callback(-1, None)
251
+ self._interrupt = False
252
+ for i, t in enumerate(tqdm(timesteps)):
253
+ latent_model_input = latents
254
+ timestep = [t]
255
+
256
+ timestep = torch.stack(timestep)
257
+
258
+ # self.model.to(self.device)
259
+ noise_pred_cond = self.model(
260
+ latent_model_input, t=timestep, **arg_c)[0]
261
+ if self._interrupt:
262
+ return None
263
+ noise_pred_uncond = self.model(
264
+ latent_model_input, t=timestep, **arg_null)[0]
265
+ if self._interrupt:
266
+ return None
267
+
268
+ del latent_model_input
269
+ noise_pred = noise_pred_uncond + guide_scale * (
270
+ noise_pred_cond - noise_pred_uncond)
271
+ del noise_pred_uncond
272
+
273
+ temp_x0 = sample_scheduler.step(
274
+ noise_pred.unsqueeze(0),
275
+ t,
276
+ latents[0].unsqueeze(0),
277
+ return_dict=False,
278
+ generator=seed_g)[0]
279
+ latents = [temp_x0.squeeze(0)]
280
+ del temp_x0
281
+
282
+ if callback is not None:
283
+ callback(i, latents)
284
+
285
+ x0 = latents
286
+ if offload_model:
287
+ self.model.cpu()
288
+ torch.cuda.empty_cache()
289
+ if self.rank == 0:
290
+ videos = self.vae.decode(x0)
291
 
292
 
293
  del noise, latents