Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
fcbb34f
1
Parent(s):
d29010d
Added RIFLEx support
Browse files- README.md +23 -3
- gradio/i2v_14B_singleGPU.py +1 -1
- gradio_server.py +57 -27
- wan/image2video.py +98 -93
- wan/modules/model.py +72 -12
- wan/text2video.py +80 -72
README.md
CHANGED
@@ -19,7 +19,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|
19 |
|
20 |
## 🔥 Latest News!!
|
21 |
|
22 |
-
* Mar 03, 2025: 👋 Wan2.1GP DeepBeepMeep
|
|
|
23 |
* Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
|
24 |
* Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
|
25 |
|
@@ -35,7 +36,6 @@ This version has the following improvements over the original Alibaba model:
|
|
35 |
- Improved gradio interface with progression bar and more options
|
36 |
- Multiples prompts / multiple generations per prompt
|
37 |
- Support multiple pretrained Loras with 32 GB of RAM or less
|
38 |
-
- Switch easily between Hunyuan and Fast Hunyuan models and quantized / non quantized models
|
39 |
- Much simpler installation
|
40 |
|
41 |
|
@@ -105,10 +105,28 @@ pip install https://github.com/deepbeepmeep/SageAttention/raw/refs/heads/main/re
|
|
105 |
## Run the application
|
106 |
|
107 |
### Run a Gradio Server on port 7860 (recommended)
|
|
|
|
|
108 |
```bash
|
109 |
python gradio_server.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
```
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
### Loras support
|
114 |
|
@@ -131,7 +149,8 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
|
131 |
|
132 |
|
133 |
### Command line parameters for Gradio Server
|
134 |
-
--
|
|
|
135 |
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
|
136 |
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
|
137 |
--lora-preset preset : name of preset gile (without the extension) to preload
|
@@ -141,6 +160,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
|
141 |
--open-browser : open automatically Browser when launching Gradio Server\
|
142 |
--compile : turn on pytorch compilation\
|
143 |
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
|
|
144 |
|
145 |
### Profiles (for power users only)
|
146 |
You can choose between 5 profiles, these will try to leverage the most your hardware, but have little impact for HunyuanVideo GP:
|
|
|
19 |
|
20 |
## 🔥 Latest News!!
|
21 |
|
22 |
+
* Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep brings: Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
|
23 |
+
|
24 |
* Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
|
25 |
* Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
|
26 |
|
|
|
36 |
- Improved gradio interface with progression bar and more options
|
37 |
- Multiples prompts / multiple generations per prompt
|
38 |
- Support multiple pretrained Loras with 32 GB of RAM or less
|
|
|
39 |
- Much simpler installation
|
40 |
|
41 |
|
|
|
105 |
## Run the application
|
106 |
|
107 |
### Run a Gradio Server on port 7860 (recommended)
|
108 |
+
|
109 |
+
To run the text to video generator (in Low VRAM mode):
|
110 |
```bash
|
111 |
python gradio_server.py
|
112 |
+
#or
|
113 |
+
python gradio_server.py --t2v
|
114 |
+
|
115 |
+
```
|
116 |
+
|
117 |
+
To run the image to video generator (in Low VRAM mode):
|
118 |
+
```bash
|
119 |
+
python gradio_server.py --i2v
|
120 |
```
|
121 |
|
122 |
+
Within the application you can configure which video generator will be launched without specifying a command line switch.
|
123 |
+
|
124 |
+
To run the application while loading entirely the diffusion model in VRAM (slightly faster but requires 24 GB of VRAM for a 8 bits quantized 14B model )
|
125 |
+
```bash
|
126 |
+
python gradio_server.py --profile 3
|
127 |
+
```
|
128 |
+
Please note that diffusion model of Wan2.1GP is extremely VRAM optimized and this will greatly benefit low VRAM systems since the diffusion / denoising step is the longest part of the generation process. However, the VAE encoder (at the beginning of a image 2 video process) and the VAE decoder (at the end of any video process) is only 20% lighter and it will require temporarly 22 GB of VRAM for a 720p generation and 12 GB of VRAM for a 480p generation. Therefore if you have less than these numbers, you may experience slow down at the begining and at the end of the generation process due to pytorch VRAM offloading.
|
129 |
+
|
130 |
|
131 |
### Loras support
|
132 |
|
|
|
149 |
|
150 |
|
151 |
### Command line parameters for Gradio Server
|
152 |
+
--i2v : launch the image to video generator\
|
153 |
+
--t2v : launch the text to video generator\
|
154 |
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
|
155 |
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
|
156 |
--lora-preset preset : name of preset gile (without the extension) to preload
|
|
|
160 |
--open-browser : open automatically Browser when launching Gradio Server\
|
161 |
--compile : turn on pytorch compilation\
|
162 |
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
163 |
+
--profile no : default (4) : no of profile between 1 and 5\
|
164 |
|
165 |
### Profiles (for power users only)
|
166 |
You can choose between 5 profiles, these will try to leverage the most your hardware, but have little impact for HunyuanVideo GP:
|
gradio/i2v_14B_singleGPU.py
CHANGED
@@ -288,7 +288,7 @@ if __name__ == '__main__':
|
|
288 |
# resolution = '720P'
|
289 |
resolution = '480P'
|
290 |
|
291 |
-
|
292 |
|
293 |
print("Step1: Init prompt_expander...", end='', flush=True)
|
294 |
if args.prompt_extend_method == "dashscope":
|
|
|
288 |
# resolution = '720P'
|
289 |
resolution = '480P'
|
290 |
|
291 |
+
load_i2v_model(resolution)
|
292 |
|
293 |
print("Step1: Init prompt_expander...", end='', flush=True)
|
294 |
if args.prompt_extend_method == "dashscope":
|
gradio_server.py
CHANGED
@@ -18,6 +18,8 @@ from wan.utils.utils import cache_video
|
|
18 |
from wan.modules.attention import get_attention_modes
|
19 |
import torch
|
20 |
import gc
|
|
|
|
|
21 |
|
22 |
def _parse_args():
|
23 |
parser = argparse.ArgumentParser(
|
@@ -752,7 +754,7 @@ def generate_video(
|
|
752 |
state["in_progress"] = True
|
753 |
state["selected"] = 0
|
754 |
|
755 |
-
|
756 |
# VAE Tiling
|
757 |
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
758 |
|
@@ -810,7 +812,8 @@ def generate_video(
|
|
810 |
n_prompt=negative_prompt,
|
811 |
seed=seed,
|
812 |
offload_model=False,
|
813 |
-
callback=callback
|
|
|
814 |
)
|
815 |
|
816 |
else:
|
@@ -824,9 +827,10 @@ def generate_video(
|
|
824 |
n_prompt=negative_prompt,
|
825 |
seed=seed,
|
826 |
offload_model=False,
|
827 |
-
callback=callback
|
|
|
828 |
)
|
829 |
-
except:
|
830 |
gen_in_progress = False
|
831 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
832 |
os.remove(temp_filename)
|
@@ -838,7 +842,21 @@ def generate_video(
|
|
838 |
|
839 |
gc.collect()
|
840 |
torch.cuda.empty_cache()
|
841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
|
843 |
|
844 |
if samples != None:
|
@@ -949,7 +967,7 @@ def create_demo():
|
|
949 |
else:
|
950 |
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
951 |
|
952 |
-
gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B
|
953 |
|
954 |
if use_image2video and False:
|
955 |
pass
|
@@ -959,7 +977,7 @@ def create_demo():
|
|
959 |
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
960 |
gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
|
961 |
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
|
962 |
-
|
963 |
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
964 |
|
965 |
|
@@ -1092,25 +1110,37 @@ def create_demo():
|
|
1092 |
|
1093 |
|
1094 |
with gr.Row():
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1114 |
|
1115 |
with gr.Row():
|
1116 |
with gr.Column():
|
@@ -1125,7 +1155,7 @@ def create_demo():
|
|
1125 |
with gr.Row(visible= len(loras)>0):
|
1126 |
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
|
1127 |
with gr.Column(scale=5):
|
1128 |
-
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=
|
1129 |
with gr.Column(scale=1):
|
1130 |
# with gr.Column():
|
1131 |
with gr.Row(height=17):
|
|
|
18 |
from wan.modules.attention import get_attention_modes
|
19 |
import torch
|
20 |
import gc
|
21 |
+
import traceback
|
22 |
+
|
23 |
|
24 |
def _parse_args():
|
25 |
parser = argparse.ArgumentParser(
|
|
|
754 |
state["in_progress"] = True
|
755 |
state["selected"] = 0
|
756 |
|
757 |
+
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (5* 16) or RIFLEx_setting == 1
|
758 |
# VAE Tiling
|
759 |
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
760 |
|
|
|
812 |
n_prompt=negative_prompt,
|
813 |
seed=seed,
|
814 |
offload_model=False,
|
815 |
+
callback=callback,
|
816 |
+
enable_RIFLEx = enable_RIFLEx
|
817 |
)
|
818 |
|
819 |
else:
|
|
|
827 |
n_prompt=negative_prompt,
|
828 |
seed=seed,
|
829 |
offload_model=False,
|
830 |
+
callback=callback,
|
831 |
+
enable_RIFLEx = enable_RIFLEx
|
832 |
)
|
833 |
+
except Exception as e:
|
834 |
gen_in_progress = False
|
835 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
836 |
os.remove(temp_filename)
|
|
|
842 |
|
843 |
gc.collect()
|
844 |
torch.cuda.empty_cache()
|
845 |
+
s = str(e)
|
846 |
+
keyword_list = ["vram", "VRAM", "memory", "triton", "cuda", "allocat"]
|
847 |
+
VRAM_crash= False
|
848 |
+
if any( keyword in s for keyword in keyword_list):
|
849 |
+
VRAM_crash = True
|
850 |
+
else:
|
851 |
+
stack = traceback.extract_stack(f=None, limit=5)
|
852 |
+
for frame in stack:
|
853 |
+
if any( keyword in frame.name for keyword in keyword_list):
|
854 |
+
VRAM_crash = True
|
855 |
+
break
|
856 |
+
if VRAM_crash:
|
857 |
+
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
858 |
+
else:
|
859 |
+
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
860 |
|
861 |
|
862 |
if samples != None:
|
|
|
967 |
else:
|
968 |
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
969 |
|
970 |
+
gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B>, the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
|
971 |
|
972 |
if use_image2video and False:
|
973 |
pass
|
|
|
977 |
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
978 |
gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
|
979 |
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
|
980 |
+
gr.Markdown("It is not recommmended to generate a video longer than 8s even if there is still some VRAM left as some artifact may appear")
|
981 |
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
982 |
|
983 |
|
|
|
1110 |
|
1111 |
|
1112 |
with gr.Row():
|
1113 |
+
if use_image2video:
|
1114 |
+
resolution = gr.Dropdown(
|
1115 |
+
choices=[
|
1116 |
+
# 720p
|
1117 |
+
("720p", "1280x720"),
|
1118 |
+
("480p", "832x480"),
|
1119 |
+
],
|
1120 |
+
value="832x480",
|
1121 |
+
label="Resolution (video will have the same height / width ratio than the original image)"
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
else:
|
1125 |
+
resolution = gr.Dropdown(
|
1126 |
+
choices=[
|
1127 |
+
# 720p
|
1128 |
+
("1280x720 (16:9, 720p)", "1280x720"),
|
1129 |
+
("720x1280 (9:16, 720p)", "720x1280"),
|
1130 |
+
("1024x1024 (4:3, 720p)", "1024x024"),
|
1131 |
+
# ("832x1104 (3:4, 720p)", "832x1104"),
|
1132 |
+
# ("960x960 (1:1, 720p)", "960x960"),
|
1133 |
+
# 480p
|
1134 |
+
# ("960x544 (16:9, 480p)", "960x544"),
|
1135 |
+
("832x480 (16:9, 480p)", "832x480"),
|
1136 |
+
("480x832 (9:16, 480p)", "480x832"),
|
1137 |
+
# ("832x624 (4:3, 540p)", "832x624"),
|
1138 |
+
# ("624x832 (3:4, 540p)", "624x832"),
|
1139 |
+
# ("720x720 (1:1, 540p)", "720x720"),
|
1140 |
+
],
|
1141 |
+
value="832x480",
|
1142 |
+
label="Resolution"
|
1143 |
+
)
|
1144 |
|
1145 |
with gr.Row():
|
1146 |
with gr.Column():
|
|
|
1155 |
with gr.Row(visible= len(loras)>0):
|
1156 |
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
|
1157 |
with gr.Column(scale=5):
|
1158 |
+
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
|
1159 |
with gr.Column(scale=1):
|
1160 |
# with gr.Column():
|
1161 |
with gr.Row(height=17):
|
wan/image2video.py
CHANGED
@@ -143,7 +143,9 @@ class WanI2V:
|
|
143 |
n_prompt="",
|
144 |
seed=-1,
|
145 |
offload_model=True,
|
146 |
-
callback = None
|
|
|
|
|
147 |
):
|
148 |
r"""
|
149 |
Generates video frames from input image and text prompt using diffusion process.
|
@@ -262,104 +264,107 @@ class WanI2V:
|
|
262 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
263 |
|
264 |
# evaluation mode
|
265 |
-
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
266 |
-
|
267 |
-
if sample_solver == 'unipc':
|
268 |
-
sample_scheduler = FlowUniPCMultistepScheduler(
|
269 |
-
num_train_timesteps=self.num_train_timesteps,
|
270 |
-
shift=1,
|
271 |
-
use_dynamic_shifting=False)
|
272 |
-
sample_scheduler.set_timesteps(
|
273 |
-
sampling_steps, device=self.device, shift=shift)
|
274 |
-
timesteps = sample_scheduler.timesteps
|
275 |
-
elif sample_solver == 'dpm++':
|
276 |
-
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
277 |
-
num_train_timesteps=self.num_train_timesteps,
|
278 |
-
shift=1,
|
279 |
-
use_dynamic_shifting=False)
|
280 |
-
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
281 |
-
timesteps, _ = retrieve_timesteps(
|
282 |
-
sample_scheduler,
|
283 |
-
device=self.device,
|
284 |
-
sigmas=sampling_sigmas)
|
285 |
-
else:
|
286 |
-
raise NotImplementedError("Unsupported solver.")
|
287 |
-
|
288 |
-
# sample videos
|
289 |
-
latent = noise
|
290 |
-
|
291 |
-
arg_c = {
|
292 |
-
'context': [context[0]],
|
293 |
-
'clip_fea': clip_context,
|
294 |
-
'seq_len': max_seq_len,
|
295 |
-
'y': [y],
|
296 |
-
'pipeline' : self
|
297 |
-
}
|
298 |
-
|
299 |
-
arg_null = {
|
300 |
-
'context': context_null,
|
301 |
-
'clip_fea': clip_context,
|
302 |
-
'seq_len': max_seq_len,
|
303 |
-
'y': [y],
|
304 |
-
'pipeline' : self
|
305 |
-
}
|
306 |
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
self._interrupt = False
|
315 |
-
for i, t in enumerate(tqdm(timesteps)):
|
316 |
-
latent_model_input = [latent.to(self.device)]
|
317 |
-
timestep = [t]
|
318 |
-
|
319 |
-
timestep = torch.stack(timestep).to(self.device)
|
320 |
-
|
321 |
-
noise_pred_cond = self.model(
|
322 |
-
latent_model_input, t=timestep, **arg_c)[0]
|
323 |
-
if self._interrupt:
|
324 |
-
return None
|
325 |
-
if offload_model:
|
326 |
-
torch.cuda.empty_cache()
|
327 |
-
noise_pred_uncond = self.model(
|
328 |
-
latent_model_input, t=timestep, **arg_null)[0]
|
329 |
-
if self._interrupt:
|
330 |
-
return None
|
331 |
-
del latent_model_input
|
332 |
-
if offload_model:
|
333 |
-
torch.cuda.empty_cache()
|
334 |
-
noise_pred = noise_pred_uncond + guide_scale * (
|
335 |
-
noise_pred_cond - noise_pred_uncond)
|
336 |
-
del noise_pred_uncond
|
337 |
-
|
338 |
-
latent = latent.to(
|
339 |
-
torch.device('cpu') if offload_model else self.device)
|
340 |
-
|
341 |
-
temp_x0 = sample_scheduler.step(
|
342 |
-
noise_pred.unsqueeze(0),
|
343 |
-
t,
|
344 |
-
latent.unsqueeze(0),
|
345 |
-
return_dict=False,
|
346 |
-
generator=seed_g)[0]
|
347 |
-
latent = temp_x0.squeeze(0)
|
348 |
-
del temp_x0
|
349 |
-
del timestep
|
350 |
-
|
351 |
-
if callback is not None:
|
352 |
-
callback(i, latent)
|
353 |
-
|
354 |
-
|
355 |
-
x0 = [latent.to(self.device)]
|
356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
if offload_model:
|
358 |
-
self.model.cpu()
|
359 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
-
|
362 |
-
|
363 |
|
364 |
del noise, latent
|
365 |
del sample_scheduler
|
|
|
143 |
n_prompt="",
|
144 |
seed=-1,
|
145 |
offload_model=True,
|
146 |
+
callback = None,
|
147 |
+
enable_RIFLEx = False
|
148 |
+
|
149 |
):
|
150 |
r"""
|
151 |
Generates video frames from input image and text prompt using diffusion process.
|
|
|
264 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
265 |
|
266 |
# evaluation mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
+
if sample_solver == 'unipc':
|
269 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
270 |
+
num_train_timesteps=self.num_train_timesteps,
|
271 |
+
shift=1,
|
272 |
+
use_dynamic_shifting=False)
|
273 |
+
sample_scheduler.set_timesteps(
|
274 |
+
sampling_steps, device=self.device, shift=shift)
|
275 |
+
timesteps = sample_scheduler.timesteps
|
276 |
+
elif sample_solver == 'dpm++':
|
277 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
278 |
+
num_train_timesteps=self.num_train_timesteps,
|
279 |
+
shift=1,
|
280 |
+
use_dynamic_shifting=False)
|
281 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
282 |
+
timesteps, _ = retrieve_timesteps(
|
283 |
+
sample_scheduler,
|
284 |
+
device=self.device,
|
285 |
+
sigmas=sampling_sigmas)
|
286 |
+
else:
|
287 |
+
raise NotImplementedError("Unsupported solver.")
|
288 |
+
|
289 |
+
# sample videos
|
290 |
+
latent = noise
|
291 |
+
|
292 |
+
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
|
293 |
+
|
294 |
+
arg_c = {
|
295 |
+
'context': [context[0]],
|
296 |
+
'clip_fea': clip_context,
|
297 |
+
'seq_len': max_seq_len,
|
298 |
+
'y': [y],
|
299 |
+
'freqs' : freqs,
|
300 |
+
'pipeline' : self
|
301 |
+
}
|
302 |
+
|
303 |
+
arg_null = {
|
304 |
+
'context': context_null,
|
305 |
+
'clip_fea': clip_context,
|
306 |
+
'seq_len': max_seq_len,
|
307 |
+
'y': [y],
|
308 |
+
'freqs' : freqs,
|
309 |
+
'pipeline' : self
|
310 |
+
}
|
311 |
+
|
312 |
+
if offload_model:
|
313 |
+
torch.cuda.empty_cache()
|
314 |
+
|
315 |
+
# self.model.to(self.device)
|
316 |
+
if callback != None:
|
317 |
+
callback(-1, None)
|
318 |
|
319 |
+
self._interrupt = False
|
320 |
+
for i, t in enumerate(tqdm(timesteps)):
|
321 |
+
latent_model_input = [latent.to(self.device)]
|
322 |
+
timestep = [t]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
+
timestep = torch.stack(timestep).to(self.device)
|
325 |
+
|
326 |
+
noise_pred_cond = self.model(
|
327 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
328 |
+
if self._interrupt:
|
329 |
+
return None
|
330 |
+
if offload_model:
|
331 |
+
torch.cuda.empty_cache()
|
332 |
+
noise_pred_uncond = self.model(
|
333 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
334 |
+
if self._interrupt:
|
335 |
+
return None
|
336 |
+
del latent_model_input
|
337 |
if offload_model:
|
|
|
338 |
torch.cuda.empty_cache()
|
339 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
340 |
+
noise_pred_cond - noise_pred_uncond)
|
341 |
+
del noise_pred_uncond
|
342 |
+
|
343 |
+
latent = latent.to(
|
344 |
+
torch.device('cpu') if offload_model else self.device)
|
345 |
+
|
346 |
+
temp_x0 = sample_scheduler.step(
|
347 |
+
noise_pred.unsqueeze(0),
|
348 |
+
t,
|
349 |
+
latent.unsqueeze(0),
|
350 |
+
return_dict=False,
|
351 |
+
generator=seed_g)[0]
|
352 |
+
latent = temp_x0.squeeze(0)
|
353 |
+
del temp_x0
|
354 |
+
del timestep
|
355 |
+
|
356 |
+
if callback is not None:
|
357 |
+
callback(i, latent)
|
358 |
+
|
359 |
+
|
360 |
+
x0 = [latent.to(self.device)]
|
361 |
+
|
362 |
+
if offload_model:
|
363 |
+
self.model.cpu()
|
364 |
+
torch.cuda.empty_cache()
|
365 |
|
366 |
+
if self.rank == 0:
|
367 |
+
videos = self.vae.decode(x0)
|
368 |
|
369 |
del noise, latent
|
370 |
del sample_scheduler
|
wan/modules/model.py
CHANGED
@@ -6,6 +6,8 @@ import torch.cuda.amp as amp
|
|
6 |
import torch.nn as nn
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
|
9 |
|
10 |
from .attention import pay_attention
|
11 |
|
@@ -25,7 +27,49 @@ def sinusoidal_embedding_1d(dim, position):
|
|
25 |
return x
|
26 |
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def rope_params(max_seq_len, dim, theta=10000):
|
30 |
assert dim % 2 == 0
|
31 |
freqs = torch.outer(
|
@@ -588,14 +632,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
588 |
self.head = Head(dim, out_dim, patch_size, eps)
|
589 |
|
590 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
591 |
-
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
592 |
-
d = dim // num_heads
|
593 |
-
self.freqs = torch.cat([
|
594 |
-
rope_params(1024, d - 4 * (d // 6)),
|
595 |
-
rope_params(1024, 2 * (d // 6)),
|
596 |
-
rope_params(1024, 2 * (d // 6))
|
597 |
-
],
|
598 |
-
dim=1)
|
599 |
|
600 |
if model_type == 'i2v':
|
601 |
self.img_emb = MLPProj(1280, dim)
|
@@ -603,6 +639,29 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
603 |
# initialize weights
|
604 |
self.init_weights()
|
605 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
def forward(
|
607 |
self,
|
608 |
x,
|
@@ -611,6 +670,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
611 |
seq_len,
|
612 |
clip_fea=None,
|
613 |
y=None,
|
|
|
614 |
pipeline = None,
|
615 |
):
|
616 |
r"""
|
@@ -638,8 +698,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
638 |
assert clip_fea is not None and y is not None
|
639 |
# params
|
640 |
device = self.patch_embedding.weight.device
|
641 |
-
if
|
642 |
-
|
643 |
|
644 |
if y is not None:
|
645 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
@@ -683,7 +743,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
683 |
e=e0,
|
684 |
seq_lens=seq_lens,
|
685 |
grid_sizes=grid_sizes,
|
686 |
-
freqs=
|
687 |
context=context,
|
688 |
context_lens=context_lens)
|
689 |
|
|
|
6 |
import torch.nn as nn
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
from diffusers.models.modeling_utils import ModelMixin
|
9 |
+
import numpy as np
|
10 |
+
from typing import Union,Optional
|
11 |
|
12 |
from .attention import pay_attention
|
13 |
|
|
|
27 |
return x
|
28 |
|
29 |
|
30 |
+
|
31 |
+
|
32 |
+
def identify_k( b: float, d: int, N: int):
|
33 |
+
"""
|
34 |
+
This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
b (`float`): The base frequency for RoPE.
|
38 |
+
d (`int`): Dimension of the frequency tensor
|
39 |
+
N (`int`): the first observed repetition frame in latent space
|
40 |
+
Returns:
|
41 |
+
k (`int`): the index of intrinsic frequency component
|
42 |
+
N_k (`int`): the period of intrinsic frequency component in latent space
|
43 |
+
Example:
|
44 |
+
In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
|
45 |
+
k, N_k = identify_k(b=256, d=16, N=48)
|
46 |
+
In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
|
47 |
+
"""
|
48 |
+
|
49 |
+
# Compute the period of each frequency in RoPE according to Eq.(4)
|
50 |
+
periods = []
|
51 |
+
for j in range(1, d // 2 + 1):
|
52 |
+
theta_j = 1.0 / (b ** (2 * (j - 1) / d))
|
53 |
+
N_j = round(2 * torch.pi / theta_j)
|
54 |
+
periods.append(N_j)
|
55 |
+
|
56 |
+
# Identify the intrinsic frequency whose period is closed to N(see Eq.(7))
|
57 |
+
diffs = [abs(N_j - N) for N_j in periods]
|
58 |
+
k = diffs.index(min(diffs)) + 1
|
59 |
+
N_k = periods[k-1]
|
60 |
+
return k, N_k
|
61 |
+
|
62 |
+
def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
|
63 |
+
assert dim % 2 == 0
|
64 |
+
exponents = torch.arange(0, dim, 2, dtype=torch.float64).div(dim)
|
65 |
+
inv_theta_pow = 1.0 / torch.pow(theta, exponents)
|
66 |
+
|
67 |
+
inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
|
68 |
+
|
69 |
+
freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
|
70 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
71 |
+
return freqs
|
72 |
+
|
73 |
def rope_params(max_seq_len, dim, theta=10000):
|
74 |
assert dim % 2 == 0
|
75 |
freqs = torch.outer(
|
|
|
632 |
self.head = Head(dim, out_dim, patch_size, eps)
|
633 |
|
634 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
|
636 |
if model_type == 'i2v':
|
637 |
self.img_emb = MLPProj(1280, dim)
|
|
|
639 |
# initialize weights
|
640 |
self.init_weights()
|
641 |
|
642 |
+
|
643 |
+
# self.freqs = torch.cat([
|
644 |
+
# rope_params(1024, d - 4 * (d // 6)), #44
|
645 |
+
# rope_params(1024, 2 * (d // 6)), #42
|
646 |
+
# rope_params(1024, 2 * (d // 6)) #42
|
647 |
+
# ],dim=1)
|
648 |
+
|
649 |
+
|
650 |
+
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None):
|
651 |
+
dim = self.dim
|
652 |
+
num_heads = self.num_heads
|
653 |
+
d = dim // num_heads
|
654 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
655 |
+
|
656 |
+
|
657 |
+
freqs = torch.cat([
|
658 |
+
rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ), #44
|
659 |
+
rope_params(1024, 2 * (d // 6)), #42
|
660 |
+
rope_params(1024, 2 * (d // 6)) #42
|
661 |
+
],dim=1)
|
662 |
+
|
663 |
+
return freqs
|
664 |
+
|
665 |
def forward(
|
666 |
self,
|
667 |
x,
|
|
|
670 |
seq_len,
|
671 |
clip_fea=None,
|
672 |
y=None,
|
673 |
+
freqs = None,
|
674 |
pipeline = None,
|
675 |
):
|
676 |
r"""
|
|
|
698 |
assert clip_fea is not None and y is not None
|
699 |
# params
|
700 |
device = self.patch_embedding.weight.device
|
701 |
+
if freqs.device != device:
|
702 |
+
freqs = freqs.to(device)
|
703 |
|
704 |
if y is not None:
|
705 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
|
|
743 |
e=e0,
|
744 |
seq_lens=seq_lens,
|
745 |
grid_sizes=grid_sizes,
|
746 |
+
freqs=freqs,
|
747 |
context=context,
|
748 |
context_lens=context_lens)
|
749 |
|
wan/text2video.py
CHANGED
@@ -128,7 +128,8 @@ class WanT2V:
|
|
128 |
n_prompt="",
|
129 |
seed=-1,
|
130 |
offload_model=True,
|
131 |
-
callback = None
|
|
|
132 |
):
|
133 |
r"""
|
134 |
Generates video frames from text prompt using diffusion process.
|
@@ -209,77 +210,84 @@ class WanT2V:
|
|
209 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
210 |
|
211 |
# evaluation mode
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
if
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
|
285 |
del noise, latents
|
|
|
128 |
n_prompt="",
|
129 |
seed=-1,
|
130 |
offload_model=True,
|
131 |
+
callback = None,
|
132 |
+
enable_RIFLEx = None
|
133 |
):
|
134 |
r"""
|
135 |
Generates video frames from text prompt using diffusion process.
|
|
|
210 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
211 |
|
212 |
# evaluation mode
|
213 |
+
|
214 |
+
if sample_solver == 'unipc':
|
215 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
216 |
+
num_train_timesteps=self.num_train_timesteps,
|
217 |
+
shift=1,
|
218 |
+
use_dynamic_shifting=False)
|
219 |
+
sample_scheduler.set_timesteps(
|
220 |
+
sampling_steps, device=self.device, shift=shift)
|
221 |
+
timesteps = sample_scheduler.timesteps
|
222 |
+
elif sample_solver == 'dpm++':
|
223 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
224 |
+
num_train_timesteps=self.num_train_timesteps,
|
225 |
+
shift=1,
|
226 |
+
use_dynamic_shifting=False)
|
227 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
228 |
+
timesteps, _ = retrieve_timesteps(
|
229 |
+
sample_scheduler,
|
230 |
+
device=self.device,
|
231 |
+
sigmas=sampling_sigmas)
|
232 |
+
else:
|
233 |
+
raise NotImplementedError("Unsupported solver.")
|
234 |
+
|
235 |
+
# sample videos
|
236 |
+
latents = noise
|
237 |
+
|
238 |
+
# from .modules.model import identify_k
|
239 |
+
# for nf in range(20, 50):
|
240 |
+
# k, N_k = identify_k(10000, 44, 26)
|
241 |
+
# print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
|
242 |
+
|
243 |
+
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
|
244 |
+
|
245 |
+
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
246 |
+
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
247 |
+
|
248 |
+
|
249 |
+
if callback != None:
|
250 |
+
callback(-1, None)
|
251 |
+
self._interrupt = False
|
252 |
+
for i, t in enumerate(tqdm(timesteps)):
|
253 |
+
latent_model_input = latents
|
254 |
+
timestep = [t]
|
255 |
+
|
256 |
+
timestep = torch.stack(timestep)
|
257 |
+
|
258 |
+
# self.model.to(self.device)
|
259 |
+
noise_pred_cond = self.model(
|
260 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
261 |
+
if self._interrupt:
|
262 |
+
return None
|
263 |
+
noise_pred_uncond = self.model(
|
264 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
265 |
+
if self._interrupt:
|
266 |
+
return None
|
267 |
+
|
268 |
+
del latent_model_input
|
269 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
270 |
+
noise_pred_cond - noise_pred_uncond)
|
271 |
+
del noise_pred_uncond
|
272 |
+
|
273 |
+
temp_x0 = sample_scheduler.step(
|
274 |
+
noise_pred.unsqueeze(0),
|
275 |
+
t,
|
276 |
+
latents[0].unsqueeze(0),
|
277 |
+
return_dict=False,
|
278 |
+
generator=seed_g)[0]
|
279 |
+
latents = [temp_x0.squeeze(0)]
|
280 |
+
del temp_x0
|
281 |
+
|
282 |
+
if callback is not None:
|
283 |
+
callback(i, latents)
|
284 |
+
|
285 |
+
x0 = latents
|
286 |
+
if offload_model:
|
287 |
+
self.model.cpu()
|
288 |
+
torch.cuda.empty_cache()
|
289 |
+
if self.rank == 0:
|
290 |
+
videos = self.vae.decode(x0)
|
291 |
|
292 |
|
293 |
del noise, latents
|