wedyanessam commited on
Commit
206dfd7
·
verified ·
1 Parent(s): 5b038c7

Upload 40 files

Browse files
Files changed (41) hide show
  1. .gitattributes +4 -0
  2. FantasyTalking/.pre-commit-config.yaml +23 -0
  3. FantasyTalking/LICENSE +201 -0
  4. FantasyTalking/README.md +95 -0
  5. FantasyTalking/README_zh.md +94 -0
  6. FantasyTalking/app.py +314 -0
  7. FantasyTalking/assets/audios/woman.wav +3 -0
  8. FantasyTalking/assets/images/fig0_1_0.png +3 -0
  9. FantasyTalking/assets/images/woman.png +3 -0
  10. FantasyTalking/assets/overview.png +3 -0
  11. FantasyTalking/diffsynth/__init__.py +5 -0
  12. FantasyTalking/diffsynth/configs/__init__.py +0 -0
  13. FantasyTalking/diffsynth/configs/model_config.py +1577 -0
  14. FantasyTalking/diffsynth/data/__init__.py +1 -0
  15. FantasyTalking/diffsynth/data/video.py +188 -0
  16. FantasyTalking/diffsynth/models/__init__.py +1 -0
  17. FantasyTalking/diffsynth/models/downloader.py +124 -0
  18. FantasyTalking/diffsynth/models/model_manager.py +582 -0
  19. FantasyTalking/diffsynth/models/utils.py +217 -0
  20. FantasyTalking/diffsynth/models/wan_video_dit.py +998 -0
  21. FantasyTalking/diffsynth/models/wan_video_image_encoder.py +960 -0
  22. FantasyTalking/diffsynth/models/wan_video_text_encoder.py +289 -0
  23. FantasyTalking/diffsynth/models/wan_video_vae.py +948 -0
  24. FantasyTalking/diffsynth/pipelines/__init__.py +1 -0
  25. FantasyTalking/diffsynth/pipelines/base.py +173 -0
  26. FantasyTalking/diffsynth/pipelines/wan_video.py +389 -0
  27. FantasyTalking/diffsynth/prompters/__init__.py +1 -0
  28. FantasyTalking/diffsynth/prompters/base_prompter.py +69 -0
  29. FantasyTalking/diffsynth/prompters/wan_prompter.py +114 -0
  30. FantasyTalking/diffsynth/schedulers/__init__.py +3 -0
  31. FantasyTalking/diffsynth/schedulers/continuous_ode.py +61 -0
  32. FantasyTalking/diffsynth/schedulers/ddim.py +138 -0
  33. FantasyTalking/diffsynth/schedulers/flow_match.py +97 -0
  34. FantasyTalking/diffsynth/vram_management/__init__.py +1 -0
  35. FantasyTalking/diffsynth/vram_management/layers.py +177 -0
  36. FantasyTalking/infer.py +236 -0
  37. FantasyTalking/infer.sh +11 -0
  38. FantasyTalking/infer_24G.sh +12 -0
  39. FantasyTalking/model.py +228 -0
  40. FantasyTalking/requirements.txt +14 -0
  41. FantasyTalking/utils.py +52 -0
.gitattributes CHANGED
@@ -35,3 +35,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  TTS/IMG_6935.wav filter=lfs diff=lfs merge=lfs -text
37
  TTS_X/IMG_6935.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  TTS/IMG_6935.wav filter=lfs diff=lfs merge=lfs -text
37
  TTS_X/IMG_6935.wav filter=lfs diff=lfs merge=lfs -text
38
+ FantasyTalking/assets/audios/woman.wav filter=lfs diff=lfs merge=lfs -text
39
+ FantasyTalking/assets/images/fig0_1_0.png filter=lfs diff=lfs merge=lfs -text
40
+ FantasyTalking/assets/images/woman.png filter=lfs diff=lfs merge=lfs -text
41
+ FantasyTalking/assets/overview.png filter=lfs diff=lfs merge=lfs -text
FantasyTalking/.pre-commit-config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.4.0
4
+ hooks:
5
+ - id: check-added-large-files
6
+ - id: check-yaml
7
+ - id: end-of-file-fixer
8
+ - id: trailing-whitespace
9
+
10
+ - repo: https://github.com/psf/black
11
+ rev: 23.10.0
12
+ hooks:
13
+ - id: black
14
+
15
+ - repo: https://github.com/pycqa/flake8
16
+ rev: 6.1.0
17
+ hooks:
18
+ - id: flake8
19
+
20
+ - repo: https://github.com/pycqa/isort
21
+ rev: 5.12.0
22
+ hooks:
23
+ - id: isort
FantasyTalking/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright AMAP, Alibaba Group
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
FantasyTalking/README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [中文阅读](./README_zh.md)
2
+ # FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis
3
+
4
+ [![Home Page](https://img.shields.io/badge/Project-FantasyTalking-blue.svg)](https://fantasy-amap.github.io/fantasy-talking/)
5
+ [![arXiv](https://img.shields.io/badge/Arxiv-2504.04842-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2504.04842)
6
+ [![hf_paper](https://img.shields.io/badge/🤗-FantasyTalking-red.svg)](https://huggingface.co/acvlab/FantasyTalking)
7
+
8
+ ## 🔥 Latest News!!
9
+ * April 29, 2025: Our work is merged to [ComfyUI-Wan](https://github.com/kijai/ComfyUI-WanVideoWrapper) ! Thank [kijai](https://github.com/kijai) for the update 👏!
10
+ * April 28, 2025: We released the inference code and model weights for audio conditions.
11
+
12
+
13
+ ## Quickstart
14
+ ### 🛠️Installation
15
+
16
+ Clone the repo:
17
+
18
+ ```
19
+ git clone https://github.com/Fantasy-AMAP/fantasy-talking.git
20
+ cd fantasy-talking
21
+ ```
22
+
23
+ Install dependencies:
24
+ ```
25
+ # Ensure torch >= 2.0.0
26
+ pip install -r requirements.txt
27
+ # Optional to install flash_attn to accelerate attention computation
28
+ pip install flash_attn
29
+ ```
30
+
31
+ ### 🧱Model Download
32
+ | Models | Download Link | Notes |
33
+ | --------------|-------------------------------------------------------------------------------|-------------------------------|
34
+ | Wan2.1-I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Base model
35
+ | Wav2Vec | 🤗 [Huggingface](https://huggingface.co/facebook/wav2vec2-base-960h) 🤖 [ModelScope](https://modelscope.cn/models/AI-ModelScope/wav2vec2-base-960h) | Audio encoder
36
+ | FantasyTalking model | 🤗 [Huggingface](https://huggingface.co/acvlab/FantasyTalking/) 🤖 [ModelScope](https://www.modelscope.cn/models/amap_cvlab/FantasyTalking/) | Our audio condition weights
37
+
38
+ Download models using huggingface-cli:
39
+ ``` sh
40
+ pip install "huggingface_hub[cli]"
41
+ huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P --local-dir ./models/Wan2.1-I2V-14B-720P
42
+ huggingface-cli download facebook/wav2vec2-base-960h --local-dir ./models/wav2vec2-base-960h
43
+ huggingface-cli download acvlab/FantasyTalking fantasytalking_model.ckpt --local-dir ./models
44
+ ```
45
+
46
+ Download models using modelscope-cli:
47
+ ``` sh
48
+ pip install modelscope
49
+ modelscope download Wan-AI/Wan2.1-I2V-14B-720P --local_dir ./models/Wan2.1-I2V-14B-720P
50
+ modelscope download AI-ModelScope/wav2vec2-base-960h --local_dir ./models/wav2vec2-base-960h
51
+ modelscope download amap_cvlab/FantasyTalking fantasytalking_model.ckpt --local_dir ./models
52
+ ```
53
+
54
+ ### 🔑 Inference
55
+ ``` sh
56
+ python infer.py --image_path ./assets/images/woman.png --audio_path ./assets/audios/woman.wav
57
+ ```
58
+ You can control the character's behavior through the prompt. **The recommended range for prompt and audio cfg is [3-7]. You can increase the audio cfg to achieve more consistent lip-sync.**
59
+ ``` sh
60
+ python infer.py --image_path ./assets/images/woman.png --audio_path ./assets/audios/woman.wav --prompt "The person is speaking enthusiastically, with their hands continuously waving." --prompt_cfg_scale 5.0 --audio_cfg_scale 5.0
61
+ ```
62
+
63
+ We present a detailed table here. The model is tested on a single A100.(512x512, 81 frames).
64
+
65
+ |`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|
66
+ |-|-|-|-|
67
+ |torch.bfloat16|None (unlimited)|15.5s/it|40G|
68
+ |torch.bfloat16|7*10**9 (7B)|32.8s/it|20G|
69
+ |torch.bfloat16|0|42.6s/it|5G|
70
+
71
+ ### Gradio Demo
72
+ We construct an [online demo](https://huggingface.co/spaces/acvlab/FantasyTalking) in Huggingface.
73
+ For the local gradio demo, you can run:
74
+ ``` sh
75
+ pip install gradio spaces
76
+ python app.py
77
+ ```
78
+
79
+ ## 🧩 Community Works
80
+ We ❤️ contributions from the open-source community! If your work has improved FantasyTalking, please inform us.
81
+ Or you can directly e-mail [[email protected]](mailto://[email protected]). We are happy to reference your project for everyone's convenience.
82
+
83
+ ## 🔗Citation
84
+ If you find this repository useful, please consider giving a star ⭐ and citation
85
+ ```
86
+ @article{wang2025fantasytalking,
87
+ title={FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis},
88
+ author={Wang, Mengchao and Wang, Qiang and Jiang, Fan and Fan, Yaqi and Zhang, Yunpeng and Qi, Yonggang and Zhao, Kun and Xu, Mu},
89
+ journal={arXiv preprint arXiv:2504.04842},
90
+ year={2025}
91
+ }
92
+ ```
93
+
94
+ ## Acknowledgments
95
+ Thanks to [Wan2.1](https://github.com/Wan-Video/Wan2.1), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), and [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) for open-sourcing their models and code, which provided valuable references and support for this project. Their contributions to the open-source community are truly appreciated.
FantasyTalking/README_zh.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [中文阅读](./README_zh.md)
2
+ # FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis
3
+
4
+ [![Home Page](https://img.shields.io/badge/Project-FantasyTalking-blue.svg)](https://fantasy-amap.github.io/fantasy-talking/)
5
+ [![arXiv](https://img.shields.io/badge/Arxiv-2504.04842-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2504.04842)
6
+ [![hf_paper](https://img.shields.io/badge/🤗-FantasyTalking-red.svg)](https://huggingface.co/acvlab/FantasyTalking)
7
+
8
+ ## 🔥 Latest News!!
9
+ * 2025年4月29日: 我们的工作被加入到[ComfyUI-Wan](https://github.com/kijai/ComfyUI-WanVideoWrapper) ! 感谢 [kijai](https://github.com/kijai) 更新 👏!
10
+ * 2025年4月28日: 开源了音频条件下的推理代码和模型权重。
11
+
12
+
13
+ ## 快速开始
14
+ ### 🛠️安装和依赖
15
+
16
+ 首先克隆git仓库:
17
+
18
+ ```
19
+ git clone https://github.com/Fantasy-AMAP/fantasy-talking.git
20
+ cd fantasy-talking
21
+ ```
22
+
23
+ 安装依赖:
24
+ ```
25
+ # Ensure torch >= 2.0.0
26
+ pip install -r requirements.txt
27
+ # 可选安装 flash_attn 以加速注意力计算
28
+ pip install flash_attn
29
+ ```
30
+
31
+ ### 🧱模型下载
32
+ | 模型 | 下载链接 | 备注 |
33
+ | --------------|-------------------------------------------------------------------------------|-------------------------------|
34
+ | Wan2.1-I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | 基础模型
35
+ | Wav2Vec | 🤗 [Huggingface](https://huggingface.co/facebook/wav2vec2-base-960h) 🤖 [ModelScope](https://modelscope.cn/models/AI-ModelScope/wav2vec2-base-960h) | 音频编码器
36
+ | FantasyTalking model | 🤗 [Huggingface](https://huggingface.co/acvlab/FantasyTalking/) 🤖 [ModelScope](https://www.modelscope.cn/models/amap_cvlab/FantasyTalking/) | 我们的音频条件权重
37
+
38
+ 使用huggingface-cli下载模型:
39
+ ``` sh
40
+ pip install "huggingface_hub[cli]"
41
+ huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P --local-dir ./models/Wan2.1-I2V-14B-720P
42
+ huggingface-cli download facebook/wav2vec2-base-960h --local-dir ./models/wav2vec2-base-960h
43
+ huggingface-cli download acvlab/FantasyTalking fantasytalking_model.ckpt --local-dir ./models
44
+ ```
45
+
46
+ 使用modelscope-cli下载模型:
47
+ ``` sh
48
+ pip install modelscope
49
+ modelscope download Wan-AI/Wan2.1-I2V-14B-720P --local_dir ./models/Wan2.1-I2V-14B-720P
50
+ modelscope download AI-ModelScope/wav2vec2-base-960h --local_dir ./models/wav2vec2-base-960h
51
+ modelscope download amap_cvlab/FantasyTalking fantasytalking_model.ckpt --local_dir ./models
52
+ ```
53
+
54
+ ### 🔑 推理
55
+ ``` sh
56
+ python infer.py --image_path ./assets/images/woman.png --audio_path ./assets/audios/woman.wav
57
+ ```
58
+ 您可以通过提示控制角色的行为。**提示和音频cfg的推荐范围是[3-7]。你可以通过调高音频cfg获得更一致的口型同步。**
59
+ ``` sh
60
+ python infer.py --image_path ./assets/images/woman.png --audio_path ./assets/audios/woman.wav --prompt "The person is speaking enthusiastically, with their hands continuously waving." --prompt_cfg_scale 5.0 --audio_cfg_scale 5.0
61
+ ```
62
+
63
+ 我们在此处提供了一个详细的表格。该模型在单个A100上进行了测试。(512x512,81帧)
64
+ |`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|
65
+ |-|-|-|-|
66
+ |torch.bfloat16|None (unlimited)|15.5s/it|40G|
67
+ |torch.bfloat16|7*10**9 (7B)|32.8s/it|20G|
68
+ |torch.bfloat16|0|42.6s/it|5G|
69
+
70
+ ### Gradio 示例
71
+ 我们构建了一个Huggingface[在线演示](https://huggingface.co/spaces/acvlab/FantasyTalking)。
72
+
73
+ 对于本地的gradio演示,您可以运行:
74
+ ``` sh
75
+ pip install gradio spaces
76
+ python app.py
77
+ ```
78
+
79
+ ## 🧩 社区工作
80
+ 我们❤️喜欢来自开源社区的贡献!如果你的工作改进了FantasyTalking,请告诉我们。
81
+
82
+ ## 🔗Citation
83
+ 如果您发现此存储库有用,请考虑给出一个星号⭐和引用:
84
+ ```
85
+ @article{wang2025fantasytalking,
86
+ title={FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis},
87
+ author={Wang, Mengchao and Wang, Qiang and Jiang, Fan and Fan, Yaqi and Zhang, Yunpeng and Qi, Yonggang and Zhao, Kun and Xu, Mu},
88
+ journal={arXiv preprint arXiv:2504.04842},
89
+ year={2025}
90
+ }
91
+ ```
92
+
93
+ ## 致谢
94
+ 感谢[Wan2.1](https://github.com/Wan-Video/Wan2.1)、[HunyuanVideo](https://github.com/Tencent/HunyuanVideo)和[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)开源他们的模型和代码,为该项目提供了宝贵的参考和支持。他们对开源社区的贡献真正值得赞赏。
FantasyTalking/app.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Alibaba Inc. All Rights Reserved.
2
+
3
+ import argparse
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import librosa
9
+
10
+ from infer import load_models, main
11
+
12
+ pipe, fantasytalking, wav2vec_processor, wav2vec = None, None, None, None
13
+
14
+
15
+ def generate_video(
16
+ image_path,
17
+ audio_path,
18
+ prompt,
19
+ prompt_cfg_scale,
20
+ audio_cfg_scale,
21
+ audio_weight,
22
+ image_size,
23
+ max_num_frames,
24
+ inference_steps,
25
+ seed,
26
+ ):
27
+ # Create the temp directory if it doesn't exist
28
+ output_dir = Path("./output")
29
+ output_dir.mkdir(parents=True, exist_ok=True)
30
+
31
+ # Convert paths to absolute Path objects and normalize them
32
+ print(image_path)
33
+ image_path = Path(image_path).absolute().as_posix()
34
+ audio_path = Path(audio_path).absolute().as_posix()
35
+
36
+ # Parse the arguments
37
+
38
+ args = create_args(
39
+ image_path=image_path,
40
+ audio_path=audio_path,
41
+ prompt=prompt,
42
+ output_dir=str(output_dir),
43
+ audio_weight=audio_weight,
44
+ prompt_cfg_scale=prompt_cfg_scale,
45
+ audio_cfg_scale=audio_cfg_scale,
46
+ image_size=image_size,
47
+ max_num_frames=max_num_frames,
48
+ inference_steps=inference_steps,
49
+ seed=seed,
50
+ )
51
+
52
+ try:
53
+ global pipe, fantasytalking, wav2vec_processor, wav2vec
54
+ if pipe is None:
55
+ pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
56
+ output_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
57
+ return output_path # Ensure the output path is returned
58
+ except Exception as e:
59
+ print(f"Error during processing: {str(e)}")
60
+ raise gr.Error(f"Error during processing: {str(e)}")
61
+
62
+
63
+ def create_args(
64
+ image_path: str,
65
+ audio_path: str,
66
+ prompt: str,
67
+ output_dir: str,
68
+ audio_weight: float,
69
+ prompt_cfg_scale: float,
70
+ audio_cfg_scale: float,
71
+ image_size: int,
72
+ max_num_frames: int,
73
+ inference_steps: int,
74
+ seed: int,
75
+ ) -> argparse.Namespace:
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument(
78
+ "--wan_model_dir",
79
+ type=str,
80
+ default="./models/Wan2.1-I2V-14B-720P",
81
+ required=False,
82
+ help="The dir of the Wan I2V 14B model.",
83
+ )
84
+ parser.add_argument(
85
+ "--fantasytalking_model_path",
86
+ type=str,
87
+ default="./models/fantasytalking_model.ckpt",
88
+ required=False,
89
+ help="The .ckpt path of fantasytalking model.",
90
+ )
91
+ parser.add_argument(
92
+ "--wav2vec_model_dir",
93
+ type=str,
94
+ default="./models/wav2vec2-base-960h",
95
+ required=False,
96
+ help="The dir of wav2vec model.",
97
+ )
98
+ parser.add_argument(
99
+ "--image_path",
100
+ type=str,
101
+ default="./assets/images/woman.png",
102
+ required=False,
103
+ help="The path of the image.",
104
+ )
105
+ parser.add_argument(
106
+ "--audio_path",
107
+ type=str,
108
+ default="./assets/audios/woman.wav",
109
+ required=False,
110
+ help="The path of the audio.",
111
+ )
112
+ parser.add_argument(
113
+ "--prompt",
114
+ type=str,
115
+ default="A woman is talking.",
116
+ required=False,
117
+ help="prompt.",
118
+ )
119
+ parser.add_argument(
120
+ "--output_dir",
121
+ type=str,
122
+ default="./output",
123
+ help="Dir to save the video.",
124
+ )
125
+ parser.add_argument(
126
+ "--image_size",
127
+ type=int,
128
+ default=512,
129
+ help="The image will be resized proportionally to this size.",
130
+ )
131
+ parser.add_argument(
132
+ "--audio_scale",
133
+ type=float,
134
+ default=1.0,
135
+ help="Image width.",
136
+ )
137
+ parser.add_argument(
138
+ "--prompt_cfg_scale",
139
+ type=float,
140
+ default=5.0,
141
+ required=False,
142
+ help="prompt cfg scale",
143
+ )
144
+ parser.add_argument(
145
+ "--audio_cfg_scale",
146
+ type=float,
147
+ default=5.0,
148
+ required=False,
149
+ help="audio cfg scale",
150
+ )
151
+ parser.add_argument(
152
+ "--max_num_frames",
153
+ type=int,
154
+ default=81,
155
+ required=False,
156
+ help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
157
+ )
158
+ parser.add_argument(
159
+ "--inference_steps",
160
+ type=int,
161
+ default=20,
162
+ required=False,
163
+ )
164
+ parser.add_argument(
165
+ "--fps",
166
+ type=int,
167
+ default=23,
168
+ required=False,
169
+ )
170
+ parser.add_argument(
171
+ "--num_persistent_param_in_dit",
172
+ type=int,
173
+ default=None,
174
+ required=False,
175
+ help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required",
176
+ )
177
+ parser.add_argument(
178
+ "--seed",
179
+ type=int,
180
+ default=1111,
181
+ required=False,
182
+ )
183
+ args = parser.parse_args(
184
+ [
185
+ "--image_path",
186
+ image_path,
187
+ "--audio_path",
188
+ audio_path,
189
+ "--prompt",
190
+ prompt,
191
+ "--output_dir",
192
+ output_dir,
193
+ "--image_size",
194
+ str(image_size),
195
+ "--audio_scale",
196
+ str(audio_weight),
197
+ "--prompt_cfg_scale",
198
+ str(prompt_cfg_scale),
199
+ "--audio_cfg_scale",
200
+ str(audio_cfg_scale),
201
+ "--max_num_frames",
202
+ str(max_num_frames),
203
+ "--inference_steps",
204
+ str(inference_steps),
205
+ "--seed",
206
+ str(seed),
207
+ ]
208
+ )
209
+ print(args)
210
+ return args
211
+
212
+
213
+ # Create Gradio interface
214
+ with gr.Blocks(title="FantasyTalking Video Generation") as demo:
215
+ gr.Markdown(
216
+ """
217
+ # FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis
218
+
219
+ <div align="center">
220
+ <strong> Mengchao Wang1* Qiang Wang1* Fan Jiang1†
221
+ Yaqi Fan2 Yunpeng Zhang1,2 YongGang Qi2‡
222
+ Kun Zhao1. Mu Xu1 </strong>
223
+ </div>
224
+
225
+ <div align="center">
226
+ <strong>1AMAP,Alibaba Group 2Beijing University of Posts and Telecommunications</strong>
227
+ </div>
228
+
229
+ <div style="display:flex;justify-content:center;column-gap:4px;">
230
+ <a href="https://github.com/Fantasy-AMAP/fantasy-talking">
231
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
232
+ </a>
233
+ <a href="https://arxiv.org/abs/2504.04842">
234
+ <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
235
+ </a>
236
+ </div>
237
+ """
238
+ )
239
+
240
+ with gr.Row():
241
+ with gr.Column():
242
+ image_input = gr.Image(label="Input Image", type="filepath")
243
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
244
+ prompt_input = gr.Text(label="Input Prompt")
245
+ with gr.Row():
246
+ prompt_cfg_scale = gr.Slider(
247
+ minimum=1.0,
248
+ maximum=9.0,
249
+ value=5.0,
250
+ step=0.5,
251
+ label="Prompt CFG Scale",
252
+ )
253
+ audio_cfg_scale = gr.Slider(
254
+ minimum=1.0,
255
+ maximum=9.0,
256
+ value=5.0,
257
+ step=0.5,
258
+ label="Audio CFG Scale",
259
+ )
260
+ audio_weight = gr.Slider(
261
+ minimum=0.1,
262
+ maximum=3.0,
263
+ value=1.0,
264
+ step=0.1,
265
+ label="Audio Weight",
266
+ )
267
+ with gr.Row():
268
+ image_size = gr.Number(
269
+ value=512, label="Width/Height Maxsize", precision=0
270
+ )
271
+ max_num_frames = gr.Number(
272
+ value=81, label="The Maximum Frames", precision=0
273
+ )
274
+ inference_steps = gr.Slider(
275
+ minimum=1, maximum=50, value=20, step=1, label="Inference Steps"
276
+ )
277
+
278
+ with gr.Row():
279
+ seed = gr.Number(value=1247, label="Random Seed", precision=0)
280
+
281
+ process_btn = gr.Button("Generate Video")
282
+
283
+ with gr.Column():
284
+ video_output = gr.Video(label="Output Video")
285
+
286
+ gr.Examples(
287
+ examples=[
288
+ [
289
+ "assets/images/woman.png",
290
+ "assets/audios/woman.wav",
291
+ ],
292
+ ],
293
+ inputs=[image_input, audio_input],
294
+ )
295
+
296
+ process_btn.click(
297
+ fn=generate_video,
298
+ inputs=[
299
+ image_input,
300
+ audio_input,
301
+ prompt_input,
302
+ prompt_cfg_scale,
303
+ audio_cfg_scale,
304
+ audio_weight,
305
+ image_size,
306
+ max_num_frames,
307
+ inference_steps,
308
+ seed,
309
+ ],
310
+ outputs=video_output,
311
+ )
312
+
313
+ if __name__ == "__main__":
314
+ demo.launch(inbrowser=True, share=True)
FantasyTalking/assets/audios/woman.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e08584293621824d039c264132d90b654bede740f67d9384979544e3e2abfacc
3
+ size 1765454
FantasyTalking/assets/images/fig0_1_0.png ADDED

Git LFS Details

  • SHA256: b7eb9cfe91e7be5d175afa8d6c464b0e64638813271062ee0429370bf757a555
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
FantasyTalking/assets/images/woman.png ADDED

Git LFS Details

  • SHA256: add373b3b48fa76ac760f60da302bcf402bfbb77eccecae6b87b861f7d0825de
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
FantasyTalking/assets/overview.png ADDED

Git LFS Details

  • SHA256: b7eb9cfe91e7be5d175afa8d6c464b0e64638813271062ee0429370bf757a555
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
FantasyTalking/diffsynth/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .pipelines import *
4
+ from .prompters import *
5
+ from .schedulers import *
FantasyTalking/diffsynth/configs/__init__.py ADDED
File without changes
FantasyTalking/diffsynth/configs/model_config.py ADDED
@@ -0,0 +1,1577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.wan_video_dit import WanModel
4
+ from ..models.wan_video_image_encoder import WanImageEncoder
5
+ from ..models.wan_video_text_encoder import WanTextEncoder
6
+ from ..models.wan_video_vae import WanVideoVAE
7
+
8
+ model_loader_configs = [
9
+ # These configs are provided for detecting model type automatically.
10
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
11
+ (
12
+ None,
13
+ "9269f8db9040a9d860eaca435be61814",
14
+ ["wan_video_dit"],
15
+ [WanModel],
16
+ "civitai",
17
+ ),
18
+ (
19
+ None,
20
+ "aafcfd9672c3a2456dc46e1cb6e52c70",
21
+ ["wan_video_dit"],
22
+ [WanModel],
23
+ "civitai",
24
+ ),
25
+ (
26
+ None,
27
+ "6bfcfb3b342cb286ce886889d519a77e",
28
+ ["wan_video_dit"],
29
+ [WanModel],
30
+ "civitai",
31
+ ),
32
+ (
33
+ None,
34
+ "9c8818c2cbea55eca56c7b447df170da",
35
+ ["wan_video_text_encoder"],
36
+ [WanTextEncoder],
37
+ "civitai",
38
+ ),
39
+ (
40
+ None,
41
+ "5941c53e207d62f20f9025686193c40b",
42
+ ["wan_video_image_encoder"],
43
+ [WanImageEncoder],
44
+ "civitai",
45
+ ),
46
+ (
47
+ None,
48
+ "1378ea763357eea97acdef78e65d6d96",
49
+ ["wan_video_vae"],
50
+ [WanVideoVAE],
51
+ "civitai",
52
+ ),
53
+ (
54
+ None,
55
+ "ccc42284ea13e1ad04693284c7a09be6",
56
+ ["wan_video_vae"],
57
+ [WanVideoVAE],
58
+ "civitai",
59
+ ),
60
+ ]
61
+ huggingface_model_loader_configs = [
62
+ # These configs are provided for detecting model type automatically.
63
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
64
+ (
65
+ "ChatGLMModel",
66
+ "diffsynth.models.kolors_text_encoder",
67
+ "kolors_text_encoder",
68
+ None,
69
+ ),
70
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
71
+ (
72
+ "BloomForCausalLM",
73
+ "transformers.models.bloom.modeling_bloom",
74
+ "beautiful_prompt",
75
+ None,
76
+ ),
77
+ (
78
+ "Qwen2ForCausalLM",
79
+ "transformers.models.qwen2.modeling_qwen2",
80
+ "qwen_prompt",
81
+ None,
82
+ ),
83
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
84
+ (
85
+ "T5EncoderModel",
86
+ "diffsynth.models.flux_text_encoder",
87
+ "flux_text_encoder_2",
88
+ "FluxTextEncoder2",
89
+ ),
90
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
91
+ (
92
+ "SiglipModel",
93
+ "transformers.models.siglip.modeling_siglip",
94
+ "siglip_vision_model",
95
+ "SiglipVisionModel",
96
+ ),
97
+ (
98
+ "LlamaForCausalLM",
99
+ "diffsynth.models.hunyuan_video_text_encoder",
100
+ "hunyuan_video_text_encoder_2",
101
+ "HunyuanVideoLLMEncoder",
102
+ ),
103
+ (
104
+ "Step1Model",
105
+ "diffsynth.models.stepvideo_text_encoder",
106
+ "stepvideo_text_encoder_2",
107
+ "STEP1TextEncoder",
108
+ ),
109
+ ]
110
+ patch_model_loader_configs = [
111
+ # These configs are provided for detecting model type automatically.
112
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
113
+ # ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
114
+ ]
115
+
116
+ preset_models_on_huggingface = {
117
+ "HunyuanDiT": [
118
+ (
119
+ "Tencent-Hunyuan/HunyuanDiT",
120
+ "t2i/clip_text_encoder/pytorch_model.bin",
121
+ "models/HunyuanDiT/t2i/clip_text_encoder",
122
+ ),
123
+ (
124
+ "Tencent-Hunyuan/HunyuanDiT",
125
+ "t2i/mt5/pytorch_model.bin",
126
+ "models/HunyuanDiT/t2i/mt5",
127
+ ),
128
+ (
129
+ "Tencent-Hunyuan/HunyuanDiT",
130
+ "t2i/model/pytorch_model_ema.pt",
131
+ "models/HunyuanDiT/t2i/model",
132
+ ),
133
+ (
134
+ "Tencent-Hunyuan/HunyuanDiT",
135
+ "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin",
136
+ "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix",
137
+ ),
138
+ ],
139
+ "stable-video-diffusion-img2vid-xt": [
140
+ (
141
+ "stabilityai/stable-video-diffusion-img2vid-xt",
142
+ "svd_xt.safetensors",
143
+ "models/stable_video_diffusion",
144
+ ),
145
+ ],
146
+ "ExVideo-SVD-128f-v1": [
147
+ (
148
+ "ECNU-CILab/ExVideo-SVD-128f-v1",
149
+ "model.fp16.safetensors",
150
+ "models/stable_video_diffusion",
151
+ ),
152
+ ],
153
+ # Stable Diffusion
154
+ "StableDiffusion_v15": [
155
+ (
156
+ "benjamin-paine/stable-diffusion-v1-5",
157
+ "v1-5-pruned-emaonly.safetensors",
158
+ "models/stable_diffusion",
159
+ ),
160
+ ],
161
+ "DreamShaper_8": [
162
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
163
+ ],
164
+ # Textual Inversion
165
+ "TextualInversion_VeryBadImageNegative_v1.3": [
166
+ (
167
+ "gemasai/verybadimagenegative_v1.3",
168
+ "verybadimagenegative_v1.3.pt",
169
+ "models/textual_inversion",
170
+ ),
171
+ ],
172
+ # Stable Diffusion XL
173
+ "StableDiffusionXL_v1": [
174
+ (
175
+ "stabilityai/stable-diffusion-xl-base-1.0",
176
+ "sd_xl_base_1.0.safetensors",
177
+ "models/stable_diffusion_xl",
178
+ ),
179
+ ],
180
+ "BluePencilXL_v200": [
181
+ (
182
+ "frankjoshua/bluePencilXL_v200",
183
+ "bluePencilXL_v200.safetensors",
184
+ "models/stable_diffusion_xl",
185
+ ),
186
+ ],
187
+ "StableDiffusionXL_Turbo": [
188
+ (
189
+ "stabilityai/sdxl-turbo",
190
+ "sd_xl_turbo_1.0_fp16.safetensors",
191
+ "models/stable_diffusion_xl_turbo",
192
+ ),
193
+ ],
194
+ # Stable Diffusion 3
195
+ "StableDiffusion3": [
196
+ (
197
+ "stabilityai/stable-diffusion-3-medium",
198
+ "sd3_medium_incl_clips_t5xxlfp16.safetensors",
199
+ "models/stable_diffusion_3",
200
+ ),
201
+ ],
202
+ "StableDiffusion3_without_T5": [
203
+ (
204
+ "stabilityai/stable-diffusion-3-medium",
205
+ "sd3_medium_incl_clips.safetensors",
206
+ "models/stable_diffusion_3",
207
+ ),
208
+ ],
209
+ # ControlNet
210
+ "ControlNet_v11f1p_sd15_depth": [
211
+ (
212
+ "lllyasviel/ControlNet-v1-1",
213
+ "control_v11f1p_sd15_depth.pth",
214
+ "models/ControlNet",
215
+ ),
216
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
217
+ ],
218
+ "ControlNet_v11p_sd15_softedge": [
219
+ (
220
+ "lllyasviel/ControlNet-v1-1",
221
+ "control_v11p_sd15_softedge.pth",
222
+ "models/ControlNet",
223
+ ),
224
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators"),
225
+ ],
226
+ "ControlNet_v11f1e_sd15_tile": [
227
+ (
228
+ "lllyasviel/ControlNet-v1-1",
229
+ "control_v11f1e_sd15_tile.pth",
230
+ "models/ControlNet",
231
+ )
232
+ ],
233
+ "ControlNet_v11p_sd15_lineart": [
234
+ (
235
+ "lllyasviel/ControlNet-v1-1",
236
+ "control_v11p_sd15_lineart.pth",
237
+ "models/ControlNet",
238
+ ),
239
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
240
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators"),
241
+ ],
242
+ "ControlNet_union_sdxl_promax": [
243
+ (
244
+ "xinsir/controlnet-union-sdxl-1.0",
245
+ "diffusion_pytorch_model_promax.safetensors",
246
+ "models/ControlNet/controlnet_union",
247
+ ),
248
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
249
+ ],
250
+ # AnimateDiff
251
+ "AnimateDiff_v2": [
252
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
253
+ ],
254
+ "AnimateDiff_xl_beta": [
255
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
256
+ ],
257
+ # Qwen Prompt
258
+ "QwenPrompt": [
259
+ (
260
+ "Qwen/Qwen2-1.5B-Instruct",
261
+ "config.json",
262
+ "models/QwenPrompt/qwen2-1.5b-instruct",
263
+ ),
264
+ (
265
+ "Qwen/Qwen2-1.5B-Instruct",
266
+ "generation_config.json",
267
+ "models/QwenPrompt/qwen2-1.5b-instruct",
268
+ ),
269
+ (
270
+ "Qwen/Qwen2-1.5B-Instruct",
271
+ "model.safetensors",
272
+ "models/QwenPrompt/qwen2-1.5b-instruct",
273
+ ),
274
+ (
275
+ "Qwen/Qwen2-1.5B-Instruct",
276
+ "special_tokens_map.json",
277
+ "models/QwenPrompt/qwen2-1.5b-instruct",
278
+ ),
279
+ (
280
+ "Qwen/Qwen2-1.5B-Instruct",
281
+ "tokenizer.json",
282
+ "models/QwenPrompt/qwen2-1.5b-instruct",
283
+ ),
284
+ (
285
+ "Qwen/Qwen2-1.5B-Instruct",
286
+ "tokenizer_config.json",
287
+ "models/QwenPrompt/qwen2-1.5b-instruct",
288
+ ),
289
+ (
290
+ "Qwen/Qwen2-1.5B-Instruct",
291
+ "merges.txt",
292
+ "models/QwenPrompt/qwen2-1.5b-instruct",
293
+ ),
294
+ (
295
+ "Qwen/Qwen2-1.5B-Instruct",
296
+ "vocab.json",
297
+ "models/QwenPrompt/qwen2-1.5b-instruct",
298
+ ),
299
+ ],
300
+ # Beautiful Prompt
301
+ "BeautifulPrompt": [
302
+ (
303
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
304
+ "config.json",
305
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
306
+ ),
307
+ (
308
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
309
+ "generation_config.json",
310
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
311
+ ),
312
+ (
313
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
314
+ "model.safetensors",
315
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
316
+ ),
317
+ (
318
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
319
+ "special_tokens_map.json",
320
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
321
+ ),
322
+ (
323
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
324
+ "tokenizer.json",
325
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
326
+ ),
327
+ (
328
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
329
+ "tokenizer_config.json",
330
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
331
+ ),
332
+ ],
333
+ # Omost prompt
334
+ "OmostPrompt": [
335
+ (
336
+ "lllyasviel/omost-llama-3-8b-4bits",
337
+ "model-00001-of-00002.safetensors",
338
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
339
+ ),
340
+ (
341
+ "lllyasviel/omost-llama-3-8b-4bits",
342
+ "model-00002-of-00002.safetensors",
343
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
344
+ ),
345
+ (
346
+ "lllyasviel/omost-llama-3-8b-4bits",
347
+ "tokenizer.json",
348
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
349
+ ),
350
+ (
351
+ "lllyasviel/omost-llama-3-8b-4bits",
352
+ "tokenizer_config.json",
353
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
354
+ ),
355
+ (
356
+ "lllyasviel/omost-llama-3-8b-4bits",
357
+ "config.json",
358
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
359
+ ),
360
+ (
361
+ "lllyasviel/omost-llama-3-8b-4bits",
362
+ "generation_config.json",
363
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
364
+ ),
365
+ (
366
+ "lllyasviel/omost-llama-3-8b-4bits",
367
+ "model.safetensors.index.json",
368
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
369
+ ),
370
+ (
371
+ "lllyasviel/omost-llama-3-8b-4bits",
372
+ "special_tokens_map.json",
373
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
374
+ ),
375
+ ],
376
+ # Translator
377
+ "opus-mt-zh-en": [
378
+ (
379
+ "Helsinki-NLP/opus-mt-zh-en",
380
+ "config.json",
381
+ "models/translator/opus-mt-zh-en",
382
+ ),
383
+ (
384
+ "Helsinki-NLP/opus-mt-zh-en",
385
+ "generation_config.json",
386
+ "models/translator/opus-mt-zh-en",
387
+ ),
388
+ (
389
+ "Helsinki-NLP/opus-mt-zh-en",
390
+ "metadata.json",
391
+ "models/translator/opus-mt-zh-en",
392
+ ),
393
+ (
394
+ "Helsinki-NLP/opus-mt-zh-en",
395
+ "pytorch_model.bin",
396
+ "models/translator/opus-mt-zh-en",
397
+ ),
398
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
399
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
400
+ (
401
+ "Helsinki-NLP/opus-mt-zh-en",
402
+ "tokenizer_config.json",
403
+ "models/translator/opus-mt-zh-en",
404
+ ),
405
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
406
+ ],
407
+ # IP-Adapter
408
+ "IP-Adapter-SD": [
409
+ (
410
+ "h94/IP-Adapter",
411
+ "models/image_encoder/model.safetensors",
412
+ "models/IpAdapter/stable_diffusion/image_encoder",
413
+ ),
414
+ (
415
+ "h94/IP-Adapter",
416
+ "models/ip-adapter_sd15.bin",
417
+ "models/IpAdapter/stable_diffusion",
418
+ ),
419
+ ],
420
+ "IP-Adapter-SDXL": [
421
+ (
422
+ "h94/IP-Adapter",
423
+ "sdxl_models/image_encoder/model.safetensors",
424
+ "models/IpAdapter/stable_diffusion_xl/image_encoder",
425
+ ),
426
+ (
427
+ "h94/IP-Adapter",
428
+ "sdxl_models/ip-adapter_sdxl.bin",
429
+ "models/IpAdapter/stable_diffusion_xl",
430
+ ),
431
+ ],
432
+ "SDXL-vae-fp16-fix": [
433
+ (
434
+ "madebyollin/sdxl-vae-fp16-fix",
435
+ "diffusion_pytorch_model.safetensors",
436
+ "models/sdxl-vae-fp16-fix",
437
+ )
438
+ ],
439
+ # Kolors
440
+ "Kolors": [
441
+ (
442
+ "Kwai-Kolors/Kolors",
443
+ "text_encoder/config.json",
444
+ "models/kolors/Kolors/text_encoder",
445
+ ),
446
+ (
447
+ "Kwai-Kolors/Kolors",
448
+ "text_encoder/pytorch_model.bin.index.json",
449
+ "models/kolors/Kolors/text_encoder",
450
+ ),
451
+ (
452
+ "Kwai-Kolors/Kolors",
453
+ "text_encoder/pytorch_model-00001-of-00007.bin",
454
+ "models/kolors/Kolors/text_encoder",
455
+ ),
456
+ (
457
+ "Kwai-Kolors/Kolors",
458
+ "text_encoder/pytorch_model-00002-of-00007.bin",
459
+ "models/kolors/Kolors/text_encoder",
460
+ ),
461
+ (
462
+ "Kwai-Kolors/Kolors",
463
+ "text_encoder/pytorch_model-00003-of-00007.bin",
464
+ "models/kolors/Kolors/text_encoder",
465
+ ),
466
+ (
467
+ "Kwai-Kolors/Kolors",
468
+ "text_encoder/pytorch_model-00004-of-00007.bin",
469
+ "models/kolors/Kolors/text_encoder",
470
+ ),
471
+ (
472
+ "Kwai-Kolors/Kolors",
473
+ "text_encoder/pytorch_model-00005-of-00007.bin",
474
+ "models/kolors/Kolors/text_encoder",
475
+ ),
476
+ (
477
+ "Kwai-Kolors/Kolors",
478
+ "text_encoder/pytorch_model-00006-of-00007.bin",
479
+ "models/kolors/Kolors/text_encoder",
480
+ ),
481
+ (
482
+ "Kwai-Kolors/Kolors",
483
+ "text_encoder/pytorch_model-00007-of-00007.bin",
484
+ "models/kolors/Kolors/text_encoder",
485
+ ),
486
+ (
487
+ "Kwai-Kolors/Kolors",
488
+ "unet/diffusion_pytorch_model.safetensors",
489
+ "models/kolors/Kolors/unet",
490
+ ),
491
+ (
492
+ "Kwai-Kolors/Kolors",
493
+ "vae/diffusion_pytorch_model.safetensors",
494
+ "models/kolors/Kolors/vae",
495
+ ),
496
+ ],
497
+ # FLUX
498
+ "FLUX.1-dev": [
499
+ (
500
+ "black-forest-labs/FLUX.1-dev",
501
+ "text_encoder/model.safetensors",
502
+ "models/FLUX/FLUX.1-dev/text_encoder",
503
+ ),
504
+ (
505
+ "black-forest-labs/FLUX.1-dev",
506
+ "text_encoder_2/config.json",
507
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
508
+ ),
509
+ (
510
+ "black-forest-labs/FLUX.1-dev",
511
+ "text_encoder_2/model-00001-of-00002.safetensors",
512
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
513
+ ),
514
+ (
515
+ "black-forest-labs/FLUX.1-dev",
516
+ "text_encoder_2/model-00002-of-00002.safetensors",
517
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
518
+ ),
519
+ (
520
+ "black-forest-labs/FLUX.1-dev",
521
+ "text_encoder_2/model.safetensors.index.json",
522
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
523
+ ),
524
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
525
+ (
526
+ "black-forest-labs/FLUX.1-dev",
527
+ "flux1-dev.safetensors",
528
+ "models/FLUX/FLUX.1-dev",
529
+ ),
530
+ ],
531
+ "InstantX/FLUX.1-dev-IP-Adapter": {
532
+ "file_list": [
533
+ (
534
+ "InstantX/FLUX.1-dev-IP-Adapter",
535
+ "ip-adapter.bin",
536
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter",
537
+ ),
538
+ (
539
+ "google/siglip-so400m-patch14-384",
540
+ "model.safetensors",
541
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
542
+ ),
543
+ (
544
+ "google/siglip-so400m-patch14-384",
545
+ "config.json",
546
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
547
+ ),
548
+ ],
549
+ "load_path": [
550
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
551
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
552
+ ],
553
+ },
554
+ # RIFE
555
+ "RIFE": [
556
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
557
+ ],
558
+ # CogVideo
559
+ "CogVideoX-5B": [
560
+ (
561
+ "THUDM/CogVideoX-5b",
562
+ "text_encoder/config.json",
563
+ "models/CogVideo/CogVideoX-5b/text_encoder",
564
+ ),
565
+ (
566
+ "THUDM/CogVideoX-5b",
567
+ "text_encoder/model.safetensors.index.json",
568
+ "models/CogVideo/CogVideoX-5b/text_encoder",
569
+ ),
570
+ (
571
+ "THUDM/CogVideoX-5b",
572
+ "text_encoder/model-00001-of-00002.safetensors",
573
+ "models/CogVideo/CogVideoX-5b/text_encoder",
574
+ ),
575
+ (
576
+ "THUDM/CogVideoX-5b",
577
+ "text_encoder/model-00002-of-00002.safetensors",
578
+ "models/CogVideo/CogVideoX-5b/text_encoder",
579
+ ),
580
+ (
581
+ "THUDM/CogVideoX-5b",
582
+ "transformer/config.json",
583
+ "models/CogVideo/CogVideoX-5b/transformer",
584
+ ),
585
+ (
586
+ "THUDM/CogVideoX-5b",
587
+ "transformer/diffusion_pytorch_model.safetensors.index.json",
588
+ "models/CogVideo/CogVideoX-5b/transformer",
589
+ ),
590
+ (
591
+ "THUDM/CogVideoX-5b",
592
+ "transformer/diffusion_pytorch_model-00001-of-00002.safetensors",
593
+ "models/CogVideo/CogVideoX-5b/transformer",
594
+ ),
595
+ (
596
+ "THUDM/CogVideoX-5b",
597
+ "transformer/diffusion_pytorch_model-00002-of-00002.safetensors",
598
+ "models/CogVideo/CogVideoX-5b/transformer",
599
+ ),
600
+ (
601
+ "THUDM/CogVideoX-5b",
602
+ "vae/diffusion_pytorch_model.safetensors",
603
+ "models/CogVideo/CogVideoX-5b/vae",
604
+ ),
605
+ ],
606
+ # Stable Diffusion 3.5
607
+ "StableDiffusion3.5-large": [
608
+ (
609
+ "stabilityai/stable-diffusion-3.5-large",
610
+ "sd3.5_large.safetensors",
611
+ "models/stable_diffusion_3",
612
+ ),
613
+ (
614
+ "stabilityai/stable-diffusion-3.5-large",
615
+ "text_encoders/clip_l.safetensors",
616
+ "models/stable_diffusion_3/text_encoders",
617
+ ),
618
+ (
619
+ "stabilityai/stable-diffusion-3.5-large",
620
+ "text_encoders/clip_g.safetensors",
621
+ "models/stable_diffusion_3/text_encoders",
622
+ ),
623
+ (
624
+ "stabilityai/stable-diffusion-3.5-large",
625
+ "text_encoders/t5xxl_fp16.safetensors",
626
+ "models/stable_diffusion_3/text_encoders",
627
+ ),
628
+ ],
629
+ }
630
+ preset_models_on_modelscope = {
631
+ # Hunyuan DiT
632
+ "HunyuanDiT": [
633
+ (
634
+ "modelscope/HunyuanDiT",
635
+ "t2i/clip_text_encoder/pytorch_model.bin",
636
+ "models/HunyuanDiT/t2i/clip_text_encoder",
637
+ ),
638
+ (
639
+ "modelscope/HunyuanDiT",
640
+ "t2i/mt5/pytorch_model.bin",
641
+ "models/HunyuanDiT/t2i/mt5",
642
+ ),
643
+ (
644
+ "modelscope/HunyuanDiT",
645
+ "t2i/model/pytorch_model_ema.pt",
646
+ "models/HunyuanDiT/t2i/model",
647
+ ),
648
+ (
649
+ "modelscope/HunyuanDiT",
650
+ "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin",
651
+ "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix",
652
+ ),
653
+ ],
654
+ # Stable Video Diffusion
655
+ "stable-video-diffusion-img2vid-xt": [
656
+ (
657
+ "AI-ModelScope/stable-video-diffusion-img2vid-xt",
658
+ "svd_xt.safetensors",
659
+ "models/stable_video_diffusion",
660
+ ),
661
+ ],
662
+ # ExVideo
663
+ "ExVideo-SVD-128f-v1": [
664
+ (
665
+ "ECNU-CILab/ExVideo-SVD-128f-v1",
666
+ "model.fp16.safetensors",
667
+ "models/stable_video_diffusion",
668
+ ),
669
+ ],
670
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
671
+ (
672
+ "ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1",
673
+ "ExVideo-CogVideoX-LoRA-129f-v1.safetensors",
674
+ "models/lora",
675
+ ),
676
+ ],
677
+ # Stable Diffusion
678
+ "StableDiffusion_v15": [
679
+ (
680
+ "AI-ModelScope/stable-diffusion-v1-5",
681
+ "v1-5-pruned-emaonly.safetensors",
682
+ "models/stable_diffusion",
683
+ ),
684
+ ],
685
+ "DreamShaper_8": [
686
+ (
687
+ "sd_lora/dreamshaper_8",
688
+ "dreamshaper_8.safetensors",
689
+ "models/stable_diffusion",
690
+ ),
691
+ ],
692
+ "AingDiffusion_v12": [
693
+ (
694
+ "sd_lora/aingdiffusion_v12",
695
+ "aingdiffusion_v12.safetensors",
696
+ "models/stable_diffusion",
697
+ ),
698
+ ],
699
+ "Flat2DAnimerge_v45Sharp": [
700
+ (
701
+ "sd_lora/Flat-2D-Animerge",
702
+ "flat2DAnimerge_v45Sharp.safetensors",
703
+ "models/stable_diffusion",
704
+ ),
705
+ ],
706
+ # Textual Inversion
707
+ "TextualInversion_VeryBadImageNegative_v1.3": [
708
+ (
709
+ "sd_lora/verybadimagenegative_v1.3",
710
+ "verybadimagenegative_v1.3.pt",
711
+ "models/textual_inversion",
712
+ ),
713
+ ],
714
+ # Stable Diffusion XL
715
+ "StableDiffusionXL_v1": [
716
+ (
717
+ "AI-ModelScope/stable-diffusion-xl-base-1.0",
718
+ "sd_xl_base_1.0.safetensors",
719
+ "models/stable_diffusion_xl",
720
+ ),
721
+ ],
722
+ "BluePencilXL_v200": [
723
+ (
724
+ "sd_lora/bluePencilXL_v200",
725
+ "bluePencilXL_v200.safetensors",
726
+ "models/stable_diffusion_xl",
727
+ ),
728
+ ],
729
+ "StableDiffusionXL_Turbo": [
730
+ (
731
+ "AI-ModelScope/sdxl-turbo",
732
+ "sd_xl_turbo_1.0_fp16.safetensors",
733
+ "models/stable_diffusion_xl_turbo",
734
+ ),
735
+ ],
736
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
737
+ (
738
+ "sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0",
739
+ "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors",
740
+ "models/lora",
741
+ ),
742
+ ],
743
+ # Stable Diffusion 3
744
+ "StableDiffusion3": [
745
+ (
746
+ "AI-ModelScope/stable-diffusion-3-medium",
747
+ "sd3_medium_incl_clips_t5xxlfp16.safetensors",
748
+ "models/stable_diffusion_3",
749
+ ),
750
+ ],
751
+ "StableDiffusion3_without_T5": [
752
+ (
753
+ "AI-ModelScope/stable-diffusion-3-medium",
754
+ "sd3_medium_incl_clips.safetensors",
755
+ "models/stable_diffusion_3",
756
+ ),
757
+ ],
758
+ # ControlNet
759
+ "ControlNet_v11f1p_sd15_depth": [
760
+ (
761
+ "AI-ModelScope/ControlNet-v1-1",
762
+ "control_v11f1p_sd15_depth.pth",
763
+ "models/ControlNet",
764
+ ),
765
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
766
+ ],
767
+ "ControlNet_v11p_sd15_softedge": [
768
+ (
769
+ "AI-ModelScope/ControlNet-v1-1",
770
+ "control_v11p_sd15_softedge.pth",
771
+ "models/ControlNet",
772
+ ),
773
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
774
+ ],
775
+ "ControlNet_v11f1e_sd15_tile": [
776
+ (
777
+ "AI-ModelScope/ControlNet-v1-1",
778
+ "control_v11f1e_sd15_tile.pth",
779
+ "models/ControlNet",
780
+ )
781
+ ],
782
+ "ControlNet_v11p_sd15_lineart": [
783
+ (
784
+ "AI-ModelScope/ControlNet-v1-1",
785
+ "control_v11p_sd15_lineart.pth",
786
+ "models/ControlNet",
787
+ ),
788
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
789
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
790
+ ],
791
+ "ControlNet_union_sdxl_promax": [
792
+ (
793
+ "AI-ModelScope/controlnet-union-sdxl-1.0",
794
+ "diffusion_pytorch_model_promax.safetensors",
795
+ "models/ControlNet/controlnet_union",
796
+ ),
797
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
798
+ ],
799
+ "Annotators:Depth": [
800
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
801
+ ],
802
+ "Annotators:Softedge": [
803
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
804
+ ],
805
+ "Annotators:Lineart": [
806
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
807
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
808
+ ],
809
+ "Annotators:Normal": [
810
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
811
+ ],
812
+ "Annotators:Openpose": [
813
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
814
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
815
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
816
+ ],
817
+ # AnimateDiff
818
+ "AnimateDiff_v2": [
819
+ (
820
+ "Shanghai_AI_Laboratory/animatediff",
821
+ "mm_sd_v15_v2.ckpt",
822
+ "models/AnimateDiff",
823
+ ),
824
+ ],
825
+ "AnimateDiff_xl_beta": [
826
+ (
827
+ "Shanghai_AI_Laboratory/animatediff",
828
+ "mm_sdxl_v10_beta.ckpt",
829
+ "models/AnimateDiff",
830
+ ),
831
+ ],
832
+ # RIFE
833
+ "RIFE": [
834
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
835
+ ],
836
+ # Qwen Prompt
837
+ "QwenPrompt": {
838
+ "file_list": [
839
+ (
840
+ "qwen/Qwen2-1.5B-Instruct",
841
+ "config.json",
842
+ "models/QwenPrompt/qwen2-1.5b-instruct",
843
+ ),
844
+ (
845
+ "qwen/Qwen2-1.5B-Instruct",
846
+ "generation_config.json",
847
+ "models/QwenPrompt/qwen2-1.5b-instruct",
848
+ ),
849
+ (
850
+ "qwen/Qwen2-1.5B-Instruct",
851
+ "model.safetensors",
852
+ "models/QwenPrompt/qwen2-1.5b-instruct",
853
+ ),
854
+ (
855
+ "qwen/Qwen2-1.5B-Instruct",
856
+ "special_tokens_map.json",
857
+ "models/QwenPrompt/qwen2-1.5b-instruct",
858
+ ),
859
+ (
860
+ "qwen/Qwen2-1.5B-Instruct",
861
+ "tokenizer.json",
862
+ "models/QwenPrompt/qwen2-1.5b-instruct",
863
+ ),
864
+ (
865
+ "qwen/Qwen2-1.5B-Instruct",
866
+ "tokenizer_config.json",
867
+ "models/QwenPrompt/qwen2-1.5b-instruct",
868
+ ),
869
+ (
870
+ "qwen/Qwen2-1.5B-Instruct",
871
+ "merges.txt",
872
+ "models/QwenPrompt/qwen2-1.5b-instruct",
873
+ ),
874
+ (
875
+ "qwen/Qwen2-1.5B-Instruct",
876
+ "vocab.json",
877
+ "models/QwenPrompt/qwen2-1.5b-instruct",
878
+ ),
879
+ ],
880
+ "load_path": [
881
+ "models/QwenPrompt/qwen2-1.5b-instruct",
882
+ ],
883
+ },
884
+ # Beautiful Prompt
885
+ "BeautifulPrompt": {
886
+ "file_list": [
887
+ (
888
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
889
+ "config.json",
890
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
891
+ ),
892
+ (
893
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
894
+ "generation_config.json",
895
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
896
+ ),
897
+ (
898
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
899
+ "model.safetensors",
900
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
901
+ ),
902
+ (
903
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
904
+ "special_tokens_map.json",
905
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
906
+ ),
907
+ (
908
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
909
+ "tokenizer.json",
910
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
911
+ ),
912
+ (
913
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
914
+ "tokenizer_config.json",
915
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
916
+ ),
917
+ ],
918
+ "load_path": [
919
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
920
+ ],
921
+ },
922
+ # Omost prompt
923
+ "OmostPrompt": {
924
+ "file_list": [
925
+ (
926
+ "Omost/omost-llama-3-8b-4bits",
927
+ "model-00001-of-00002.safetensors",
928
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
929
+ ),
930
+ (
931
+ "Omost/omost-llama-3-8b-4bits",
932
+ "model-00002-of-00002.safetensors",
933
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
934
+ ),
935
+ (
936
+ "Omost/omost-llama-3-8b-4bits",
937
+ "tokenizer.json",
938
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
939
+ ),
940
+ (
941
+ "Omost/omost-llama-3-8b-4bits",
942
+ "tokenizer_config.json",
943
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
944
+ ),
945
+ (
946
+ "Omost/omost-llama-3-8b-4bits",
947
+ "config.json",
948
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
949
+ ),
950
+ (
951
+ "Omost/omost-llama-3-8b-4bits",
952
+ "generation_config.json",
953
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
954
+ ),
955
+ (
956
+ "Omost/omost-llama-3-8b-4bits",
957
+ "model.safetensors.index.json",
958
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
959
+ ),
960
+ (
961
+ "Omost/omost-llama-3-8b-4bits",
962
+ "special_tokens_map.json",
963
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
964
+ ),
965
+ ],
966
+ "load_path": [
967
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
968
+ ],
969
+ },
970
+ # Translator
971
+ "opus-mt-zh-en": {
972
+ "file_list": [
973
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
974
+ (
975
+ "moxying/opus-mt-zh-en",
976
+ "generation_config.json",
977
+ "models/translator/opus-mt-zh-en",
978
+ ),
979
+ (
980
+ "moxying/opus-mt-zh-en",
981
+ "metadata.json",
982
+ "models/translator/opus-mt-zh-en",
983
+ ),
984
+ (
985
+ "moxying/opus-mt-zh-en",
986
+ "pytorch_model.bin",
987
+ "models/translator/opus-mt-zh-en",
988
+ ),
989
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
990
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
991
+ (
992
+ "moxying/opus-mt-zh-en",
993
+ "tokenizer_config.json",
994
+ "models/translator/opus-mt-zh-en",
995
+ ),
996
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
997
+ ],
998
+ "load_path": [
999
+ "models/translator/opus-mt-zh-en",
1000
+ ],
1001
+ },
1002
+ # IP-Adapter
1003
+ "IP-Adapter-SD": [
1004
+ (
1005
+ "AI-ModelScope/IP-Adapter",
1006
+ "models/image_encoder/model.safetensors",
1007
+ "models/IpAdapter/stable_diffusion/image_encoder",
1008
+ ),
1009
+ (
1010
+ "AI-ModelScope/IP-Adapter",
1011
+ "models/ip-adapter_sd15.bin",
1012
+ "models/IpAdapter/stable_diffusion",
1013
+ ),
1014
+ ],
1015
+ "IP-Adapter-SDXL": [
1016
+ (
1017
+ "AI-ModelScope/IP-Adapter",
1018
+ "sdxl_models/image_encoder/model.safetensors",
1019
+ "models/IpAdapter/stable_diffusion_xl/image_encoder",
1020
+ ),
1021
+ (
1022
+ "AI-ModelScope/IP-Adapter",
1023
+ "sdxl_models/ip-adapter_sdxl.bin",
1024
+ "models/IpAdapter/stable_diffusion_xl",
1025
+ ),
1026
+ ],
1027
+ # Kolors
1028
+ "Kolors": {
1029
+ "file_list": [
1030
+ (
1031
+ "Kwai-Kolors/Kolors",
1032
+ "text_encoder/config.json",
1033
+ "models/kolors/Kolors/text_encoder",
1034
+ ),
1035
+ (
1036
+ "Kwai-Kolors/Kolors",
1037
+ "text_encoder/pytorch_model.bin.index.json",
1038
+ "models/kolors/Kolors/text_encoder",
1039
+ ),
1040
+ (
1041
+ "Kwai-Kolors/Kolors",
1042
+ "text_encoder/pytorch_model-00001-of-00007.bin",
1043
+ "models/kolors/Kolors/text_encoder",
1044
+ ),
1045
+ (
1046
+ "Kwai-Kolors/Kolors",
1047
+ "text_encoder/pytorch_model-00002-of-00007.bin",
1048
+ "models/kolors/Kolors/text_encoder",
1049
+ ),
1050
+ (
1051
+ "Kwai-Kolors/Kolors",
1052
+ "text_encoder/pytorch_model-00003-of-00007.bin",
1053
+ "models/kolors/Kolors/text_encoder",
1054
+ ),
1055
+ (
1056
+ "Kwai-Kolors/Kolors",
1057
+ "text_encoder/pytorch_model-00004-of-00007.bin",
1058
+ "models/kolors/Kolors/text_encoder",
1059
+ ),
1060
+ (
1061
+ "Kwai-Kolors/Kolors",
1062
+ "text_encoder/pytorch_model-00005-of-00007.bin",
1063
+ "models/kolors/Kolors/text_encoder",
1064
+ ),
1065
+ (
1066
+ "Kwai-Kolors/Kolors",
1067
+ "text_encoder/pytorch_model-00006-of-00007.bin",
1068
+ "models/kolors/Kolors/text_encoder",
1069
+ ),
1070
+ (
1071
+ "Kwai-Kolors/Kolors",
1072
+ "text_encoder/pytorch_model-00007-of-00007.bin",
1073
+ "models/kolors/Kolors/text_encoder",
1074
+ ),
1075
+ (
1076
+ "Kwai-Kolors/Kolors",
1077
+ "unet/diffusion_pytorch_model.safetensors",
1078
+ "models/kolors/Kolors/unet",
1079
+ ),
1080
+ (
1081
+ "Kwai-Kolors/Kolors",
1082
+ "vae/diffusion_pytorch_model.safetensors",
1083
+ "models/kolors/Kolors/vae",
1084
+ ),
1085
+ ],
1086
+ "load_path": [
1087
+ "models/kolors/Kolors/text_encoder",
1088
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
1089
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
1090
+ ],
1091
+ },
1092
+ "SDXL-vae-fp16-fix": [
1093
+ (
1094
+ "AI-ModelScope/sdxl-vae-fp16-fix",
1095
+ "diffusion_pytorch_model.safetensors",
1096
+ "models/sdxl-vae-fp16-fix",
1097
+ )
1098
+ ],
1099
+ # FLUX
1100
+ "FLUX.1-dev": {
1101
+ "file_list": [
1102
+ (
1103
+ "AI-ModelScope/FLUX.1-dev",
1104
+ "text_encoder/model.safetensors",
1105
+ "models/FLUX/FLUX.1-dev/text_encoder",
1106
+ ),
1107
+ (
1108
+ "AI-ModelScope/FLUX.1-dev",
1109
+ "text_encoder_2/config.json",
1110
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1111
+ ),
1112
+ (
1113
+ "AI-ModelScope/FLUX.1-dev",
1114
+ "text_encoder_2/model-00001-of-00002.safetensors",
1115
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1116
+ ),
1117
+ (
1118
+ "AI-ModelScope/FLUX.1-dev",
1119
+ "text_encoder_2/model-00002-of-00002.safetensors",
1120
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1121
+ ),
1122
+ (
1123
+ "AI-ModelScope/FLUX.1-dev",
1124
+ "text_encoder_2/model.safetensors.index.json",
1125
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1126
+ ),
1127
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
1128
+ (
1129
+ "AI-ModelScope/FLUX.1-dev",
1130
+ "flux1-dev.safetensors",
1131
+ "models/FLUX/FLUX.1-dev",
1132
+ ),
1133
+ ],
1134
+ "load_path": [
1135
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
1136
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1137
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
1138
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
1139
+ ],
1140
+ },
1141
+ "FLUX.1-schnell": {
1142
+ "file_list": [
1143
+ (
1144
+ "AI-ModelScope/FLUX.1-dev",
1145
+ "text_encoder/model.safetensors",
1146
+ "models/FLUX/FLUX.1-dev/text_encoder",
1147
+ ),
1148
+ (
1149
+ "AI-ModelScope/FLUX.1-dev",
1150
+ "text_encoder_2/config.json",
1151
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1152
+ ),
1153
+ (
1154
+ "AI-ModelScope/FLUX.1-dev",
1155
+ "text_encoder_2/model-00001-of-00002.safetensors",
1156
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1157
+ ),
1158
+ (
1159
+ "AI-ModelScope/FLUX.1-dev",
1160
+ "text_encoder_2/model-00002-of-00002.safetensors",
1161
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1162
+ ),
1163
+ (
1164
+ "AI-ModelScope/FLUX.1-dev",
1165
+ "text_encoder_2/model.safetensors.index.json",
1166
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1167
+ ),
1168
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
1169
+ (
1170
+ "AI-ModelScope/FLUX.1-schnell",
1171
+ "flux1-schnell.safetensors",
1172
+ "models/FLUX/FLUX.1-schnell",
1173
+ ),
1174
+ ],
1175
+ "load_path": [
1176
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
1177
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1178
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
1179
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors",
1180
+ ],
1181
+ },
1182
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
1183
+ (
1184
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
1185
+ "diffusion_pytorch_model.safetensors",
1186
+ "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha",
1187
+ ),
1188
+ ],
1189
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
1190
+ (
1191
+ "jasperai/Flux.1-dev-Controlnet-Depth",
1192
+ "diffusion_pytorch_model.safetensors",
1193
+ "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth",
1194
+ ),
1195
+ ],
1196
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
1197
+ (
1198
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
1199
+ "diffusion_pytorch_model.safetensors",
1200
+ "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals",
1201
+ ),
1202
+ ],
1203
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
1204
+ (
1205
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
1206
+ "diffusion_pytorch_model.safetensors",
1207
+ "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler",
1208
+ ),
1209
+ ],
1210
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
1211
+ (
1212
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
1213
+ "diffusion_pytorch_model.safetensors",
1214
+ "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
1215
+ ),
1216
+ ],
1217
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
1218
+ (
1219
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
1220
+ "diffusion_pytorch_model.safetensors",
1221
+ "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
1222
+ ),
1223
+ ],
1224
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
1225
+ (
1226
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
1227
+ "diffusion_pytorch_model.safetensors",
1228
+ "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
1229
+ ),
1230
+ ],
1231
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
1232
+ (
1233
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
1234
+ "diffusion_pytorch_model.safetensors",
1235
+ "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
1236
+ ),
1237
+ ],
1238
+ "InstantX/FLUX.1-dev-IP-Adapter": {
1239
+ "file_list": [
1240
+ (
1241
+ "InstantX/FLUX.1-dev-IP-Adapter",
1242
+ "ip-adapter.bin",
1243
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter",
1244
+ ),
1245
+ (
1246
+ "AI-ModelScope/siglip-so400m-patch14-384",
1247
+ "model.safetensors",
1248
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
1249
+ ),
1250
+ (
1251
+ "AI-ModelScope/siglip-so400m-patch14-384",
1252
+ "config.json",
1253
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
1254
+ ),
1255
+ ],
1256
+ "load_path": [
1257
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
1258
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
1259
+ ],
1260
+ },
1261
+ # ESRGAN
1262
+ "ESRGAN_x4": [
1263
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
1264
+ ],
1265
+ # RIFE
1266
+ "RIFE": [
1267
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
1268
+ ],
1269
+ # Omnigen
1270
+ "OmniGen-v1": {
1271
+ "file_list": [
1272
+ (
1273
+ "BAAI/OmniGen-v1",
1274
+ "vae/diffusion_pytorch_model.safetensors",
1275
+ "models/OmniGen/OmniGen-v1/vae",
1276
+ ),
1277
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
1278
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
1279
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
1280
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
1281
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
1282
+ ],
1283
+ "load_path": [
1284
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
1285
+ "models/OmniGen/OmniGen-v1/model.safetensors",
1286
+ ],
1287
+ },
1288
+ # CogVideo
1289
+ "CogVideoX-5B": {
1290
+ "file_list": [
1291
+ (
1292
+ "ZhipuAI/CogVideoX-5b",
1293
+ "text_encoder/config.json",
1294
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1295
+ ),
1296
+ (
1297
+ "ZhipuAI/CogVideoX-5b",
1298
+ "text_encoder/model.safetensors.index.json",
1299
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1300
+ ),
1301
+ (
1302
+ "ZhipuAI/CogVideoX-5b",
1303
+ "text_encoder/model-00001-of-00002.safetensors",
1304
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1305
+ ),
1306
+ (
1307
+ "ZhipuAI/CogVideoX-5b",
1308
+ "text_encoder/model-00002-of-00002.safetensors",
1309
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1310
+ ),
1311
+ (
1312
+ "ZhipuAI/CogVideoX-5b",
1313
+ "transformer/config.json",
1314
+ "models/CogVideo/CogVideoX-5b/transformer",
1315
+ ),
1316
+ (
1317
+ "ZhipuAI/CogVideoX-5b",
1318
+ "transformer/diffusion_pytorch_model.safetensors.index.json",
1319
+ "models/CogVideo/CogVideoX-5b/transformer",
1320
+ ),
1321
+ (
1322
+ "ZhipuAI/CogVideoX-5b",
1323
+ "transformer/diffusion_pytorch_model-00001-of-00002.safetensors",
1324
+ "models/CogVideo/CogVideoX-5b/transformer",
1325
+ ),
1326
+ (
1327
+ "ZhipuAI/CogVideoX-5b",
1328
+ "transformer/diffusion_pytorch_model-00002-of-00002.safetensors",
1329
+ "models/CogVideo/CogVideoX-5b/transformer",
1330
+ ),
1331
+ (
1332
+ "ZhipuAI/CogVideoX-5b",
1333
+ "vae/diffusion_pytorch_model.safetensors",
1334
+ "models/CogVideo/CogVideoX-5b/vae",
1335
+ ),
1336
+ ],
1337
+ "load_path": [
1338
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1339
+ "models/CogVideo/CogVideoX-5b/transformer",
1340
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
1341
+ ],
1342
+ },
1343
+ # Stable Diffusion 3.5
1344
+ "StableDiffusion3.5-large": [
1345
+ (
1346
+ "AI-ModelScope/stable-diffusion-3.5-large",
1347
+ "sd3.5_large.safetensors",
1348
+ "models/stable_diffusion_3",
1349
+ ),
1350
+ (
1351
+ "AI-ModelScope/stable-diffusion-3.5-large",
1352
+ "text_encoders/clip_l.safetensors",
1353
+ "models/stable_diffusion_3/text_encoders",
1354
+ ),
1355
+ (
1356
+ "AI-ModelScope/stable-diffusion-3.5-large",
1357
+ "text_encoders/clip_g.safetensors",
1358
+ "models/stable_diffusion_3/text_encoders",
1359
+ ),
1360
+ (
1361
+ "AI-ModelScope/stable-diffusion-3.5-large",
1362
+ "text_encoders/t5xxl_fp16.safetensors",
1363
+ "models/stable_diffusion_3/text_encoders",
1364
+ ),
1365
+ ],
1366
+ "StableDiffusion3.5-medium": [
1367
+ (
1368
+ "AI-ModelScope/stable-diffusion-3.5-medium",
1369
+ "sd3.5_medium.safetensors",
1370
+ "models/stable_diffusion_3",
1371
+ ),
1372
+ (
1373
+ "AI-ModelScope/stable-diffusion-3.5-large",
1374
+ "text_encoders/clip_l.safetensors",
1375
+ "models/stable_diffusion_3/text_encoders",
1376
+ ),
1377
+ (
1378
+ "AI-ModelScope/stable-diffusion-3.5-large",
1379
+ "text_encoders/clip_g.safetensors",
1380
+ "models/stable_diffusion_3/text_encoders",
1381
+ ),
1382
+ (
1383
+ "AI-ModelScope/stable-diffusion-3.5-large",
1384
+ "text_encoders/t5xxl_fp16.safetensors",
1385
+ "models/stable_diffusion_3/text_encoders",
1386
+ ),
1387
+ ],
1388
+ "StableDiffusion3.5-large-turbo": [
1389
+ (
1390
+ "AI-ModelScope/stable-diffusion-3.5-large-turbo",
1391
+ "sd3.5_large_turbo.safetensors",
1392
+ "models/stable_diffusion_3",
1393
+ ),
1394
+ (
1395
+ "AI-ModelScope/stable-diffusion-3.5-large",
1396
+ "text_encoders/clip_l.safetensors",
1397
+ "models/stable_diffusion_3/text_encoders",
1398
+ ),
1399
+ (
1400
+ "AI-ModelScope/stable-diffusion-3.5-large",
1401
+ "text_encoders/clip_g.safetensors",
1402
+ "models/stable_diffusion_3/text_encoders",
1403
+ ),
1404
+ (
1405
+ "AI-ModelScope/stable-diffusion-3.5-large",
1406
+ "text_encoders/t5xxl_fp16.safetensors",
1407
+ "models/stable_diffusion_3/text_encoders",
1408
+ ),
1409
+ ],
1410
+ "HunyuanVideo": {
1411
+ "file_list": [
1412
+ (
1413
+ "AI-ModelScope/clip-vit-large-patch14",
1414
+ "model.safetensors",
1415
+ "models/HunyuanVideo/text_encoder",
1416
+ ),
1417
+ (
1418
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1419
+ "model-00001-of-00004.safetensors",
1420
+ "models/HunyuanVideo/text_encoder_2",
1421
+ ),
1422
+ (
1423
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1424
+ "model-00002-of-00004.safetensors",
1425
+ "models/HunyuanVideo/text_encoder_2",
1426
+ ),
1427
+ (
1428
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1429
+ "model-00003-of-00004.safetensors",
1430
+ "models/HunyuanVideo/text_encoder_2",
1431
+ ),
1432
+ (
1433
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1434
+ "model-00004-of-00004.safetensors",
1435
+ "models/HunyuanVideo/text_encoder_2",
1436
+ ),
1437
+ (
1438
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1439
+ "config.json",
1440
+ "models/HunyuanVideo/text_encoder_2",
1441
+ ),
1442
+ (
1443
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1444
+ "model.safetensors.index.json",
1445
+ "models/HunyuanVideo/text_encoder_2",
1446
+ ),
1447
+ (
1448
+ "AI-ModelScope/HunyuanVideo",
1449
+ "hunyuan-video-t2v-720p/vae/pytorch_model.pt",
1450
+ "models/HunyuanVideo/vae",
1451
+ ),
1452
+ (
1453
+ "AI-ModelScope/HunyuanVideo",
1454
+ "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
1455
+ "models/HunyuanVideo/transformers",
1456
+ ),
1457
+ ],
1458
+ "load_path": [
1459
+ "models/HunyuanVideo/text_encoder/model.safetensors",
1460
+ "models/HunyuanVideo/text_encoder_2",
1461
+ "models/HunyuanVideo/vae/pytorch_model.pt",
1462
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt",
1463
+ ],
1464
+ },
1465
+ "HunyuanVideo-fp8": {
1466
+ "file_list": [
1467
+ (
1468
+ "AI-ModelScope/clip-vit-large-patch14",
1469
+ "model.safetensors",
1470
+ "models/HunyuanVideo/text_encoder",
1471
+ ),
1472
+ (
1473
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1474
+ "model-00001-of-00004.safetensors",
1475
+ "models/HunyuanVideo/text_encoder_2",
1476
+ ),
1477
+ (
1478
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1479
+ "model-00002-of-00004.safetensors",
1480
+ "models/HunyuanVideo/text_encoder_2",
1481
+ ),
1482
+ (
1483
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1484
+ "model-00003-of-00004.safetensors",
1485
+ "models/HunyuanVideo/text_encoder_2",
1486
+ ),
1487
+ (
1488
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1489
+ "model-00004-of-00004.safetensors",
1490
+ "models/HunyuanVideo/text_encoder_2",
1491
+ ),
1492
+ (
1493
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1494
+ "config.json",
1495
+ "models/HunyuanVideo/text_encoder_2",
1496
+ ),
1497
+ (
1498
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1499
+ "model.safetensors.index.json",
1500
+ "models/HunyuanVideo/text_encoder_2",
1501
+ ),
1502
+ (
1503
+ "AI-ModelScope/HunyuanVideo",
1504
+ "hunyuan-video-t2v-720p/vae/pytorch_model.pt",
1505
+ "models/HunyuanVideo/vae",
1506
+ ),
1507
+ (
1508
+ "DiffSynth-Studio/HunyuanVideo-safetensors",
1509
+ "model.fp8.safetensors",
1510
+ "models/HunyuanVideo/transformers",
1511
+ ),
1512
+ ],
1513
+ "load_path": [
1514
+ "models/HunyuanVideo/text_encoder/model.safetensors",
1515
+ "models/HunyuanVideo/text_encoder_2",
1516
+ "models/HunyuanVideo/vae/pytorch_model.pt",
1517
+ "models/HunyuanVideo/transformers/model.fp8.safetensors",
1518
+ ],
1519
+ },
1520
+ }
1521
+ Preset_model_id: TypeAlias = Literal[
1522
+ "HunyuanDiT",
1523
+ "stable-video-diffusion-img2vid-xt",
1524
+ "ExVideo-SVD-128f-v1",
1525
+ "ExVideo-CogVideoX-LoRA-129f-v1",
1526
+ "StableDiffusion_v15",
1527
+ "DreamShaper_8",
1528
+ "AingDiffusion_v12",
1529
+ "Flat2DAnimerge_v45Sharp",
1530
+ "TextualInversion_VeryBadImageNegative_v1.3",
1531
+ "StableDiffusionXL_v1",
1532
+ "BluePencilXL_v200",
1533
+ "StableDiffusionXL_Turbo",
1534
+ "ControlNet_v11f1p_sd15_depth",
1535
+ "ControlNet_v11p_sd15_softedge",
1536
+ "ControlNet_v11f1e_sd15_tile",
1537
+ "ControlNet_v11p_sd15_lineart",
1538
+ "AnimateDiff_v2",
1539
+ "AnimateDiff_xl_beta",
1540
+ "RIFE",
1541
+ "BeautifulPrompt",
1542
+ "opus-mt-zh-en",
1543
+ "IP-Adapter-SD",
1544
+ "IP-Adapter-SDXL",
1545
+ "StableDiffusion3",
1546
+ "StableDiffusion3_without_T5",
1547
+ "Kolors",
1548
+ "SDXL-vae-fp16-fix",
1549
+ "ControlNet_union_sdxl_promax",
1550
+ "FLUX.1-dev",
1551
+ "FLUX.1-schnell",
1552
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
1553
+ "jasperai/Flux.1-dev-Controlnet-Depth",
1554
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
1555
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
1556
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
1557
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
1558
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
1559
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
1560
+ "InstantX/FLUX.1-dev-IP-Adapter",
1561
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
1562
+ "QwenPrompt",
1563
+ "OmostPrompt",
1564
+ "ESRGAN_x4",
1565
+ "RIFE",
1566
+ "OmniGen-v1",
1567
+ "CogVideoX-5B",
1568
+ "Annotators:Depth",
1569
+ "Annotators:Softedge",
1570
+ "Annotators:Lineart",
1571
+ "Annotators:Normal",
1572
+ "Annotators:Openpose",
1573
+ "StableDiffusion3.5-large",
1574
+ "StableDiffusion3.5-medium",
1575
+ "HunyuanVideo",
1576
+ "HunyuanVideo-fp8",
1577
+ ]
FantasyTalking/diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_frames, save_video
FantasyTalking/diffsynth/data/video.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import imageio
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+
8
+
9
+ class LowMemoryVideo:
10
+ def __init__(self, file_name):
11
+ self.reader = imageio.get_reader(file_name)
12
+
13
+ def __len__(self):
14
+ return self.reader.count_frames()
15
+
16
+ def __getitem__(self, item):
17
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
18
+
19
+ def __del__(self):
20
+ self.reader.close()
21
+
22
+
23
+ def split_file_name(file_name):
24
+ result = []
25
+ number = -1
26
+ for i in file_name:
27
+ if ord(i) >= ord("0") and ord(i) <= ord("9"):
28
+ if number == -1:
29
+ number = 0
30
+ number = number * 10 + ord(i) - ord("0")
31
+ else:
32
+ if number != -1:
33
+ result.append(number)
34
+ number = -1
35
+ result.append(i)
36
+ if number != -1:
37
+ result.append(number)
38
+ result = tuple(result)
39
+ return result
40
+
41
+
42
+ def search_for_images(folder):
43
+ file_list = [
44
+ i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")
45
+ ]
46
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
47
+ file_list = [i[1] for i in sorted(file_list)]
48
+ file_list = [os.path.join(folder, i) for i in file_list]
49
+ return file_list
50
+
51
+
52
+ class LowMemoryImageFolder:
53
+ def __init__(self, folder, file_list=None):
54
+ if file_list is None:
55
+ self.file_list = search_for_images(folder)
56
+ else:
57
+ self.file_list = [
58
+ os.path.join(folder, file_name) for file_name in file_list
59
+ ]
60
+
61
+ def __len__(self):
62
+ return len(self.file_list)
63
+
64
+ def __getitem__(self, item):
65
+ return Image.open(self.file_list[item]).convert("RGB")
66
+
67
+ def __del__(self):
68
+ pass
69
+
70
+
71
+ def crop_and_resize(image, height, width):
72
+ image = np.array(image)
73
+ image_height, image_width, _ = image.shape
74
+ if image_height / image_width < height / width:
75
+ croped_width = int(image_height / height * width)
76
+ left = (image_width - croped_width) // 2
77
+ image = image[:, left : left + croped_width]
78
+ image = Image.fromarray(image).resize((width, height))
79
+ else:
80
+ croped_height = int(image_width / width * height)
81
+ left = (image_height - croped_height) // 2
82
+ image = image[left : left + croped_height, :]
83
+ image = Image.fromarray(image).resize((width, height))
84
+ return image
85
+
86
+
87
+ class VideoData:
88
+ def __init__(
89
+ self, video_file=None, image_folder=None, height=None, width=None, **kwargs
90
+ ):
91
+ if video_file is not None:
92
+ self.data_type = "video"
93
+ self.data = LowMemoryVideo(video_file, **kwargs)
94
+ elif image_folder is not None:
95
+ self.data_type = "images"
96
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
97
+ else:
98
+ raise ValueError("Cannot open video or image folder")
99
+ self.length = None
100
+ self.set_shape(height, width)
101
+
102
+ def raw_data(self):
103
+ frames = []
104
+ for i in range(self.__len__()):
105
+ frames.append(self.__getitem__(i))
106
+ return frames
107
+
108
+ def set_length(self, length):
109
+ self.length = length
110
+
111
+ def set_shape(self, height, width):
112
+ self.height = height
113
+ self.width = width
114
+
115
+ def __len__(self):
116
+ if self.length is None:
117
+ return len(self.data)
118
+ else:
119
+ return self.length
120
+
121
+ def shape(self):
122
+ if self.height is not None and self.width is not None:
123
+ return self.height, self.width
124
+ else:
125
+ height, width, _ = self.__getitem__(0).shape
126
+ return height, width
127
+
128
+ def __getitem__(self, item):
129
+ frame = self.data.__getitem__(item)
130
+ width, height = frame.size
131
+ if self.height is not None and self.width is not None:
132
+ if self.height != height or self.width != width:
133
+ frame = crop_and_resize(frame, self.height, self.width)
134
+ return frame
135
+
136
+ def __del__(self):
137
+ pass
138
+
139
+ def save_images(self, folder):
140
+ os.makedirs(folder, exist_ok=True)
141
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
142
+ frame = self.__getitem__(i)
143
+ frame.save(os.path.join(folder, f"{i}.png"))
144
+
145
+
146
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
147
+ writer = imageio.get_writer(
148
+ save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
149
+ )
150
+ for frame in tqdm(frames, desc="Saving video"):
151
+ frame = np.array(frame)
152
+ writer.append_data(frame)
153
+ writer.close()
154
+
155
+
156
+ # def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
157
+ # writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=["-crf", "0", "-preset", "veryslow"])
158
+ # for frame in tqdm(frames, desc="Saving video"):
159
+ # frame = np.array(frame)
160
+ # writer.append_data(frame)
161
+ # writer.close()
162
+
163
+ # def save_video_h264(frames, save_path, fps, ffmpeg_params=None):
164
+ # import imageio.v3 as iio
165
+ # from tqdm import tqdm
166
+ # import numpy as np
167
+
168
+ # if ffmpeg_params is None:
169
+ # ffmpeg_params = ["-crf", "0", "-preset", "ultrafast"] # 无损 H.264
170
+
171
+ # writer = iio.get_writer(save_path, fps=fps, codec="libx264", ffmpeg_params=ffmpeg_params)
172
+ # for frame in tqdm(frames, desc="Saving video"):
173
+ # writer.append_data(np.array(frame))
174
+ # writer.close()
175
+
176
+
177
+ def save_frames(frames, save_path):
178
+ os.makedirs(save_path, exist_ok=True)
179
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
180
+ frame.save(os.path.join(save_path, f"{i}.png"))
181
+
182
+
183
+ if __name__ == "__main__":
184
+ frames = [
185
+ Image.fromarray(np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8))
186
+ for i in range(81)
187
+ ]
188
+ save_video(frames, "haha.mp4", 23, 5)
FantasyTalking/diffsynth/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_manager import *
FantasyTalking/diffsynth/models/downloader.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from typing import List
4
+
5
+ from huggingface_hub import hf_hub_download
6
+ from modelscope import snapshot_download
7
+ from typing_extensions import Literal, TypeAlias
8
+
9
+ from ..configs.model_config import (Preset_model_id,
10
+ preset_models_on_huggingface,
11
+ preset_models_on_modelscope)
12
+
13
+
14
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
15
+ os.makedirs(local_dir, exist_ok=True)
16
+ file_name = os.path.basename(origin_file_path)
17
+ if file_name in os.listdir(local_dir):
18
+ print(f" {file_name} has been already in {local_dir}.")
19
+ else:
20
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
21
+ snapshot_download(
22
+ model_id, allow_file_pattern=origin_file_path, local_dir=local_dir
23
+ )
24
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
25
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
26
+ if downloaded_file_path != target_file_path:
27
+ shutil.move(downloaded_file_path, target_file_path)
28
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
29
+
30
+
31
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
32
+ os.makedirs(local_dir, exist_ok=True)
33
+ file_name = os.path.basename(origin_file_path)
34
+ if file_name in os.listdir(local_dir):
35
+ print(f" {file_name} has been already in {local_dir}.")
36
+ else:
37
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
38
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
39
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
40
+ target_file_path = os.path.join(local_dir, file_name)
41
+ if downloaded_file_path != target_file_path:
42
+ shutil.move(downloaded_file_path, target_file_path)
43
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
44
+
45
+
46
+ Preset_model_website: TypeAlias = Literal[
47
+ "HuggingFace",
48
+ "ModelScope",
49
+ ]
50
+ website_to_preset_models = {
51
+ "HuggingFace": preset_models_on_huggingface,
52
+ "ModelScope": preset_models_on_modelscope,
53
+ }
54
+ website_to_download_fn = {
55
+ "HuggingFace": download_from_huggingface,
56
+ "ModelScope": download_from_modelscope,
57
+ }
58
+
59
+
60
+ def download_customized_models(
61
+ model_id,
62
+ origin_file_path,
63
+ local_dir,
64
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
65
+ ):
66
+ downloaded_files = []
67
+ for website in downloading_priority:
68
+ # Check if the file is downloaded.
69
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
70
+ if file_to_download in downloaded_files:
71
+ continue
72
+ # Download
73
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
74
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
75
+ downloaded_files.append(file_to_download)
76
+ return downloaded_files
77
+
78
+
79
+ def download_models(
80
+ model_id_list: List[Preset_model_id] = [],
81
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
82
+ ):
83
+ print(f"Downloading models: {model_id_list}")
84
+ downloaded_files = []
85
+ load_files = []
86
+
87
+ for model_id in model_id_list:
88
+ for website in downloading_priority:
89
+ if model_id in website_to_preset_models[website]:
90
+ # Parse model metadata
91
+ model_metadata = website_to_preset_models[website][model_id]
92
+ if isinstance(model_metadata, list):
93
+ file_data = model_metadata
94
+ else:
95
+ file_data = model_metadata.get("file_list", [])
96
+
97
+ # Try downloading the model from this website.
98
+ model_files = []
99
+ for model_id, origin_file_path, local_dir in file_data:
100
+ # Check if the file is downloaded.
101
+ file_to_download = os.path.join(
102
+ local_dir, os.path.basename(origin_file_path)
103
+ )
104
+ if file_to_download in downloaded_files:
105
+ continue
106
+ # Download
107
+ website_to_download_fn[website](
108
+ model_id, origin_file_path, local_dir
109
+ )
110
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
111
+ downloaded_files.append(file_to_download)
112
+ model_files.append(file_to_download)
113
+
114
+ # If the model is successfully downloaded, break.
115
+ if len(model_files) > 0:
116
+ if (
117
+ isinstance(model_metadata, dict)
118
+ and "load_path" in model_metadata
119
+ ):
120
+ model_files = model_metadata["load_path"]
121
+ load_files.extend(model_files)
122
+ break
123
+
124
+ return load_files
FantasyTalking/diffsynth/models/model_manager.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from ..configs.model_config import (huggingface_model_loader_configs,
9
+ model_loader_configs,
10
+ patch_model_loader_configs)
11
+ from .downloader import (Preset_model_id, Preset_model_website,
12
+ download_customized_models, download_models)
13
+ from .utils import (hash_state_dict_keys, init_weights_on_device,
14
+ load_state_dict, split_state_dict_with_prefix)
15
+
16
+
17
+ def load_model_from_single_file(
18
+ state_dict, model_names, model_classes, model_resource, torch_dtype, device
19
+ ):
20
+ loaded_model_names, loaded_models = [], []
21
+ for model_name, model_class in zip(model_names, model_classes):
22
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
23
+ state_dict_converter = model_class.state_dict_converter()
24
+ if model_resource == "civitai":
25
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
26
+ elif model_resource == "diffusers":
27
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
28
+ if isinstance(state_dict_results, tuple):
29
+ model_state_dict, extra_kwargs = state_dict_results
30
+ print(
31
+ f" This model is initialized with extra kwargs: {extra_kwargs}"
32
+ )
33
+ else:
34
+ model_state_dict, extra_kwargs = state_dict_results, {}
35
+ torch_dtype = (
36
+ torch.float32
37
+ if extra_kwargs.get("upcast_to_float32", False)
38
+ else torch_dtype
39
+ )
40
+ with init_weights_on_device():
41
+ model = model_class(**extra_kwargs)
42
+ if hasattr(model, "eval"):
43
+ model = model.eval()
44
+ model.load_state_dict(model_state_dict, assign=True)
45
+ model = model.to(dtype=torch_dtype, device=device)
46
+ loaded_model_names.append(model_name)
47
+ loaded_models.append(model)
48
+ return loaded_model_names, loaded_models
49
+
50
+
51
+ def load_model_from_huggingface_folder(
52
+ file_path, model_names, model_classes, torch_dtype, device
53
+ ):
54
+ loaded_model_names, loaded_models = [], []
55
+ for model_name, model_class in zip(model_names, model_classes):
56
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
57
+ model = model_class.from_pretrained(
58
+ file_path, torch_dtype=torch_dtype
59
+ ).eval()
60
+ else:
61
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
62
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
63
+ model = model.half()
64
+ try:
65
+ model = model.to(device=device)
66
+ except:
67
+ pass
68
+ loaded_model_names.append(model_name)
69
+ loaded_models.append(model)
70
+ return loaded_model_names, loaded_models
71
+
72
+
73
+ def load_single_patch_model_from_single_file(
74
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device
75
+ ):
76
+ print(
77
+ f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}"
78
+ )
79
+ base_state_dict = base_model.state_dict()
80
+ base_model.to("cpu")
81
+ del base_model
82
+ model = model_class(**extra_kwargs)
83
+ model.load_state_dict(base_state_dict, strict=False)
84
+ model.load_state_dict(state_dict, strict=False)
85
+ model.to(dtype=torch_dtype, device=device)
86
+ return model
87
+
88
+
89
+ def load_patch_model_from_single_file(
90
+ state_dict,
91
+ model_names,
92
+ model_classes,
93
+ extra_kwargs,
94
+ model_manager,
95
+ torch_dtype,
96
+ device,
97
+ ):
98
+ loaded_model_names, loaded_models = [], []
99
+ for model_name, model_class in zip(model_names, model_classes):
100
+ while True:
101
+ for model_id in range(len(model_manager.model)):
102
+ base_model_name = model_manager.model_name[model_id]
103
+ if base_model_name == model_name:
104
+ base_model_path = model_manager.model_path[model_id]
105
+ base_model = model_manager.model[model_id]
106
+ print(
107
+ f" Adding patch model to {base_model_name} ({base_model_path})"
108
+ )
109
+ patched_model = load_single_patch_model_from_single_file(
110
+ state_dict,
111
+ model_name,
112
+ model_class,
113
+ base_model,
114
+ extra_kwargs,
115
+ torch_dtype,
116
+ device,
117
+ )
118
+ loaded_model_names.append(base_model_name)
119
+ loaded_models.append(patched_model)
120
+ model_manager.model.pop(model_id)
121
+ model_manager.model_path.pop(model_id)
122
+ model_manager.model_name.pop(model_id)
123
+ break
124
+ else:
125
+ break
126
+ return loaded_model_names, loaded_models
127
+
128
+
129
+ class ModelDetectorTemplate:
130
+ def __init__(self):
131
+ pass
132
+
133
+ def match(self, file_path="", state_dict={}):
134
+ return False
135
+
136
+ def load(
137
+ self,
138
+ file_path="",
139
+ state_dict={},
140
+ device="cuda",
141
+ torch_dtype=torch.float16,
142
+ **kwargs,
143
+ ):
144
+ return [], []
145
+
146
+
147
+ class ModelDetectorFromSingleFile:
148
+ def __init__(self, model_loader_configs=[]):
149
+ self.keys_hash_with_shape_dict = {}
150
+ self.keys_hash_dict = {}
151
+ for metadata in model_loader_configs:
152
+ self.add_model_metadata(*metadata)
153
+
154
+ def add_model_metadata(
155
+ self,
156
+ keys_hash,
157
+ keys_hash_with_shape,
158
+ model_names,
159
+ model_classes,
160
+ model_resource,
161
+ ):
162
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
163
+ model_names,
164
+ model_classes,
165
+ model_resource,
166
+ )
167
+ if keys_hash is not None:
168
+ self.keys_hash_dict[keys_hash] = (
169
+ model_names,
170
+ model_classes,
171
+ model_resource,
172
+ )
173
+
174
+ def match(self, file_path="", state_dict={}):
175
+ if isinstance(file_path, str) and os.path.isdir(file_path):
176
+ return False
177
+ if len(state_dict) == 0:
178
+ state_dict = load_state_dict(file_path)
179
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
180
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
181
+ return True
182
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
183
+ if keys_hash in self.keys_hash_dict:
184
+ return True
185
+ return False
186
+
187
+ def load(
188
+ self,
189
+ file_path="",
190
+ state_dict={},
191
+ device="cuda",
192
+ torch_dtype=torch.float16,
193
+ **kwargs,
194
+ ):
195
+ if len(state_dict) == 0:
196
+ state_dict = load_state_dict(file_path)
197
+
198
+ # Load models with strict matching
199
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
200
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
201
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[
202
+ keys_hash_with_shape
203
+ ]
204
+ loaded_model_names, loaded_models = load_model_from_single_file(
205
+ state_dict,
206
+ model_names,
207
+ model_classes,
208
+ model_resource,
209
+ torch_dtype,
210
+ device,
211
+ )
212
+ return loaded_model_names, loaded_models
213
+
214
+ # Load models without strict matching
215
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
216
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
217
+ if keys_hash in self.keys_hash_dict:
218
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
219
+ loaded_model_names, loaded_models = load_model_from_single_file(
220
+ state_dict,
221
+ model_names,
222
+ model_classes,
223
+ model_resource,
224
+ torch_dtype,
225
+ device,
226
+ )
227
+ return loaded_model_names, loaded_models
228
+
229
+ return loaded_model_names, loaded_models
230
+
231
+
232
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
233
+ def __init__(self, model_loader_configs=[]):
234
+ super().__init__(model_loader_configs)
235
+
236
+ def match(self, file_path="", state_dict={}):
237
+ if isinstance(file_path, str) and os.path.isdir(file_path):
238
+ return False
239
+ if len(state_dict) == 0:
240
+ state_dict = load_state_dict(file_path)
241
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
242
+ for sub_state_dict in splited_state_dict:
243
+ if super().match(file_path, sub_state_dict):
244
+ return True
245
+ return False
246
+
247
+ def load(
248
+ self,
249
+ file_path="",
250
+ state_dict={},
251
+ device="cuda",
252
+ torch_dtype=torch.float16,
253
+ **kwargs,
254
+ ):
255
+ # Split the state_dict and load from each component
256
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
257
+ valid_state_dict = {}
258
+ for sub_state_dict in splited_state_dict:
259
+ if super().match(file_path, sub_state_dict):
260
+ valid_state_dict.update(sub_state_dict)
261
+ if super().match(file_path, valid_state_dict):
262
+ loaded_model_names, loaded_models = super().load(
263
+ file_path, valid_state_dict, device, torch_dtype
264
+ )
265
+ else:
266
+ loaded_model_names, loaded_models = [], []
267
+ for sub_state_dict in splited_state_dict:
268
+ if super().match(file_path, sub_state_dict):
269
+ loaded_model_names_, loaded_models_ = super().load(
270
+ file_path, valid_state_dict, device, torch_dtype
271
+ )
272
+ loaded_model_names += loaded_model_names_
273
+ loaded_models += loaded_models_
274
+ return loaded_model_names, loaded_models
275
+
276
+
277
+ class ModelDetectorFromHuggingfaceFolder:
278
+ def __init__(self, model_loader_configs=[]):
279
+ self.architecture_dict = {}
280
+ for metadata in model_loader_configs:
281
+ self.add_model_metadata(*metadata)
282
+
283
+ def add_model_metadata(
284
+ self, architecture, huggingface_lib, model_name, redirected_architecture
285
+ ):
286
+ self.architecture_dict[architecture] = (
287
+ huggingface_lib,
288
+ model_name,
289
+ redirected_architecture,
290
+ )
291
+
292
+ def match(self, file_path="", state_dict={}):
293
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
294
+ return False
295
+ file_list = os.listdir(file_path)
296
+ if "config.json" not in file_list:
297
+ return False
298
+ with open(os.path.join(file_path, "config.json"), "r") as f:
299
+ config = json.load(f)
300
+ if "architectures" not in config and "_class_name" not in config:
301
+ return False
302
+ return True
303
+
304
+ def load(
305
+ self,
306
+ file_path="",
307
+ state_dict={},
308
+ device="cuda",
309
+ torch_dtype=torch.float16,
310
+ **kwargs,
311
+ ):
312
+ with open(os.path.join(file_path, "config.json"), "r") as f:
313
+ config = json.load(f)
314
+ loaded_model_names, loaded_models = [], []
315
+ architectures = (
316
+ config["architectures"]
317
+ if "architectures" in config
318
+ else [config["_class_name"]]
319
+ )
320
+ for architecture in architectures:
321
+ (
322
+ huggingface_lib,
323
+ model_name,
324
+ redirected_architecture,
325
+ ) = self.architecture_dict[architecture]
326
+ if redirected_architecture is not None:
327
+ architecture = redirected_architecture
328
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(
329
+ architecture
330
+ )
331
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(
332
+ file_path, [model_name], [model_class], torch_dtype, device
333
+ )
334
+ loaded_model_names += loaded_model_names_
335
+ loaded_models += loaded_models_
336
+ return loaded_model_names, loaded_models
337
+
338
+
339
+ class ModelDetectorFromPatchedSingleFile:
340
+ def __init__(self, model_loader_configs=[]):
341
+ self.keys_hash_with_shape_dict = {}
342
+ for metadata in model_loader_configs:
343
+ self.add_model_metadata(*metadata)
344
+
345
+ def add_model_metadata(
346
+ self, keys_hash_with_shape, model_name, model_class, extra_kwargs
347
+ ):
348
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
349
+ model_name,
350
+ model_class,
351
+ extra_kwargs,
352
+ )
353
+
354
+ def match(self, file_path="", state_dict={}):
355
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
356
+ return False
357
+ if len(state_dict) == 0:
358
+ state_dict = load_state_dict(file_path)
359
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
360
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
361
+ return True
362
+ return False
363
+
364
+ def load(
365
+ self,
366
+ file_path="",
367
+ state_dict={},
368
+ device="cuda",
369
+ torch_dtype=torch.float16,
370
+ model_manager=None,
371
+ **kwargs,
372
+ ):
373
+ if len(state_dict) == 0:
374
+ state_dict = load_state_dict(file_path)
375
+
376
+ # Load models with strict matching
377
+ loaded_model_names, loaded_models = [], []
378
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
379
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
380
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[
381
+ keys_hash_with_shape
382
+ ]
383
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
384
+ state_dict,
385
+ model_names,
386
+ model_classes,
387
+ extra_kwargs,
388
+ model_manager,
389
+ torch_dtype,
390
+ device,
391
+ )
392
+ loaded_model_names += loaded_model_names_
393
+ loaded_models += loaded_models_
394
+ return loaded_model_names, loaded_models
395
+
396
+
397
+ class ModelManager:
398
+ def __init__(
399
+ self,
400
+ torch_dtype=torch.float16,
401
+ device="cuda",
402
+ model_id_list: List[Preset_model_id] = [],
403
+ downloading_priority: List[Preset_model_website] = [
404
+ "ModelScope",
405
+ "HuggingFace",
406
+ ],
407
+ file_path_list: List[str] = [],
408
+ ):
409
+ self.torch_dtype = torch_dtype
410
+ self.device = device
411
+ self.model = []
412
+ self.model_path = []
413
+ self.model_name = []
414
+ downloaded_files = (
415
+ download_models(model_id_list, downloading_priority)
416
+ if len(model_id_list) > 0
417
+ else []
418
+ )
419
+ self.model_detector = [
420
+ ModelDetectorFromSingleFile(model_loader_configs),
421
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
422
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
423
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
424
+ ]
425
+ self.load_models(downloaded_files + file_path_list)
426
+
427
+ def load_model_from_single_file(
428
+ self,
429
+ file_path="",
430
+ state_dict={},
431
+ model_names=[],
432
+ model_classes=[],
433
+ model_resource=None,
434
+ ):
435
+ print(f"Loading models from file: {file_path}")
436
+ if len(state_dict) == 0:
437
+ state_dict = load_state_dict(file_path)
438
+ model_names, models = load_model_from_single_file(
439
+ state_dict,
440
+ model_names,
441
+ model_classes,
442
+ model_resource,
443
+ self.torch_dtype,
444
+ self.device,
445
+ )
446
+ for model_name, model in zip(model_names, models):
447
+ self.model.append(model)
448
+ self.model_path.append(file_path)
449
+ self.model_name.append(model_name)
450
+ print(f" The following models are loaded: {model_names}.")
451
+
452
+ def load_model_from_huggingface_folder(
453
+ self, file_path="", model_names=[], model_classes=[]
454
+ ):
455
+ print(f"Loading models from folder: {file_path}")
456
+ model_names, models = load_model_from_huggingface_folder(
457
+ file_path, model_names, model_classes, self.torch_dtype, self.device
458
+ )
459
+ for model_name, model in zip(model_names, models):
460
+ self.model.append(model)
461
+ self.model_path.append(file_path)
462
+ self.model_name.append(model_name)
463
+ print(f" The following models are loaded: {model_names}.")
464
+
465
+ def load_patch_model_from_single_file(
466
+ self,
467
+ file_path="",
468
+ state_dict={},
469
+ model_names=[],
470
+ model_classes=[],
471
+ extra_kwargs={},
472
+ ):
473
+ print(f"Loading patch models from file: {file_path}")
474
+ model_names, models = load_patch_model_from_single_file(
475
+ state_dict,
476
+ model_names,
477
+ model_classes,
478
+ extra_kwargs,
479
+ self,
480
+ self.torch_dtype,
481
+ self.device,
482
+ )
483
+ for model_name, model in zip(model_names, models):
484
+ self.model.append(model)
485
+ self.model_path.append(file_path)
486
+ self.model_name.append(model_name)
487
+ print(f" The following patched models are loaded: {model_names}.")
488
+
489
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
490
+ if isinstance(file_path, list):
491
+ for file_path_ in file_path:
492
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
493
+ else:
494
+ print(f"Loading LoRA models from file: {file_path}")
495
+ if len(state_dict) == 0:
496
+ state_dict = load_state_dict(file_path)
497
+ for model_name, model, model_path in zip(
498
+ self.model_name, self.model, self.model_path
499
+ ):
500
+ for lora in get_lora_loaders():
501
+ match_results = lora.match(model, state_dict)
502
+ if match_results is not None:
503
+ print(f" Adding LoRA to {model_name} ({model_path}).")
504
+ lora_prefix, model_resource = match_results
505
+ lora.load(
506
+ model,
507
+ state_dict,
508
+ lora_prefix,
509
+ alpha=lora_alpha,
510
+ model_resource=model_resource,
511
+ )
512
+ break
513
+
514
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
515
+ print(f"Loading models from: {file_path}")
516
+ if device is None:
517
+ device = self.device
518
+ if torch_dtype is None:
519
+ torch_dtype = self.torch_dtype
520
+ if isinstance(file_path, list):
521
+ state_dict = {}
522
+ for path in file_path:
523
+ state_dict.update(load_state_dict(path))
524
+ elif os.path.isfile(file_path):
525
+ state_dict = load_state_dict(file_path)
526
+ else:
527
+ state_dict = None
528
+ for model_detector in self.model_detector:
529
+ if model_detector.match(file_path, state_dict):
530
+ model_names, models = model_detector.load(
531
+ file_path,
532
+ state_dict,
533
+ device=device,
534
+ torch_dtype=torch_dtype,
535
+ allowed_model_names=model_names,
536
+ model_manager=self,
537
+ )
538
+ for model_name, model in zip(model_names, models):
539
+ self.model.append(model)
540
+ self.model_path.append(file_path)
541
+ self.model_name.append(model_name)
542
+ print(f" The following models are loaded: {model_names}.")
543
+ break
544
+ else:
545
+ print(f" We cannot detect the model type. No models are loaded.")
546
+
547
+ def load_models(
548
+ self, file_path_list, model_names=None, device=None, torch_dtype=None
549
+ ):
550
+ for file_path in file_path_list:
551
+ self.load_model(
552
+ file_path, model_names, device=device, torch_dtype=torch_dtype
553
+ )
554
+
555
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
556
+ fetched_models = []
557
+ fetched_model_paths = []
558
+ for model, model_path, model_name_ in zip(
559
+ self.model, self.model_path, self.model_name
560
+ ):
561
+ if file_path is not None and file_path != model_path:
562
+ continue
563
+ if model_name == model_name_:
564
+ fetched_models.append(model)
565
+ fetched_model_paths.append(model_path)
566
+ if len(fetched_models) == 0:
567
+ print(f"No {model_name} models available.")
568
+ return None
569
+ if len(fetched_models) == 1:
570
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
571
+ else:
572
+ print(
573
+ f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}."
574
+ )
575
+ if require_model_path:
576
+ return fetched_models[0], fetched_model_paths[0]
577
+ else:
578
+ return fetched_models[0]
579
+
580
+ def to(self, device):
581
+ for model in self.model:
582
+ model.to(device)
FantasyTalking/diffsynth/models/utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ from contextlib import contextmanager
4
+
5
+ import torch
6
+ from safetensors import safe_open
7
+
8
+
9
+ @contextmanager
10
+ def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False):
11
+ old_register_parameter = torch.nn.Module.register_parameter
12
+ if include_buffers:
13
+ old_register_buffer = torch.nn.Module.register_buffer
14
+
15
+ def register_empty_parameter(module, name, param):
16
+ old_register_parameter(module, name, param)
17
+ if param is not None:
18
+ param_cls = type(module._parameters[name])
19
+ kwargs = module._parameters[name].__dict__
20
+ kwargs["requires_grad"] = param.requires_grad
21
+ module._parameters[name] = param_cls(
22
+ module._parameters[name].to(device), **kwargs
23
+ )
24
+
25
+ def register_empty_buffer(module, name, buffer, persistent=True):
26
+ old_register_buffer(module, name, buffer, persistent=persistent)
27
+ if buffer is not None:
28
+ module._buffers[name] = module._buffers[name].to(device)
29
+
30
+ def patch_tensor_constructor(fn):
31
+ def wrapper(*args, **kwargs):
32
+ kwargs["device"] = device
33
+ return fn(*args, **kwargs)
34
+
35
+ return wrapper
36
+
37
+ if include_buffers:
38
+ tensor_constructors_to_patch = {
39
+ torch_function_name: getattr(torch, torch_function_name)
40
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
41
+ }
42
+ else:
43
+ tensor_constructors_to_patch = {}
44
+
45
+ try:
46
+ torch.nn.Module.register_parameter = register_empty_parameter
47
+ if include_buffers:
48
+ torch.nn.Module.register_buffer = register_empty_buffer
49
+ for torch_function_name in tensor_constructors_to_patch.keys():
50
+ setattr(
51
+ torch,
52
+ torch_function_name,
53
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
54
+ )
55
+ yield
56
+ finally:
57
+ torch.nn.Module.register_parameter = old_register_parameter
58
+ if include_buffers:
59
+ torch.nn.Module.register_buffer = old_register_buffer
60
+ for (
61
+ torch_function_name,
62
+ old_torch_function,
63
+ ) in tensor_constructors_to_patch.items():
64
+ setattr(torch, torch_function_name, old_torch_function)
65
+
66
+
67
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
68
+ state_dict = {}
69
+ for file_name in os.listdir(file_path):
70
+ if "." in file_name and file_name.split(".")[-1] in [
71
+ "safetensors",
72
+ "bin",
73
+ "ckpt",
74
+ "pth",
75
+ "pt",
76
+ ]:
77
+ state_dict.update(
78
+ load_state_dict(
79
+ os.path.join(file_path, file_name), torch_dtype=torch_dtype
80
+ )
81
+ )
82
+ return state_dict
83
+
84
+
85
+ def load_state_dict(file_path, torch_dtype=None):
86
+ if file_path.endswith(".safetensors"):
87
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
88
+ else:
89
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
90
+
91
+
92
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
93
+ state_dict = {}
94
+ with safe_open(file_path, framework="pt", device="cpu") as f:
95
+ for k in f.keys():
96
+ state_dict[k] = f.get_tensor(k)
97
+ if torch_dtype is not None:
98
+ state_dict[k] = state_dict[k].to(torch_dtype)
99
+ return state_dict
100
+
101
+
102
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
103
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
104
+ if torch_dtype is not None:
105
+ for i in state_dict:
106
+ if isinstance(state_dict[i], torch.Tensor):
107
+ state_dict[i] = state_dict[i].to(torch_dtype)
108
+ return state_dict
109
+
110
+
111
+ def search_for_embeddings(state_dict):
112
+ embeddings = []
113
+ for k in state_dict:
114
+ if isinstance(state_dict[k], torch.Tensor):
115
+ embeddings.append(state_dict[k])
116
+ elif isinstance(state_dict[k], dict):
117
+ embeddings += search_for_embeddings(state_dict[k])
118
+ return embeddings
119
+
120
+
121
+ def search_parameter(param, state_dict):
122
+ for name, param_ in state_dict.items():
123
+ if param.numel() == param_.numel():
124
+ if param.shape == param_.shape:
125
+ if torch.dist(param, param_) < 1e-3:
126
+ return name
127
+ else:
128
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
129
+ return name
130
+ return None
131
+
132
+
133
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
134
+ matched_keys = set()
135
+ with torch.no_grad():
136
+ for name in source_state_dict:
137
+ rename = search_parameter(source_state_dict[name], target_state_dict)
138
+ if rename is not None:
139
+ print(f'"{name}": "{rename}",')
140
+ matched_keys.add(rename)
141
+ elif (
142
+ split_qkv
143
+ and len(source_state_dict[name].shape) >= 1
144
+ and source_state_dict[name].shape[0] % 3 == 0
145
+ ):
146
+ length = source_state_dict[name].shape[0] // 3
147
+ rename = []
148
+ for i in range(3):
149
+ rename.append(
150
+ search_parameter(
151
+ source_state_dict[name][i * length : i * length + length],
152
+ target_state_dict,
153
+ )
154
+ )
155
+ if None not in rename:
156
+ print(f'"{name}": {rename},')
157
+ for rename_ in rename:
158
+ matched_keys.add(rename_)
159
+ for name in target_state_dict:
160
+ if name not in matched_keys:
161
+ print("Cannot find", name, target_state_dict[name].shape)
162
+
163
+
164
+ def search_for_files(folder, extensions):
165
+ files = []
166
+ if os.path.isdir(folder):
167
+ for file in sorted(os.listdir(folder)):
168
+ files += search_for_files(os.path.join(folder, file), extensions)
169
+ elif os.path.isfile(folder):
170
+ for extension in extensions:
171
+ if folder.endswith(extension):
172
+ files.append(folder)
173
+ break
174
+ return files
175
+
176
+
177
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
178
+ keys = []
179
+ for key, value in state_dict.items():
180
+ if isinstance(key, str):
181
+ if isinstance(value, torch.Tensor):
182
+ if with_shape:
183
+ shape = "_".join(map(str, list(value.shape)))
184
+ keys.append(key + ":" + shape)
185
+ keys.append(key)
186
+ elif isinstance(value, dict):
187
+ keys.append(
188
+ key
189
+ + "|"
190
+ + convert_state_dict_keys_to_single_str(
191
+ value, with_shape=with_shape
192
+ )
193
+ )
194
+ keys.sort()
195
+ keys_str = ",".join(keys)
196
+ return keys_str
197
+
198
+
199
+ def split_state_dict_with_prefix(state_dict):
200
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
201
+ prefix_dict = {}
202
+ for key in keys:
203
+ prefix = key if "." not in key else key.split(".")[0]
204
+ if prefix not in prefix_dict:
205
+ prefix_dict[prefix] = []
206
+ prefix_dict[prefix].append(key)
207
+ state_dicts = []
208
+ for prefix, keys in prefix_dict.items():
209
+ sub_state_dict = {key: state_dict[key] for key in keys}
210
+ state_dicts.append(sub_state_dict)
211
+ return state_dicts
212
+
213
+
214
+ def hash_state_dict_keys(state_dict, with_shape=True):
215
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
216
+ keys_str = keys_str.encode(encoding="UTF-8")
217
+ return hashlib.md5(keys_str).hexdigest()
FantasyTalking/diffsynth/models/wan_video_dit.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.amp as amp
5
+ import torch.nn as nn
6
+ from tqdm import tqdm
7
+
8
+ from .utils import hash_state_dict_keys
9
+
10
+ try:
11
+ import flash_attn_interface
12
+
13
+ FLASH_ATTN_3_AVAILABLE = True
14
+ except ModuleNotFoundError:
15
+ FLASH_ATTN_3_AVAILABLE = False
16
+
17
+ try:
18
+ import flash_attn
19
+
20
+ FLASH_ATTN_2_AVAILABLE = True
21
+ except ModuleNotFoundError:
22
+ FLASH_ATTN_2_AVAILABLE = False
23
+
24
+ try:
25
+ from sageattention import sageattn
26
+
27
+ SAGE_ATTN_AVAILABLE = True
28
+ except ModuleNotFoundError:
29
+ SAGE_ATTN_AVAILABLE = False
30
+
31
+ import warnings
32
+
33
+ __all__ = ["WanModel"]
34
+
35
+ def attention(
36
+ q,
37
+ k,
38
+ v,
39
+ q_lens=None,
40
+ k_lens=None,
41
+ dropout_p=0.0,
42
+ softmax_scale=None,
43
+ q_scale=None,
44
+ causal=False,
45
+ window_size=(-1, -1),
46
+ deterministic=False,
47
+ dtype=torch.bfloat16,
48
+ version=None):
49
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
50
+ x = flash_attention(
51
+ q=q,
52
+ k=k,
53
+ v=v,
54
+ q_lens=q_lens,
55
+ k_lens=k_lens,
56
+ dropout_p=dropout_p,
57
+ softmax_scale=softmax_scale,
58
+ q_scale=q_scale,
59
+ causal=causal,
60
+ window_size=window_size,
61
+ deterministic=deterministic,
62
+ dtype=dtype,
63
+ version=version,)
64
+ elif FLASH_ATTN_2_AVAILABLE:
65
+ x = flash_attention(
66
+ q=q,
67
+ k=k,
68
+ v=v,
69
+ q_lens=q_lens,
70
+ k_lens=k_lens,
71
+ dropout_p=dropout_p,
72
+ softmax_scale=softmax_scale,
73
+ q_scale=q_scale,
74
+ causal=causal,
75
+ window_size=window_size,
76
+ deterministic=deterministic,
77
+ dtype=dtype,
78
+ version=version,)
79
+ elif SAGE_ATTN_AVAILABLE:
80
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
81
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
82
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
83
+ x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
84
+ x = x.transpose(1, 2).contiguous()
85
+ else:
86
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
87
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
88
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
89
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
90
+ x = x.transpose(1, 2).contiguous()
91
+ # output
92
+ return x
93
+
94
+
95
+
96
+ def flash_attention(
97
+ q,
98
+ k,
99
+ v,
100
+ q_lens=None,
101
+ k_lens=None,
102
+ dropout_p=0.0,
103
+ softmax_scale=None,
104
+ q_scale=None,
105
+ causal=False,
106
+ window_size=(-1, -1),
107
+ deterministic=False,
108
+ dtype=torch.bfloat16,
109
+ version=None,
110
+ ):
111
+ """
112
+ q: [B, Lq, Nq, C1].
113
+ k: [B, Lk, Nk, C1].
114
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
115
+ q_lens: [B].
116
+ k_lens: [B].
117
+ dropout_p: float. Dropout probability.
118
+ softmax_scale: float. The scaling of QK^T before applying softmax.
119
+ causal: bool. Whether to apply causal attention mask.
120
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
121
+ deterministic: bool. If True, slightly slower and uses more memory.
122
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
123
+ """
124
+ half_dtypes = (torch.float16, torch.bfloat16)
125
+ assert dtype in half_dtypes
126
+ assert q.device.type == "cuda" and q.size(-1) <= 256
127
+
128
+ # params
129
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
130
+
131
+ def half(x):
132
+ return x if x.dtype in half_dtypes else x.to(dtype)
133
+
134
+ # preprocess query
135
+ if q_lens is None:
136
+ q = half(q.flatten(0, 1))
137
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(
138
+ device=q.device, non_blocking=True
139
+ )
140
+ else:
141
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
142
+
143
+ # preprocess key, value
144
+ if k_lens is None:
145
+ k = half(k.flatten(0, 1))
146
+ v = half(v.flatten(0, 1))
147
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(
148
+ device=k.device, non_blocking=True
149
+ )
150
+ else:
151
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
152
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
153
+
154
+ q = q.to(v.dtype)
155
+ k = k.to(v.dtype)
156
+
157
+ if q_scale is not None:
158
+ q = q * q_scale
159
+
160
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
161
+ warnings.warn(
162
+ "Flash attention 3 is not available, use flash attention 2 instead."
163
+ )
164
+
165
+ # apply attention
166
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
167
+ # Note: dropout_p, window_size are not supported in FA3 now.
168
+ x = flash_attn_interface.flash_attn_varlen_func(
169
+ q=q,
170
+ k=k,
171
+ v=v,
172
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
173
+ .cumsum(0, dtype=torch.int32)
174
+ .to(q.device, non_blocking=True),
175
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
176
+ .cumsum(0, dtype=torch.int32)
177
+ .to(q.device, non_blocking=True),
178
+ seqused_q=None,
179
+ seqused_k=None,
180
+ max_seqlen_q=lq,
181
+ max_seqlen_k=lk,
182
+ softmax_scale=softmax_scale,
183
+ causal=causal,
184
+ deterministic=deterministic,
185
+ )[0].unflatten(0, (b, lq))
186
+ elif FLASH_ATTN_2_AVAILABLE:
187
+ x = flash_attn.flash_attn_varlen_func(
188
+ q=q,
189
+ k=k,
190
+ v=v,
191
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
192
+ .cumsum(0, dtype=torch.int32)
193
+ .to(q.device, non_blocking=True),
194
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
195
+ .cumsum(0, dtype=torch.int32)
196
+ .to(q.device, non_blocking=True),
197
+ max_seqlen_q=lq,
198
+ max_seqlen_k=lk,
199
+ dropout_p=dropout_p,
200
+ softmax_scale=softmax_scale,
201
+ causal=causal,
202
+ window_size=window_size,
203
+ deterministic=deterministic,
204
+ ).unflatten(0, (b, lq))
205
+ elif SAGE_ATTN_AVAILABLE:
206
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
207
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
208
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
209
+ x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
210
+ x = x.transpose(1, 2).contiguous()
211
+ else:
212
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
213
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
214
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
215
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
216
+ x = x.transpose(1, 2).contiguous()
217
+
218
+ # output
219
+ return x.type(out_dtype)
220
+
221
+
222
+ def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
223
+ b, lq, lk = q.size(0), q.size(1), k.size(1)
224
+ if q_lens is None:
225
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32)
226
+ if k_lens is None:
227
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32)
228
+ attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
229
+ for i in range(b):
230
+ q_len, k_len = q_lens[i], k_lens[i]
231
+ attn_mask[i, q_len:, :] = True
232
+ attn_mask[i, :, k_len:] = True
233
+
234
+ if causal:
235
+ causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
236
+ attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
237
+
238
+ attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
239
+ return attn_mask
240
+
241
+
242
+ def attention(
243
+ q,
244
+ k,
245
+ v,
246
+ q_lens=None,
247
+ k_lens=None,
248
+ dropout_p=0.0,
249
+ softmax_scale=None,
250
+ q_scale=None,
251
+ causal=False,
252
+ window_size=(-1, -1),
253
+ deterministic=False,
254
+ dtype=torch.bfloat16,
255
+ fa_version=None,
256
+ ):
257
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
258
+ return flash_attention(
259
+ q=q,
260
+ k=k,
261
+ v=v,
262
+ q_lens=q_lens,
263
+ k_lens=k_lens,
264
+ dropout_p=dropout_p,
265
+ softmax_scale=softmax_scale,
266
+ q_scale=q_scale,
267
+ causal=causal,
268
+ window_size=window_size,
269
+ deterministic=deterministic,
270
+ dtype=dtype,
271
+ version=fa_version,
272
+ )
273
+ else:
274
+ if q_lens is not None or k_lens is not None:
275
+ warnings.warn(
276
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
277
+ )
278
+ attn_mask = None
279
+
280
+ q = q.transpose(1, 2).to(dtype)
281
+ k = k.transpose(1, 2).to(dtype)
282
+ v = v.transpose(1, 2).to(dtype)
283
+
284
+ out = torch.nn.functional.scaled_dot_product_attention(
285
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
286
+ )
287
+
288
+ out = out.transpose(1, 2).contiguous()
289
+ return out
290
+
291
+
292
+ def sinusoidal_embedding_1d(dim, position):
293
+ # preprocess
294
+ assert dim % 2 == 0
295
+ half = dim // 2
296
+ position = position.type(torch.float64)
297
+
298
+ # calculation
299
+ sinusoid = torch.outer(
300
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half))
301
+ )
302
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
303
+ return x
304
+
305
+
306
+ @amp.autocast(enabled=False, device_type="cuda")
307
+ def rope_params(max_seq_len, dim, theta=10000):
308
+ assert dim % 2 == 0
309
+ freqs = torch.outer(
310
+ torch.arange(max_seq_len),
311
+ 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
312
+ )
313
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
314
+ return freqs
315
+
316
+
317
+ @amp.autocast(enabled=False, device_type="cuda")
318
+ def rope_apply(x, grid_sizes, freqs):
319
+ n, c = x.size(2), x.size(3) // 2
320
+
321
+ # split freqs
322
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
323
+
324
+ # loop over samples
325
+ output = []
326
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
327
+ seq_len = f * h * w
328
+
329
+ # precompute multipliers
330
+ x_i = torch.view_as_complex(
331
+ x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
332
+ )
333
+ freqs_i = torch.cat(
334
+ [
335
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
336
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
337
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
338
+ ],
339
+ dim=-1,
340
+ ).reshape(seq_len, 1, -1)
341
+
342
+ # apply rotary embedding
343
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
344
+ x_i = torch.cat([x_i, x[i, seq_len:]])
345
+
346
+ # append to collection
347
+ output.append(x_i)
348
+ return torch.stack(output).float()
349
+
350
+
351
+ class WanRMSNorm(nn.Module):
352
+ def __init__(self, dim, eps=1e-5):
353
+ super().__init__()
354
+ self.dim = dim
355
+ self.eps = eps
356
+ self.weight = nn.Parameter(torch.ones(dim))
357
+
358
+ def forward(self, x):
359
+ return self._norm(x.float()).type_as(x) * self.weight
360
+
361
+ def _norm(self, x):
362
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
363
+
364
+
365
+ class WanLayerNorm(nn.LayerNorm):
366
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
367
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
368
+
369
+ def forward(self, x):
370
+ return super().forward(x.float()).type_as(x)
371
+
372
+
373
+ class WanSelfAttention(nn.Module):
374
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
375
+ assert dim % num_heads == 0
376
+ super().__init__()
377
+ self.dim = dim
378
+ self.num_heads = num_heads
379
+ self.head_dim = dim // num_heads
380
+ self.window_size = window_size
381
+ self.qk_norm = qk_norm
382
+ self.eps = eps
383
+
384
+ # layers
385
+ self.q = nn.Linear(dim, dim)
386
+ self.k = nn.Linear(dim, dim)
387
+ self.v = nn.Linear(dim, dim)
388
+ self.o = nn.Linear(dim, dim)
389
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
390
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
391
+
392
+ def forward(self, x, seq_lens, grid_sizes, freqs):
393
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
394
+
395
+ # query, key, value function
396
+ def qkv_fn(x):
397
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
398
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
399
+ v = self.v(x).view(b, s, n, d)
400
+ return q, k, v
401
+
402
+ q, k, v = qkv_fn(x)
403
+
404
+ x = flash_attention(
405
+ q=rope_apply(q, grid_sizes, freqs),
406
+ k=rope_apply(k, grid_sizes, freqs),
407
+ v=v,
408
+ k_lens=seq_lens,
409
+ window_size=self.window_size,
410
+ )
411
+
412
+ # output
413
+ x = x.flatten(2)
414
+ x = self.o(x)
415
+ return x
416
+
417
+
418
+ class WanT2VCrossAttention(WanSelfAttention):
419
+ def forward(self, x, context, context_lens):
420
+ """
421
+ x: [B, L1, C].
422
+ context: [B, L2, C].
423
+ context_lens: [B].
424
+ """
425
+ b, n, d = x.size(0), self.num_heads, self.head_dim
426
+
427
+ # compute query, key, value
428
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
429
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
430
+ v = self.v(context).view(b, -1, n, d)
431
+
432
+ # compute attention
433
+ x = flash_attention(q, k, v, k_lens=context_lens)
434
+
435
+ # output
436
+ x = x.flatten(2)
437
+ x = self.o(x)
438
+ return x
439
+
440
+
441
+ class WanI2VCrossAttentionProcessor:
442
+ def __call__(self, attn, x, context, context_lens) -> torch.Tensor:
443
+ """
444
+ x: [B, L1, C].
445
+ context: [B, L2, C].
446
+ context_lens: [B].
447
+ """
448
+ context_img = context[:, :257]
449
+ context = context[:, 257:]
450
+ b, n, d = x.size(0), attn.num_heads, attn.head_dim
451
+
452
+ # compute query, key, value
453
+ q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
454
+ k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
455
+ v = attn.v(context).view(b, -1, n, d)
456
+ k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
457
+ v_img = attn.v_img(context_img).view(b, -1, n, d)
458
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
459
+ # compute attention
460
+ x = flash_attention(q, k, v, k_lens=context_lens)
461
+
462
+ # output
463
+ x = x.flatten(2)
464
+ img_x = img_x.flatten(2)
465
+ x = x + img_x
466
+ x = attn.o(x)
467
+ return x
468
+
469
+
470
+ class WanI2VCrossAttention(WanSelfAttention):
471
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
472
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
473
+
474
+ self.k_img = nn.Linear(dim, dim)
475
+ self.v_img = nn.Linear(dim, dim)
476
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
477
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
478
+
479
+ processor = WanI2VCrossAttentionProcessor()
480
+ self.set_processor(processor)
481
+
482
+ def set_processor(self, processor) -> None:
483
+ self.processor = processor
484
+
485
+ def get_processor(self):
486
+ return self.processor
487
+
488
+ def forward(
489
+ self,
490
+ x,
491
+ context,
492
+ context_lens,
493
+ audio_proj,
494
+ audio_context_lens,
495
+ latents_num_frames,
496
+ audio_scale: float = 1.0,
497
+ **kwargs,
498
+ ):
499
+ """
500
+ x: [B, L1, C].
501
+ context: [B, L2, C].
502
+ context_lens: [B].
503
+ """
504
+ if audio_proj is None:
505
+ return self.processor(self, x, context, context_lens)
506
+ else:
507
+ return self.processor(
508
+ self,
509
+ x,
510
+ context,
511
+ context_lens,
512
+ audio_proj,
513
+ audio_context_lens,
514
+ latents_num_frames,
515
+ audio_scale,
516
+ )
517
+
518
+
519
+ WANX_CROSSATTENTION_CLASSES = {
520
+ "t2v_cross_attn": WanT2VCrossAttention,
521
+ "i2v_cross_attn": WanI2VCrossAttention,
522
+ }
523
+
524
+
525
+ class WanAttentionBlock(nn.Module):
526
+ def __init__(
527
+ self,
528
+ cross_attn_type,
529
+ dim,
530
+ ffn_dim,
531
+ num_heads,
532
+ window_size=(-1, -1),
533
+ qk_norm=True,
534
+ cross_attn_norm=False,
535
+ eps=1e-6,
536
+ ):
537
+ super().__init__()
538
+ self.dim = dim
539
+ self.ffn_dim = ffn_dim
540
+ self.num_heads = num_heads
541
+ self.window_size = window_size
542
+ self.qk_norm = qk_norm
543
+ self.cross_attn_norm = cross_attn_norm
544
+ self.eps = eps
545
+
546
+ # layers
547
+ self.norm1 = WanLayerNorm(dim, eps)
548
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
549
+ self.norm3 = (
550
+ WanLayerNorm(dim, eps, elementwise_affine=True)
551
+ if cross_attn_norm
552
+ else nn.Identity()
553
+ )
554
+ self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
555
+ dim, num_heads, (-1, -1), qk_norm, eps
556
+ )
557
+ self.norm2 = WanLayerNorm(dim, eps)
558
+ self.ffn = nn.Sequential(
559
+ nn.Linear(dim, ffn_dim),
560
+ nn.GELU(approximate="tanh"),
561
+ nn.Linear(ffn_dim, dim),
562
+ )
563
+
564
+ # modulation
565
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
566
+
567
+ def forward(
568
+ self,
569
+ x,
570
+ e,
571
+ seq_lens,
572
+ grid_sizes,
573
+ freqs,
574
+ context,
575
+ context_lens,
576
+ **kwargs,
577
+ ):
578
+ assert e.dtype == torch.float32
579
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
580
+ e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
581
+ assert e[0].dtype == torch.float32
582
+
583
+ # self-attention
584
+ y = self.self_attn(
585
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
586
+ )
587
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
588
+ x = x + y * e[2]
589
+
590
+ # cross-attention & ffn function
591
+ def cross_attn_ffn(x, context, context_lens, e, **kwargs):
592
+ x = x + self.cross_attn(self.norm3(x), context, context_lens, **kwargs)
593
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
594
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
595
+ x = x + y * e[5]
596
+ return x
597
+
598
+ x = cross_attn_ffn(x, context, context_lens, e, **kwargs)
599
+ return x
600
+
601
+
602
+ class Head(nn.Module):
603
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
604
+ super().__init__()
605
+ self.dim = dim
606
+ self.out_dim = out_dim
607
+ self.patch_size = patch_size
608
+ self.eps = eps
609
+
610
+ # layers
611
+ out_dim = math.prod(patch_size) * out_dim
612
+ self.norm = WanLayerNorm(dim, eps)
613
+ self.head = nn.Linear(dim, out_dim)
614
+
615
+ # modulation
616
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
617
+
618
+ def forward(self, x, e):
619
+ assert e.dtype == torch.float32
620
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
621
+ e = (
622
+ self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)
623
+ ).chunk(2, dim=1)
624
+ x = self.head(self.norm(x) * (1 + e[1]) + e[0])
625
+ return x
626
+
627
+
628
+ class MLPProj(torch.nn.Module):
629
+ def __init__(self, in_dim, out_dim):
630
+ super().__init__()
631
+
632
+ self.proj = torch.nn.Sequential(
633
+ torch.nn.LayerNorm(in_dim),
634
+ torch.nn.Linear(in_dim, in_dim),
635
+ torch.nn.GELU(),
636
+ torch.nn.Linear(in_dim, out_dim),
637
+ torch.nn.LayerNorm(out_dim),
638
+ )
639
+
640
+ def forward(self, image_embeds):
641
+ clip_extra_context_tokens = self.proj(image_embeds)
642
+ return clip_extra_context_tokens
643
+
644
+
645
+ class WanModel(nn.Module):
646
+ def __init__(
647
+ self,
648
+ model_type="t2v",
649
+ patch_size=(1, 2, 2),
650
+ text_len=512,
651
+ in_dim=16,
652
+ dim=2048,
653
+ ffn_dim=8192,
654
+ freq_dim=256,
655
+ text_dim=4096,
656
+ out_dim=16,
657
+ num_heads=16,
658
+ num_layers=32,
659
+ window_size=(-1, -1),
660
+ qk_norm=True,
661
+ cross_attn_norm=False,
662
+ eps=1e-6,
663
+ ):
664
+ super().__init__()
665
+
666
+ assert model_type in ["t2v", "i2v"]
667
+ self.model_type = model_type
668
+
669
+ self.patch_size = patch_size
670
+ self.text_len = text_len
671
+ self.in_dim = in_dim
672
+ self.dim = dim
673
+ self.ffn_dim = ffn_dim
674
+ self.freq_dim = freq_dim
675
+ self.text_dim = text_dim
676
+ self.out_dim = out_dim
677
+ self.num_heads = num_heads
678
+ self.num_layers = num_layers
679
+ self.window_size = window_size
680
+ self.qk_norm = qk_norm
681
+ self.cross_attn_norm = cross_attn_norm
682
+ self.eps = eps
683
+
684
+ # embeddings
685
+ self.patch_embedding = nn.Conv3d(
686
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
687
+ )
688
+ self.text_embedding = nn.Sequential(
689
+ nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
690
+ )
691
+
692
+ self.time_embedding = nn.Sequential(
693
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
694
+ )
695
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
696
+
697
+ # blocks
698
+ cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
699
+ self.blocks = nn.ModuleList(
700
+ [
701
+ WanAttentionBlock(
702
+ cross_attn_type,
703
+ dim,
704
+ ffn_dim,
705
+ num_heads,
706
+ window_size,
707
+ qk_norm,
708
+ cross_attn_norm,
709
+ eps,
710
+ )
711
+ for _ in range(num_layers)
712
+ ]
713
+ )
714
+
715
+ # head
716
+ self.head = Head(dim, out_dim, patch_size, eps)
717
+
718
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
719
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
720
+ d = dim // num_heads
721
+ self.freqs = torch.cat(
722
+ [
723
+ rope_params(1024, d - 4 * (d // 6)),
724
+ rope_params(1024, 2 * (d // 6)),
725
+ rope_params(1024, 2 * (d // 6)),
726
+ ],
727
+ dim=1,
728
+ )
729
+
730
+ if model_type == "i2v":
731
+ self.img_emb = MLPProj(1280, dim)
732
+
733
+ # initialize weights
734
+ self.init_weights()
735
+
736
+ def forward(
737
+ self,
738
+ x,
739
+ timestep,
740
+ context,
741
+ seq_len,
742
+ clip_fea=None,
743
+ y=None,
744
+ use_gradient_checkpointing=False,
745
+ audio_proj=None,
746
+ audio_context_lens=None,
747
+ latents_num_frames=None,
748
+ audio_scale=1.0,
749
+ **kwargs,
750
+ ):
751
+ """
752
+ x: A list of videos each with shape [C, T, H, W].
753
+ t: [B].
754
+ context: A list of text embeddings each with shape [L, C].
755
+ """
756
+ if self.model_type == "i2v":
757
+ assert clip_fea is not None and y is not None
758
+ # params
759
+ device = x[0].device
760
+ if self.freqs.device != device:
761
+ self.freqs = self.freqs.to(device)
762
+
763
+ if y is not None:
764
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
765
+
766
+ # embeddings
767
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
768
+ grid_sizes = torch.stack(
769
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
770
+ ) # [B,2]
771
+ x = [u.flatten(2).transpose(1, 2) for u in x] # [[C, L, T],,]
772
+ # print(f"x0.shape:{x[0].shape}")
773
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
774
+ assert seq_lens.max() <= seq_len
775
+ x = torch.cat(
776
+ [
777
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
778
+ for u in x
779
+ ]
780
+ )
781
+
782
+ # time embeddings
783
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
784
+ e = self.time_embedding(
785
+ sinusoidal_embedding_1d(self.freq_dim, timestep).float()
786
+ )
787
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
788
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
789
+
790
+ # context
791
+ context_lens = None
792
+ context = self.text_embedding(
793
+ torch.stack(
794
+ [
795
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
796
+ for u in context
797
+ ]
798
+ )
799
+ )
800
+
801
+ if clip_fea is not None:
802
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
803
+ context = torch.concat([context_clip, context], dim=1)
804
+
805
+ # arguments
806
+ kwargs = dict(
807
+ e=e0,
808
+ seq_lens=seq_lens,
809
+ grid_sizes=grid_sizes,
810
+ freqs=self.freqs,
811
+ context=context,
812
+ context_lens=context_lens,
813
+ audio_proj=audio_proj,
814
+ audio_context_lens=audio_context_lens,
815
+ latents_num_frames=latents_num_frames,
816
+ audio_scale=audio_scale,
817
+ )
818
+
819
+ def create_custom_forward(module):
820
+ def custom_forward(*inputs, **kwargs):
821
+ return module(*inputs, **kwargs)
822
+
823
+ return custom_forward
824
+
825
+ for block in self.blocks:
826
+ if self.training and use_gradient_checkpointing:
827
+ x = torch.utils.checkpoint.checkpoint(
828
+ create_custom_forward(block),
829
+ x,
830
+ **kwargs,
831
+ use_reentrant=False,
832
+ )
833
+ else:
834
+ x = block(x, **kwargs)
835
+
836
+ # head
837
+ x = self.head(x, e)
838
+
839
+ # unpatchify
840
+ x = self.unpatchify(x, grid_sizes)
841
+ x = torch.stack(x).float()
842
+ return x
843
+
844
+ def unpatchify(self, x, grid_sizes):
845
+ c = self.out_dim
846
+ out = []
847
+ for u, v in zip(x, grid_sizes.tolist()):
848
+ u = u[: math.prod(v)].view(*v, *self.patch_size, c)
849
+ u = torch.einsum("fhwpqrc->cfphqwr", u)
850
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
851
+ out.append(u)
852
+ return out
853
+
854
+ def init_weights(self):
855
+ # basic init
856
+ for m in self.modules():
857
+ if isinstance(m, nn.Linear):
858
+ nn.init.xavier_uniform_(m.weight)
859
+ if m.bias is not None:
860
+ nn.init.zeros_(m.bias)
861
+
862
+ # init embeddings
863
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
864
+ for m in self.text_embedding.modules():
865
+ if isinstance(m, nn.Linear):
866
+ nn.init.normal_(m.weight, std=0.02)
867
+ for m in self.time_embedding.modules():
868
+ if isinstance(m, nn.Linear):
869
+ nn.init.normal_(m.weight, std=0.02)
870
+
871
+ # init output layer
872
+ nn.init.zeros_(self.head.head.weight)
873
+
874
+ @staticmethod
875
+ def state_dict_converter():
876
+ return WanModelStateDictConverter()
877
+
878
+ @property
879
+ def attn_processors(
880
+ self,
881
+ ): # copy from https://github.com/XLabs-AI/x-flux/blob/main/src/flux/model.py
882
+ # set recursively
883
+ processors = {}
884
+
885
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
886
+ if hasattr(module, "set_processor"):
887
+ processors[f"{name}.processor"] = module.processor
888
+
889
+ for sub_name, child in module.named_children():
890
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
891
+
892
+ return processors
893
+
894
+ for name, module in self.named_children():
895
+ fn_recursive_add_processors(name, module, processors)
896
+
897
+ return processors
898
+
899
+ def set_attn_processor(self, processor):
900
+ r"""copy from https://github.com/XLabs-AI/x-flux/blob/main/src/flux/model.py
901
+ Sets the attention processor to use to compute attention.
902
+
903
+ Parameters:
904
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
905
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
906
+ for **all** `Attention` layers.
907
+
908
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
909
+ processor. This is strongly recommended when setting trainable attention processors.
910
+
911
+ """
912
+ count = len(self.attn_processors.keys())
913
+
914
+ if isinstance(processor, dict) and len(processor) != count:
915
+ raise ValueError(
916
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
917
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
918
+ )
919
+
920
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
921
+ if hasattr(module, "set_processor"):
922
+ if not isinstance(processor, dict):
923
+ module.set_processor(processor)
924
+ else:
925
+ module.set_processor(processor.pop(f"{name}.processor"))
926
+
927
+ for sub_name, child in module.named_children():
928
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
929
+
930
+ for name, module in self.named_children():
931
+ fn_recursive_attn_processor(name, module, processor)
932
+
933
+
934
+ class WanModelStateDictConverter:
935
+ def __init__(self):
936
+ pass
937
+
938
+ def from_diffusers(self, state_dict):
939
+ return state_dict
940
+
941
+ def from_civitai(self, state_dict):
942
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
943
+ config = {
944
+ "model_type": "t2v",
945
+ "patch_size": (1, 2, 2),
946
+ "text_len": 512,
947
+ "in_dim": 16,
948
+ "dim": 1536,
949
+ "ffn_dim": 8960,
950
+ "freq_dim": 256,
951
+ "text_dim": 4096,
952
+ "out_dim": 16,
953
+ "num_heads": 12,
954
+ "num_layers": 30,
955
+ "window_size": (-1, -1),
956
+ "qk_norm": True,
957
+ "cross_attn_norm": True,
958
+ "eps": 1e-6,
959
+ }
960
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
961
+ config = {
962
+ "model_type": "t2v",
963
+ "patch_size": (1, 2, 2),
964
+ "text_len": 512,
965
+ "in_dim": 16,
966
+ "dim": 5120,
967
+ "ffn_dim": 13824,
968
+ "freq_dim": 256,
969
+ "text_dim": 4096,
970
+ "out_dim": 16,
971
+ "num_heads": 40,
972
+ "num_layers": 40,
973
+ "window_size": (-1, -1),
974
+ "qk_norm": True,
975
+ "cross_attn_norm": True,
976
+ "eps": 1e-6,
977
+ }
978
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
979
+ config = {
980
+ "model_type": "i2v",
981
+ "patch_size": (1, 2, 2),
982
+ "text_len": 512,
983
+ "in_dim": 36,
984
+ "dim": 5120,
985
+ "ffn_dim": 13824,
986
+ "freq_dim": 256,
987
+ "text_dim": 4096,
988
+ "out_dim": 16,
989
+ "num_heads": 40,
990
+ "num_layers": 40,
991
+ "window_size": (-1, -1),
992
+ "qk_norm": True,
993
+ "cross_attn_norm": True,
994
+ "eps": 1e-6,
995
+ }
996
+ else:
997
+ config = {}
998
+ return state_dict, config
FantasyTalking/diffsynth/models/wan_video_image_encoder.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Concise re-implementation of
3
+ ``https://github.com/openai/CLIP'' and
4
+ ``https://github.com/mlfoundations/open_clip''.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+
13
+ from .wan_video_dit import flash_attention
14
+
15
+
16
+ class SelfAttention(nn.Module):
17
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
18
+ assert dim % num_heads == 0
19
+ super().__init__()
20
+ self.dim = dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = dim // num_heads
23
+ self.eps = eps
24
+
25
+ # layers
26
+ self.q = nn.Linear(dim, dim)
27
+ self.k = nn.Linear(dim, dim)
28
+ self.v = nn.Linear(dim, dim)
29
+ self.o = nn.Linear(dim, dim)
30
+ self.dropout = nn.Dropout(dropout)
31
+
32
+ def forward(self, x, mask):
33
+ """
34
+ x: [B, L, C].
35
+ """
36
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
37
+
38
+ # compute query, key, value
39
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
40
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
41
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
42
+
43
+ # compute attention
44
+ p = self.dropout.p if self.training else 0.0
45
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
46
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
47
+
48
+ # output
49
+ x = self.o(x)
50
+ x = self.dropout(x)
51
+ return x
52
+
53
+
54
+ class AttentionBlock(nn.Module):
55
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.num_heads = num_heads
59
+ self.post_norm = post_norm
60
+ self.eps = eps
61
+
62
+ # layers
63
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
64
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
65
+ self.ffn = nn.Sequential(
66
+ nn.Linear(dim, dim * 4),
67
+ nn.GELU(),
68
+ nn.Linear(dim * 4, dim),
69
+ nn.Dropout(dropout),
70
+ )
71
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
72
+
73
+ def forward(self, x, mask):
74
+ if self.post_norm:
75
+ x = self.norm1(x + self.attn(x, mask))
76
+ x = self.norm2(x + self.ffn(x))
77
+ else:
78
+ x = x + self.attn(self.norm1(x), mask)
79
+ x = x + self.ffn(self.norm2(x))
80
+ return x
81
+
82
+
83
+ class XLMRoberta(nn.Module):
84
+ """
85
+ XLMRobertaModel with no pooler and no LM head.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ vocab_size=250002,
91
+ max_seq_len=514,
92
+ type_size=1,
93
+ pad_id=1,
94
+ dim=1024,
95
+ num_heads=16,
96
+ num_layers=24,
97
+ post_norm=True,
98
+ dropout=0.1,
99
+ eps=1e-5,
100
+ ):
101
+ super().__init__()
102
+ self.vocab_size = vocab_size
103
+ self.max_seq_len = max_seq_len
104
+ self.type_size = type_size
105
+ self.pad_id = pad_id
106
+ self.dim = dim
107
+ self.num_heads = num_heads
108
+ self.num_layers = num_layers
109
+ self.post_norm = post_norm
110
+ self.eps = eps
111
+
112
+ # embeddings
113
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
114
+ self.type_embedding = nn.Embedding(type_size, dim)
115
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
116
+ self.dropout = nn.Dropout(dropout)
117
+
118
+ # blocks
119
+ self.blocks = nn.ModuleList(
120
+ [
121
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
122
+ for _ in range(num_layers)
123
+ ]
124
+ )
125
+
126
+ # norm layer
127
+ self.norm = nn.LayerNorm(dim, eps=eps)
128
+
129
+ def forward(self, ids):
130
+ """
131
+ ids: [B, L] of torch.LongTensor.
132
+ """
133
+ b, s = ids.shape
134
+ mask = ids.ne(self.pad_id).long()
135
+
136
+ # embeddings
137
+ x = (
138
+ self.token_embedding(ids)
139
+ + self.type_embedding(torch.zeros_like(ids))
140
+ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
141
+ )
142
+ if self.post_norm:
143
+ x = self.norm(x)
144
+ x = self.dropout(x)
145
+
146
+ # blocks
147
+ mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
148
+ for block in self.blocks:
149
+ x = block(x, mask)
150
+
151
+ # output
152
+ if not self.post_norm:
153
+ x = self.norm(x)
154
+ return x
155
+
156
+
157
+ def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
158
+ """
159
+ XLMRobertaLarge adapted from Huggingface.
160
+ """
161
+ # params
162
+ cfg = dict(
163
+ vocab_size=250002,
164
+ max_seq_len=514,
165
+ type_size=1,
166
+ pad_id=1,
167
+ dim=1024,
168
+ num_heads=16,
169
+ num_layers=24,
170
+ post_norm=True,
171
+ dropout=0.1,
172
+ eps=1e-5,
173
+ )
174
+ cfg.update(**kwargs)
175
+
176
+ # init model
177
+ if pretrained:
178
+ from sora import DOWNLOAD_TO_CACHE
179
+
180
+ # init a meta model
181
+ with torch.device("meta"):
182
+ model = XLMRoberta(**cfg)
183
+
184
+ # load checkpoint
185
+ model.load_state_dict(
186
+ torch.load(
187
+ DOWNLOAD_TO_CACHE("models/xlm_roberta/xlm_roberta_large.pth"),
188
+ map_location=device,
189
+ ),
190
+ assign=True,
191
+ )
192
+ else:
193
+ # init a model on device
194
+ with torch.device(device):
195
+ model = XLMRoberta(**cfg)
196
+
197
+ # init tokenizer
198
+ if return_tokenizer:
199
+ from sora.data import HuggingfaceTokenizer
200
+
201
+ tokenizer = HuggingfaceTokenizer(
202
+ name="xlm-roberta-large", seq_len=model.text_len, clean="whitespace"
203
+ )
204
+ return model, tokenizer
205
+ else:
206
+ return model
207
+
208
+
209
+ def pos_interpolate(pos, seq_len):
210
+ if pos.size(1) == seq_len:
211
+ return pos
212
+ else:
213
+ src_grid = int(math.sqrt(pos.size(1)))
214
+ tar_grid = int(math.sqrt(seq_len))
215
+ n = pos.size(1) - src_grid * src_grid
216
+ return torch.cat(
217
+ [
218
+ pos[:, :n],
219
+ F.interpolate(
220
+ pos[:, n:]
221
+ .float()
222
+ .reshape(1, src_grid, src_grid, -1)
223
+ .permute(0, 3, 1, 2),
224
+ size=(tar_grid, tar_grid),
225
+ mode="bicubic",
226
+ align_corners=False,
227
+ )
228
+ .flatten(2)
229
+ .transpose(1, 2),
230
+ ],
231
+ dim=1,
232
+ )
233
+
234
+
235
+ class QuickGELU(nn.Module):
236
+ def forward(self, x):
237
+ return x * torch.sigmoid(1.702 * x)
238
+
239
+
240
+ class LayerNorm(nn.LayerNorm):
241
+ def forward(self, x):
242
+ return super().forward(x.float()).type_as(x)
243
+
244
+
245
+ class SelfAttention(nn.Module):
246
+ def __init__(
247
+ self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0
248
+ ):
249
+ assert dim % num_heads == 0
250
+ super().__init__()
251
+ self.dim = dim
252
+ self.num_heads = num_heads
253
+ self.head_dim = dim // num_heads
254
+ self.causal = causal
255
+ self.attn_dropout = attn_dropout
256
+ self.proj_dropout = proj_dropout
257
+
258
+ # layers
259
+ self.to_qkv = nn.Linear(dim, dim * 3)
260
+ self.proj = nn.Linear(dim, dim)
261
+
262
+ def forward(self, x):
263
+ """
264
+ x: [B, L, C].
265
+ """
266
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
267
+
268
+ # compute query, key, value
269
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
270
+
271
+ # compute attention
272
+ p = self.attn_dropout if self.training else 0.0
273
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
274
+ x = x.reshape(b, s, c)
275
+
276
+ # output
277
+ x = self.proj(x)
278
+ x = F.dropout(x, self.proj_dropout, self.training)
279
+ return x
280
+
281
+
282
+ class SwiGLU(nn.Module):
283
+ def __init__(self, dim, mid_dim):
284
+ super().__init__()
285
+ self.dim = dim
286
+ self.mid_dim = mid_dim
287
+
288
+ # layers
289
+ self.fc1 = nn.Linear(dim, mid_dim)
290
+ self.fc2 = nn.Linear(dim, mid_dim)
291
+ self.fc3 = nn.Linear(mid_dim, dim)
292
+
293
+ def forward(self, x):
294
+ x = F.silu(self.fc1(x)) * self.fc2(x)
295
+ x = self.fc3(x)
296
+ return x
297
+
298
+
299
+ class AttentionBlock(nn.Module):
300
+ def __init__(
301
+ self,
302
+ dim,
303
+ mlp_ratio,
304
+ num_heads,
305
+ post_norm=False,
306
+ causal=False,
307
+ activation="quick_gelu",
308
+ attn_dropout=0.0,
309
+ proj_dropout=0.0,
310
+ norm_eps=1e-5,
311
+ ):
312
+ assert activation in ["quick_gelu", "gelu", "swi_glu"]
313
+ super().__init__()
314
+ self.dim = dim
315
+ self.mlp_ratio = mlp_ratio
316
+ self.num_heads = num_heads
317
+ self.post_norm = post_norm
318
+ self.causal = causal
319
+ self.norm_eps = norm_eps
320
+
321
+ # layers
322
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
323
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
324
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
325
+ if activation == "swi_glu":
326
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
327
+ else:
328
+ self.mlp = nn.Sequential(
329
+ nn.Linear(dim, int(dim * mlp_ratio)),
330
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
331
+ nn.Linear(int(dim * mlp_ratio), dim),
332
+ nn.Dropout(proj_dropout),
333
+ )
334
+
335
+ def forward(self, x):
336
+ if self.post_norm:
337
+ x = x + self.norm1(self.attn(x))
338
+ x = x + self.norm2(self.mlp(x))
339
+ else:
340
+ x = x + self.attn(self.norm1(x))
341
+ x = x + self.mlp(self.norm2(x))
342
+ return x
343
+
344
+
345
+ class AttentionPool(nn.Module):
346
+ def __init__(
347
+ self,
348
+ dim,
349
+ mlp_ratio,
350
+ num_heads,
351
+ activation="gelu",
352
+ proj_dropout=0.0,
353
+ norm_eps=1e-5,
354
+ ):
355
+ assert dim % num_heads == 0
356
+ super().__init__()
357
+ self.dim = dim
358
+ self.mlp_ratio = mlp_ratio
359
+ self.num_heads = num_heads
360
+ self.head_dim = dim // num_heads
361
+ self.proj_dropout = proj_dropout
362
+ self.norm_eps = norm_eps
363
+
364
+ # layers
365
+ gain = 1.0 / math.sqrt(dim)
366
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
367
+ self.to_q = nn.Linear(dim, dim)
368
+ self.to_kv = nn.Linear(dim, dim * 2)
369
+ self.proj = nn.Linear(dim, dim)
370
+ self.norm = LayerNorm(dim, eps=norm_eps)
371
+ self.mlp = nn.Sequential(
372
+ nn.Linear(dim, int(dim * mlp_ratio)),
373
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
374
+ nn.Linear(int(dim * mlp_ratio), dim),
375
+ nn.Dropout(proj_dropout),
376
+ )
377
+
378
+ def forward(self, x):
379
+ """
380
+ x: [B, L, C].
381
+ """
382
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
383
+
384
+ # compute query, key, value
385
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
386
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
387
+
388
+ # compute attention
389
+ x = flash_attention(q, k, v, version=2)
390
+ x = x.reshape(b, 1, c)
391
+
392
+ # output
393
+ x = self.proj(x)
394
+ x = F.dropout(x, self.proj_dropout, self.training)
395
+
396
+ # mlp
397
+ x = x + self.mlp(self.norm(x))
398
+ return x[:, 0]
399
+
400
+
401
+ class VisionTransformer(nn.Module):
402
+ def __init__(
403
+ self,
404
+ image_size=224,
405
+ patch_size=16,
406
+ dim=768,
407
+ mlp_ratio=4,
408
+ out_dim=512,
409
+ num_heads=12,
410
+ num_layers=12,
411
+ pool_type="token",
412
+ pre_norm=True,
413
+ post_norm=False,
414
+ activation="quick_gelu",
415
+ attn_dropout=0.0,
416
+ proj_dropout=0.0,
417
+ embedding_dropout=0.0,
418
+ norm_eps=1e-5,
419
+ ):
420
+ if image_size % patch_size != 0:
421
+ print("[WARNING] image_size is not divisible by patch_size", flush=True)
422
+ assert pool_type in ("token", "token_fc", "attn_pool")
423
+ out_dim = out_dim or dim
424
+ super().__init__()
425
+ self.image_size = image_size
426
+ self.patch_size = patch_size
427
+ self.num_patches = (image_size // patch_size) ** 2
428
+ self.dim = dim
429
+ self.mlp_ratio = mlp_ratio
430
+ self.out_dim = out_dim
431
+ self.num_heads = num_heads
432
+ self.num_layers = num_layers
433
+ self.pool_type = pool_type
434
+ self.post_norm = post_norm
435
+ self.norm_eps = norm_eps
436
+
437
+ # embeddings
438
+ gain = 1.0 / math.sqrt(dim)
439
+ self.patch_embedding = nn.Conv2d(
440
+ 3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm
441
+ )
442
+ if pool_type in ("token", "token_fc"):
443
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
444
+ self.pos_embedding = nn.Parameter(
445
+ gain
446
+ * torch.randn(
447
+ 1,
448
+ self.num_patches + (1 if pool_type in ("token", "token_fc") else 0),
449
+ dim,
450
+ )
451
+ )
452
+ self.dropout = nn.Dropout(embedding_dropout)
453
+
454
+ # transformer
455
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
456
+ self.transformer = nn.Sequential(
457
+ *[
458
+ AttentionBlock(
459
+ dim,
460
+ mlp_ratio,
461
+ num_heads,
462
+ post_norm,
463
+ False,
464
+ activation,
465
+ attn_dropout,
466
+ proj_dropout,
467
+ norm_eps,
468
+ )
469
+ for _ in range(num_layers)
470
+ ]
471
+ )
472
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
473
+
474
+ # head
475
+ if pool_type == "token":
476
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
477
+ elif pool_type == "token_fc":
478
+ self.head = nn.Linear(dim, out_dim)
479
+ elif pool_type == "attn_pool":
480
+ self.head = AttentionPool(
481
+ dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps
482
+ )
483
+
484
+ def forward(self, x, interpolation=False, use_31_block=False):
485
+ b = x.size(0)
486
+
487
+ # embeddings
488
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
489
+ if self.pool_type in ("token", "token_fc"):
490
+ x = torch.cat(
491
+ [
492
+ self.cls_embedding.expand(b, -1, -1).to(
493
+ dtype=x.dtype, device=x.device
494
+ ),
495
+ x,
496
+ ],
497
+ dim=1,
498
+ )
499
+ if interpolation:
500
+ e = pos_interpolate(self.pos_embedding, x.size(1))
501
+ else:
502
+ e = self.pos_embedding
503
+ e = e.to(dtype=x.dtype, device=x.device)
504
+ x = self.dropout(x + e)
505
+ if self.pre_norm is not None:
506
+ x = self.pre_norm(x)
507
+
508
+ # transformer
509
+ if use_31_block:
510
+ x = self.transformer[:-1](x)
511
+ return x
512
+ else:
513
+ x = self.transformer(x)
514
+ return x
515
+
516
+
517
+ class CLIP(nn.Module):
518
+ def __init__(
519
+ self,
520
+ embed_dim=512,
521
+ image_size=224,
522
+ patch_size=16,
523
+ vision_dim=768,
524
+ vision_mlp_ratio=4,
525
+ vision_heads=12,
526
+ vision_layers=12,
527
+ vision_pool="token",
528
+ vision_pre_norm=True,
529
+ vision_post_norm=False,
530
+ vocab_size=49408,
531
+ text_len=77,
532
+ text_dim=512,
533
+ text_mlp_ratio=4,
534
+ text_heads=8,
535
+ text_layers=12,
536
+ text_causal=True,
537
+ text_pool="argmax",
538
+ text_head_bias=False,
539
+ logit_bias=None,
540
+ activation="quick_gelu",
541
+ attn_dropout=0.0,
542
+ proj_dropout=0.0,
543
+ embedding_dropout=0.0,
544
+ norm_eps=1e-5,
545
+ ):
546
+ super().__init__()
547
+ self.embed_dim = embed_dim
548
+ self.image_size = image_size
549
+ self.patch_size = patch_size
550
+ self.vision_dim = vision_dim
551
+ self.vision_mlp_ratio = vision_mlp_ratio
552
+ self.vision_heads = vision_heads
553
+ self.vision_layers = vision_layers
554
+ self.vision_pool = vision_pool
555
+ self.vision_pre_norm = vision_pre_norm
556
+ self.vision_post_norm = vision_post_norm
557
+ self.vocab_size = vocab_size
558
+ self.text_len = text_len
559
+ self.text_dim = text_dim
560
+ self.text_mlp_ratio = text_mlp_ratio
561
+ self.text_heads = text_heads
562
+ self.text_layers = text_layers
563
+ self.text_causal = text_causal
564
+ self.text_pool = text_pool
565
+ self.text_head_bias = text_head_bias
566
+ self.norm_eps = norm_eps
567
+
568
+ # models
569
+ self.visual = VisionTransformer(
570
+ image_size=image_size,
571
+ patch_size=patch_size,
572
+ dim=vision_dim,
573
+ mlp_ratio=vision_mlp_ratio,
574
+ out_dim=embed_dim,
575
+ num_heads=vision_heads,
576
+ num_layers=vision_layers,
577
+ pool_type=vision_pool,
578
+ pre_norm=vision_pre_norm,
579
+ post_norm=vision_post_norm,
580
+ activation=activation,
581
+ attn_dropout=attn_dropout,
582
+ proj_dropout=proj_dropout,
583
+ embedding_dropout=embedding_dropout,
584
+ norm_eps=norm_eps,
585
+ )
586
+ self.textual = TextTransformer(
587
+ vocab_size=vocab_size,
588
+ text_len=text_len,
589
+ dim=text_dim,
590
+ mlp_ratio=text_mlp_ratio,
591
+ out_dim=embed_dim,
592
+ num_heads=text_heads,
593
+ num_layers=text_layers,
594
+ causal=text_causal,
595
+ pool_type=text_pool,
596
+ head_bias=text_head_bias,
597
+ activation=activation,
598
+ attn_dropout=attn_dropout,
599
+ proj_dropout=proj_dropout,
600
+ embedding_dropout=embedding_dropout,
601
+ norm_eps=norm_eps,
602
+ )
603
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
604
+ if logit_bias is not None:
605
+ self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
606
+
607
+ # initialize weights
608
+ self.init_weights()
609
+
610
+ def forward(self, imgs, txt_ids):
611
+ """
612
+ imgs: [B, 3, H, W] of torch.float32.
613
+ - mean: [0.48145466, 0.4578275, 0.40821073]
614
+ - std: [0.26862954, 0.26130258, 0.27577711]
615
+ txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
616
+ """
617
+ xi = self.visual(imgs)
618
+ xt = self.textual(txt_ids)
619
+ return xi, xt
620
+
621
+ def init_weights(self):
622
+ # embeddings
623
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
624
+ nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
625
+
626
+ # attentions
627
+ for modality in ["visual", "textual"]:
628
+ dim = self.vision_dim if modality == "visual" else self.text_dim
629
+ transformer = getattr(self, modality).transformer
630
+ proj_gain = (1.0 / math.sqrt(dim)) * (1.0 / math.sqrt(2 * len(transformer)))
631
+ attn_gain = 1.0 / math.sqrt(dim)
632
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
633
+ for block in transformer:
634
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
635
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
636
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
637
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
638
+
639
+ def param_groups(self):
640
+ groups = [
641
+ {
642
+ "params": [
643
+ p
644
+ for n, p in self.named_parameters()
645
+ if "norm" in n or n.endswith("bias")
646
+ ],
647
+ "weight_decay": 0.0,
648
+ },
649
+ {
650
+ "params": [
651
+ p
652
+ for n, p in self.named_parameters()
653
+ if not ("norm" in n or n.endswith("bias"))
654
+ ]
655
+ },
656
+ ]
657
+ return groups
658
+
659
+
660
+ class XLMRobertaWithHead(XLMRoberta):
661
+ def __init__(self, **kwargs):
662
+ self.out_dim = kwargs.pop("out_dim")
663
+ super().__init__(**kwargs)
664
+
665
+ # head
666
+ mid_dim = (self.dim + self.out_dim) // 2
667
+ self.head = nn.Sequential(
668
+ nn.Linear(self.dim, mid_dim, bias=False),
669
+ nn.GELU(),
670
+ nn.Linear(mid_dim, self.out_dim, bias=False),
671
+ )
672
+
673
+ def forward(self, ids):
674
+ # xlm-roberta
675
+ x = super().forward(ids)
676
+
677
+ # average pooling
678
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
679
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
680
+
681
+ # head
682
+ x = self.head(x)
683
+ return x
684
+
685
+
686
+ class XLMRobertaCLIP(nn.Module):
687
+ def __init__(
688
+ self,
689
+ embed_dim=1024,
690
+ image_size=224,
691
+ patch_size=14,
692
+ vision_dim=1280,
693
+ vision_mlp_ratio=4,
694
+ vision_heads=16,
695
+ vision_layers=32,
696
+ vision_pool="token",
697
+ vision_pre_norm=True,
698
+ vision_post_norm=False,
699
+ activation="gelu",
700
+ vocab_size=250002,
701
+ max_text_len=514,
702
+ type_size=1,
703
+ pad_id=1,
704
+ text_dim=1024,
705
+ text_heads=16,
706
+ text_layers=24,
707
+ text_post_norm=True,
708
+ text_dropout=0.1,
709
+ attn_dropout=0.0,
710
+ proj_dropout=0.0,
711
+ embedding_dropout=0.0,
712
+ norm_eps=1e-5,
713
+ ):
714
+ super().__init__()
715
+ self.embed_dim = embed_dim
716
+ self.image_size = image_size
717
+ self.patch_size = patch_size
718
+ self.vision_dim = vision_dim
719
+ self.vision_mlp_ratio = vision_mlp_ratio
720
+ self.vision_heads = vision_heads
721
+ self.vision_layers = vision_layers
722
+ self.vision_pre_norm = vision_pre_norm
723
+ self.vision_post_norm = vision_post_norm
724
+ self.activation = activation
725
+ self.vocab_size = vocab_size
726
+ self.max_text_len = max_text_len
727
+ self.type_size = type_size
728
+ self.pad_id = pad_id
729
+ self.text_dim = text_dim
730
+ self.text_heads = text_heads
731
+ self.text_layers = text_layers
732
+ self.text_post_norm = text_post_norm
733
+ self.norm_eps = norm_eps
734
+
735
+ # models
736
+ self.visual = VisionTransformer(
737
+ image_size=image_size,
738
+ patch_size=patch_size,
739
+ dim=vision_dim,
740
+ mlp_ratio=vision_mlp_ratio,
741
+ out_dim=embed_dim,
742
+ num_heads=vision_heads,
743
+ num_layers=vision_layers,
744
+ pool_type=vision_pool,
745
+ pre_norm=vision_pre_norm,
746
+ post_norm=vision_post_norm,
747
+ activation=activation,
748
+ attn_dropout=attn_dropout,
749
+ proj_dropout=proj_dropout,
750
+ embedding_dropout=embedding_dropout,
751
+ norm_eps=norm_eps,
752
+ )
753
+ self.textual = None
754
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
755
+
756
+ def forward(self, imgs, txt_ids):
757
+ """
758
+ imgs: [B, 3, H, W] of torch.float32.
759
+ - mean: [0.48145466, 0.4578275, 0.40821073]
760
+ - std: [0.26862954, 0.26130258, 0.27577711]
761
+ txt_ids: [B, L] of torch.long.
762
+ Encoded by data.CLIPTokenizer.
763
+ """
764
+ xi = self.visual(imgs)
765
+ xt = self.textual(txt_ids)
766
+ return xi, xt
767
+
768
+ def param_groups(self):
769
+ groups = [
770
+ {
771
+ "params": [
772
+ p
773
+ for n, p in self.named_parameters()
774
+ if "norm" in n or n.endswith("bias")
775
+ ],
776
+ "weight_decay": 0.0,
777
+ },
778
+ {
779
+ "params": [
780
+ p
781
+ for n, p in self.named_parameters()
782
+ if not ("norm" in n or n.endswith("bias"))
783
+ ]
784
+ },
785
+ ]
786
+ return groups
787
+
788
+
789
+ def _clip(
790
+ pretrained=False,
791
+ pretrained_name=None,
792
+ model_cls=CLIP,
793
+ return_transforms=False,
794
+ return_tokenizer=False,
795
+ tokenizer_padding="eos",
796
+ dtype=torch.float32,
797
+ device="cpu",
798
+ **kwargs,
799
+ ):
800
+ # init model
801
+ if pretrained and pretrained_name:
802
+ from sora import BUCKET, DOWNLOAD_TO_CACHE
803
+
804
+ # init a meta model
805
+ with torch.device("meta"):
806
+ model = model_cls(**kwargs)
807
+
808
+ # checkpoint path
809
+ checkpoint = f"models/clip/{pretrained_name}"
810
+ if dtype in (torch.float16, torch.bfloat16):
811
+ suffix = "-" + {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
812
+ if object_exists(BUCKET, f"{checkpoint}{suffix}.pth"):
813
+ checkpoint = f"{checkpoint}{suffix}"
814
+ checkpoint += ".pth"
815
+
816
+ # load
817
+ model.load_state_dict(
818
+ torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
819
+ assign=True,
820
+ strict=False,
821
+ )
822
+ else:
823
+ # init a model on device
824
+ with torch.device(device):
825
+ model = model_cls(**kwargs)
826
+
827
+ # set device
828
+ output = (model,)
829
+
830
+ # init transforms
831
+ if return_transforms:
832
+ # mean and std
833
+ if "siglip" in pretrained_name.lower():
834
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
835
+ else:
836
+ mean = [0.48145466, 0.4578275, 0.40821073]
837
+ std = [0.26862954, 0.26130258, 0.27577711]
838
+
839
+ # transforms
840
+ transforms = T.Compose(
841
+ [
842
+ T.Resize(
843
+ (model.image_size, model.image_size),
844
+ interpolation=T.InterpolationMode.BICUBIC,
845
+ ),
846
+ T.ToTensor(),
847
+ T.Normalize(mean=mean, std=std),
848
+ ]
849
+ )
850
+ output += (transforms,)
851
+
852
+ # init tokenizer
853
+ if return_tokenizer:
854
+ from sora import data
855
+
856
+ if "siglip" in pretrained_name.lower():
857
+ tokenizer = data.HuggingfaceTokenizer(
858
+ name=f"timm/{pretrained_name}",
859
+ seq_len=model.text_len,
860
+ clean="canonicalize",
861
+ )
862
+ elif "xlm" in pretrained_name.lower():
863
+ tokenizer = data.HuggingfaceTokenizer(
864
+ name="xlm-roberta-large",
865
+ seq_len=model.max_text_len - 2,
866
+ clean="whitespace",
867
+ )
868
+ elif "mba" in pretrained_name.lower():
869
+ tokenizer = data.HuggingfaceTokenizer(
870
+ name="facebook/xlm-roberta-xl",
871
+ seq_len=model.max_text_len - 2,
872
+ clean="whitespace",
873
+ )
874
+ else:
875
+ tokenizer = data.CLIPTokenizer(
876
+ seq_len=model.text_len, padding=tokenizer_padding
877
+ )
878
+ output += (tokenizer,)
879
+ return output[0] if len(output) == 1 else output
880
+
881
+
882
+ def clip_xlm_roberta_vit_h_14(
883
+ pretrained=False,
884
+ pretrained_name="open-clip-xlm-roberta-large-vit-huge-14",
885
+ **kwargs,
886
+ ):
887
+ cfg = dict(
888
+ embed_dim=1024,
889
+ image_size=224,
890
+ patch_size=14,
891
+ vision_dim=1280,
892
+ vision_mlp_ratio=4,
893
+ vision_heads=16,
894
+ vision_layers=32,
895
+ vision_pool="token",
896
+ activation="gelu",
897
+ vocab_size=250002,
898
+ max_text_len=514,
899
+ type_size=1,
900
+ pad_id=1,
901
+ text_dim=1024,
902
+ text_heads=16,
903
+ text_layers=24,
904
+ text_post_norm=True,
905
+ text_dropout=0.1,
906
+ attn_dropout=0.0,
907
+ proj_dropout=0.0,
908
+ embedding_dropout=0.0,
909
+ )
910
+ cfg.update(**kwargs)
911
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
912
+
913
+
914
+ class WanImageEncoder(torch.nn.Module):
915
+ def __init__(self):
916
+ super().__init__()
917
+ # init model
918
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
919
+ pretrained=False,
920
+ return_transforms=True,
921
+ return_tokenizer=False,
922
+ dtype=torch.float32,
923
+ device="cpu",
924
+ )
925
+
926
+ def encode_image(self, videos):
927
+ # preprocess
928
+ size = (self.model.image_size,) * 2
929
+ videos = torch.cat(
930
+ [
931
+ F.interpolate(u, size=size, mode="bicubic", align_corners=False)
932
+ for u in videos
933
+ ]
934
+ )
935
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
936
+
937
+ # forward
938
+ out = self.model.visual(videos, use_31_block=True)
939
+ return out
940
+
941
+ @staticmethod
942
+ def state_dict_converter():
943
+ return WanImageEncoderStateDictConverter()
944
+
945
+
946
+ class WanImageEncoderStateDictConverter:
947
+ def __init__(self):
948
+ pass
949
+
950
+ def from_diffusers(self, state_dict):
951
+ return state_dict
952
+
953
+ def from_civitai(self, state_dict):
954
+ state_dict_ = {}
955
+ for name, param in state_dict.items():
956
+ if name.startswith("textual."):
957
+ continue
958
+ name = "model." + name
959
+ state_dict_[name] = param
960
+ return state_dict_
FantasyTalking/diffsynth/models/wan_video_text_encoder.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+ def forward(self, x):
17
+ return (
18
+ 0.5
19
+ * x
20
+ * (
21
+ 1.0
22
+ + torch.tanh(
23
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
24
+ )
25
+ )
26
+ )
27
+
28
+
29
+ class T5LayerNorm(nn.Module):
30
+ def __init__(self, dim, eps=1e-6):
31
+ super(T5LayerNorm, self).__init__()
32
+ self.dim = dim
33
+ self.eps = eps
34
+ self.weight = nn.Parameter(torch.ones(dim))
35
+
36
+ def forward(self, x):
37
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
38
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
39
+ x = x.type_as(self.weight)
40
+ return self.weight * x
41
+
42
+
43
+ class T5Attention(nn.Module):
44
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
45
+ assert dim_attn % num_heads == 0
46
+ super(T5Attention, self).__init__()
47
+ self.dim = dim
48
+ self.dim_attn = dim_attn
49
+ self.num_heads = num_heads
50
+ self.head_dim = dim_attn // num_heads
51
+
52
+ # layers
53
+ self.q = nn.Linear(dim, dim_attn, bias=False)
54
+ self.k = nn.Linear(dim, dim_attn, bias=False)
55
+ self.v = nn.Linear(dim, dim_attn, bias=False)
56
+ self.o = nn.Linear(dim_attn, dim, bias=False)
57
+ self.dropout = nn.Dropout(dropout)
58
+
59
+ def forward(self, x, context=None, mask=None, pos_bias=None):
60
+ """
61
+ x: [B, L1, C].
62
+ context: [B, L2, C] or None.
63
+ mask: [B, L2] or [B, L1, L2] or None.
64
+ """
65
+ # check inputs
66
+ context = x if context is None else context
67
+ b, n, c = x.size(0), self.num_heads, self.head_dim
68
+
69
+ # compute query, key, value
70
+ q = self.q(x).view(b, -1, n, c)
71
+ k = self.k(context).view(b, -1, n, c)
72
+ v = self.v(context).view(b, -1, n, c)
73
+
74
+ # attention bias
75
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
76
+ if pos_bias is not None:
77
+ attn_bias += pos_bias
78
+ if mask is not None:
79
+ assert mask.ndim in [2, 3]
80
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
81
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
82
+
83
+ # compute attention (T5 does not use scaling)
84
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
85
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
86
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
87
+
88
+ # output
89
+ x = x.reshape(b, -1, n * c)
90
+ x = self.o(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+
95
+ class T5FeedForward(nn.Module):
96
+ def __init__(self, dim, dim_ffn, dropout=0.1):
97
+ super(T5FeedForward, self).__init__()
98
+ self.dim = dim
99
+ self.dim_ffn = dim_ffn
100
+
101
+ # layers
102
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
103
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
104
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
105
+ self.dropout = nn.Dropout(dropout)
106
+
107
+ def forward(self, x):
108
+ x = self.fc1(x) * self.gate(x)
109
+ x = self.dropout(x)
110
+ x = self.fc2(x)
111
+ x = self.dropout(x)
112
+ return x
113
+
114
+
115
+ class T5SelfAttention(nn.Module):
116
+ def __init__(
117
+ self,
118
+ dim,
119
+ dim_attn,
120
+ dim_ffn,
121
+ num_heads,
122
+ num_buckets,
123
+ shared_pos=True,
124
+ dropout=0.1,
125
+ ):
126
+ super(T5SelfAttention, self).__init__()
127
+ self.dim = dim
128
+ self.dim_attn = dim_attn
129
+ self.dim_ffn = dim_ffn
130
+ self.num_heads = num_heads
131
+ self.num_buckets = num_buckets
132
+ self.shared_pos = shared_pos
133
+
134
+ # layers
135
+ self.norm1 = T5LayerNorm(dim)
136
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
137
+ self.norm2 = T5LayerNorm(dim)
138
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
139
+ self.pos_embedding = (
140
+ None
141
+ if shared_pos
142
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
143
+ )
144
+
145
+ def forward(self, x, mask=None, pos_bias=None):
146
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
147
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
148
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
149
+ return x
150
+
151
+
152
+ class T5RelativeEmbedding(nn.Module):
153
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
154
+ super(T5RelativeEmbedding, self).__init__()
155
+ self.num_buckets = num_buckets
156
+ self.num_heads = num_heads
157
+ self.bidirectional = bidirectional
158
+ self.max_dist = max_dist
159
+
160
+ # layers
161
+ self.embedding = nn.Embedding(num_buckets, num_heads)
162
+
163
+ def forward(self, lq, lk):
164
+ device = self.embedding.weight.device
165
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
166
+ # torch.arange(lq).unsqueeze(1).to(device)
167
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
168
+ lq, device=device
169
+ ).unsqueeze(1)
170
+ rel_pos = self._relative_position_bucket(rel_pos)
171
+ rel_pos_embeds = self.embedding(rel_pos)
172
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
173
+ return rel_pos_embeds.contiguous()
174
+
175
+ def _relative_position_bucket(self, rel_pos):
176
+ # preprocess
177
+ if self.bidirectional:
178
+ num_buckets = self.num_buckets // 2
179
+ rel_buckets = (rel_pos > 0).long() * num_buckets
180
+ rel_pos = torch.abs(rel_pos)
181
+ else:
182
+ num_buckets = self.num_buckets
183
+ rel_buckets = 0
184
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
185
+
186
+ # embeddings for small and large positions
187
+ max_exact = num_buckets // 2
188
+ rel_pos_large = (
189
+ max_exact
190
+ + (
191
+ torch.log(rel_pos.float() / max_exact)
192
+ / math.log(self.max_dist / max_exact)
193
+ * (num_buckets - max_exact)
194
+ ).long()
195
+ )
196
+ rel_pos_large = torch.min(
197
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
198
+ )
199
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
200
+ return rel_buckets
201
+
202
+
203
+ def init_weights(m):
204
+ if isinstance(m, T5LayerNorm):
205
+ nn.init.ones_(m.weight)
206
+ elif isinstance(m, T5FeedForward):
207
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
208
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
209
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
210
+ elif isinstance(m, T5Attention):
211
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
212
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
213
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
214
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
215
+ elif isinstance(m, T5RelativeEmbedding):
216
+ nn.init.normal_(
217
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
218
+ )
219
+
220
+
221
+ class WanTextEncoder(torch.nn.Module):
222
+ def __init__(
223
+ self,
224
+ vocab=256384,
225
+ dim=4096,
226
+ dim_attn=4096,
227
+ dim_ffn=10240,
228
+ num_heads=64,
229
+ num_layers=24,
230
+ num_buckets=32,
231
+ shared_pos=False,
232
+ dropout=0.1,
233
+ ):
234
+ super(WanTextEncoder, self).__init__()
235
+ self.dim = dim
236
+ self.dim_attn = dim_attn
237
+ self.dim_ffn = dim_ffn
238
+ self.num_heads = num_heads
239
+ self.num_layers = num_layers
240
+ self.num_buckets = num_buckets
241
+ self.shared_pos = shared_pos
242
+
243
+ # layers
244
+ self.token_embedding = (
245
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
246
+ )
247
+ self.pos_embedding = (
248
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
249
+ if shared_pos
250
+ else None
251
+ )
252
+ self.dropout = nn.Dropout(dropout)
253
+ self.blocks = nn.ModuleList(
254
+ [
255
+ T5SelfAttention(
256
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
257
+ )
258
+ for _ in range(num_layers)
259
+ ]
260
+ )
261
+ self.norm = T5LayerNorm(dim)
262
+
263
+ # initialize weights
264
+ self.apply(init_weights)
265
+
266
+ def forward(self, ids, mask=None):
267
+ x = self.token_embedding(ids)
268
+ x = self.dropout(x)
269
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
270
+ for block in self.blocks:
271
+ x = block(x, mask, pos_bias=e)
272
+ x = self.norm(x)
273
+ x = self.dropout(x)
274
+ return x
275
+
276
+ @staticmethod
277
+ def state_dict_converter():
278
+ return WanTextEncoderStateDictConverter()
279
+
280
+
281
+ class WanTextEncoderStateDictConverter:
282
+ def __init__(self):
283
+ pass
284
+
285
+ def from_diffusers(self, state_dict):
286
+ return state_dict
287
+
288
+ def from_civitai(self, state_dict):
289
+ return state_dict
FantasyTalking/diffsynth/models/wan_video_vae.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+ from tqdm import tqdm
6
+
7
+ CACHE_T = 2
8
+
9
+
10
+ def check_is_instance(model, module_class):
11
+ if isinstance(model, module_class):
12
+ return True
13
+ if hasattr(model, "module") and isinstance(model.module, module_class):
14
+ return True
15
+ return False
16
+
17
+
18
+ def block_causal_mask(x, block_size):
19
+ # params
20
+ b, n, s, _, device = *x.size(), x.device
21
+ assert s % block_size == 0
22
+ num_blocks = s // block_size
23
+
24
+ # build mask
25
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
26
+ for i in range(num_blocks):
27
+ mask[:, :, i * block_size : (i + 1) * block_size, : (i + 1) * block_size] = 1
28
+ return mask
29
+
30
+
31
+ class CausalConv3d(nn.Conv3d):
32
+ """
33
+ Causal 3d convolusion.
34
+ """
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ super().__init__(*args, **kwargs)
38
+ self._padding = (
39
+ self.padding[2],
40
+ self.padding[2],
41
+ self.padding[1],
42
+ self.padding[1],
43
+ 2 * self.padding[0],
44
+ 0,
45
+ )
46
+ self.padding = (0, 0, 0)
47
+
48
+ def forward(self, x, cache_x=None):
49
+ padding = list(self._padding)
50
+ if cache_x is not None and self._padding[4] > 0:
51
+ cache_x = cache_x.to(x.device)
52
+ x = torch.cat([cache_x, x], dim=2)
53
+ padding[4] -= cache_x.shape[2]
54
+ x = F.pad(x, padding)
55
+
56
+ return super().forward(x)
57
+
58
+
59
+ class RMS_norm(nn.Module):
60
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
61
+ super().__init__()
62
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
63
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
64
+
65
+ self.channel_first = channel_first
66
+ self.scale = dim**0.5
67
+ self.gamma = nn.Parameter(torch.ones(shape))
68
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
69
+
70
+ def forward(self, x):
71
+ return (
72
+ F.normalize(x, dim=(1 if self.channel_first else -1))
73
+ * self.scale
74
+ * self.gamma
75
+ + self.bias
76
+ )
77
+
78
+
79
+ class Upsample(nn.Upsample):
80
+ def forward(self, x):
81
+ """
82
+ Fix bfloat16 support for nearest neighbor interpolation.
83
+ """
84
+ return super().forward(x.float()).type_as(x)
85
+
86
+
87
+ class Resample(nn.Module):
88
+ def __init__(self, dim, mode):
89
+ assert mode in (
90
+ "none",
91
+ "upsample2d",
92
+ "upsample3d",
93
+ "downsample2d",
94
+ "downsample3d",
95
+ )
96
+ super().__init__()
97
+ self.dim = dim
98
+ self.mode = mode
99
+
100
+ # layers
101
+ if mode == "upsample2d":
102
+ self.resample = nn.Sequential(
103
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
104
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
105
+ )
106
+ elif mode == "upsample3d":
107
+ self.resample = nn.Sequential(
108
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
109
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
110
+ )
111
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
112
+
113
+ elif mode == "downsample2d":
114
+ self.resample = nn.Sequential(
115
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
116
+ )
117
+ elif mode == "downsample3d":
118
+ self.resample = nn.Sequential(
119
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
120
+ )
121
+ self.time_conv = CausalConv3d(
122
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
123
+ )
124
+
125
+ else:
126
+ self.resample = nn.Identity()
127
+
128
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
129
+ b, c, t, h, w = x.size()
130
+ if self.mode == "upsample3d":
131
+ if feat_cache is not None:
132
+ idx = feat_idx[0]
133
+ if feat_cache[idx] is None:
134
+ feat_cache[idx] = "Rep"
135
+ feat_idx[0] += 1
136
+ else:
137
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
138
+ if (
139
+ cache_x.shape[2] < 2
140
+ and feat_cache[idx] is not None
141
+ and feat_cache[idx] != "Rep"
142
+ ):
143
+ # cache last frame of last two chunk
144
+ cache_x = torch.cat(
145
+ [
146
+ feat_cache[idx][:, :, -1, :, :]
147
+ .unsqueeze(2)
148
+ .to(cache_x.device),
149
+ cache_x,
150
+ ],
151
+ dim=2,
152
+ )
153
+ if (
154
+ cache_x.shape[2] < 2
155
+ and feat_cache[idx] is not None
156
+ and feat_cache[idx] == "Rep"
157
+ ):
158
+ cache_x = torch.cat(
159
+ [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
160
+ dim=2,
161
+ )
162
+ if feat_cache[idx] == "Rep":
163
+ x = self.time_conv(x)
164
+ else:
165
+ x = self.time_conv(x, feat_cache[idx])
166
+ feat_cache[idx] = cache_x
167
+ feat_idx[0] += 1
168
+
169
+ x = x.reshape(b, 2, c, t, h, w)
170
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
171
+ x = x.reshape(b, c, t * 2, h, w)
172
+ t = x.shape[2]
173
+ x = rearrange(x, "b c t h w -> (b t) c h w")
174
+ x = self.resample(x)
175
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
176
+
177
+ if self.mode == "downsample3d":
178
+ if feat_cache is not None:
179
+ idx = feat_idx[0]
180
+ if feat_cache[idx] is None:
181
+ feat_cache[idx] = x.clone()
182
+ feat_idx[0] += 1
183
+ else:
184
+ cache_x = x[:, :, -1:, :, :].clone()
185
+ x = self.time_conv(
186
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
187
+ )
188
+ feat_cache[idx] = cache_x
189
+ feat_idx[0] += 1
190
+ return x
191
+
192
+ def init_weight(self, conv):
193
+ conv_weight = conv.weight
194
+ nn.init.zeros_(conv_weight)
195
+ c1, c2, t, h, w = conv_weight.size()
196
+ one_matrix = torch.eye(c1, c2)
197
+ init_matrix = one_matrix
198
+ nn.init.zeros_(conv_weight)
199
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
200
+ conv.weight.data.copy_(conv_weight)
201
+ nn.init.zeros_(conv.bias.data)
202
+
203
+ def init_weight2(self, conv):
204
+ conv_weight = conv.weight.data
205
+ nn.init.zeros_(conv_weight)
206
+ c1, c2, t, h, w = conv_weight.size()
207
+ init_matrix = torch.eye(c1 // 2, c2)
208
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
209
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
210
+ conv.weight.data.copy_(conv_weight)
211
+ nn.init.zeros_(conv.bias.data)
212
+
213
+
214
+ class ResidualBlock(nn.Module):
215
+ def __init__(self, in_dim, out_dim, dropout=0.0):
216
+ super().__init__()
217
+ self.in_dim = in_dim
218
+ self.out_dim = out_dim
219
+
220
+ # layers
221
+ self.residual = nn.Sequential(
222
+ RMS_norm(in_dim, images=False),
223
+ nn.SiLU(),
224
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
225
+ RMS_norm(out_dim, images=False),
226
+ nn.SiLU(),
227
+ nn.Dropout(dropout),
228
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
229
+ )
230
+ self.shortcut = (
231
+ CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
232
+ )
233
+
234
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
235
+ h = self.shortcut(x)
236
+ for layer in self.residual:
237
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
238
+ idx = feat_idx[0]
239
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
240
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
241
+ # cache last frame of last two chunk
242
+ cache_x = torch.cat(
243
+ [
244
+ feat_cache[idx][:, :, -1, :, :]
245
+ .unsqueeze(2)
246
+ .to(cache_x.device),
247
+ cache_x,
248
+ ],
249
+ dim=2,
250
+ )
251
+ x = layer(x, feat_cache[idx])
252
+ feat_cache[idx] = cache_x
253
+ feat_idx[0] += 1
254
+ else:
255
+ x = layer(x)
256
+ return x + h
257
+
258
+
259
+ class AttentionBlock(nn.Module):
260
+ """
261
+ Causal self-attention with a single head.
262
+ """
263
+
264
+ def __init__(self, dim):
265
+ super().__init__()
266
+ self.dim = dim
267
+
268
+ # layers
269
+ self.norm = RMS_norm(dim)
270
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
271
+ self.proj = nn.Conv2d(dim, dim, 1)
272
+
273
+ # zero out the last layer params
274
+ nn.init.zeros_(self.proj.weight)
275
+
276
+ def forward(self, x):
277
+ identity = x
278
+ b, c, t, h, w = x.size()
279
+ x = rearrange(x, "b c t h w -> (b t) c h w")
280
+ x = self.norm(x)
281
+ # compute query, key, value
282
+ q, k, v = (
283
+ self.to_qkv(x)
284
+ .reshape(b * t, 1, c * 3, -1)
285
+ .permute(0, 1, 3, 2)
286
+ .contiguous()
287
+ .chunk(3, dim=-1)
288
+ )
289
+
290
+ # apply attention
291
+ x = F.scaled_dot_product_attention(
292
+ q,
293
+ k,
294
+ v,
295
+ # attn_mask=block_causal_mask(q, block_size=h * w)
296
+ )
297
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
298
+
299
+ # output
300
+ x = self.proj(x)
301
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
302
+ return x + identity
303
+
304
+
305
+ class Encoder3d(nn.Module):
306
+ def __init__(
307
+ self,
308
+ dim=128,
309
+ z_dim=4,
310
+ dim_mult=[1, 2, 4, 4],
311
+ num_res_blocks=2,
312
+ attn_scales=[],
313
+ temperal_downsample=[True, True, False],
314
+ dropout=0.0,
315
+ ):
316
+ super().__init__()
317
+ self.dim = dim
318
+ self.z_dim = z_dim
319
+ self.dim_mult = dim_mult
320
+ self.num_res_blocks = num_res_blocks
321
+ self.attn_scales = attn_scales
322
+ self.temperal_downsample = temperal_downsample
323
+
324
+ # dimensions
325
+ dims = [dim * u for u in [1] + dim_mult]
326
+ scale = 1.0
327
+
328
+ # init block
329
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
330
+
331
+ # downsample blocks
332
+ downsamples = []
333
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
334
+ # residual (+attention) blocks
335
+ for _ in range(num_res_blocks):
336
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
337
+ if scale in attn_scales:
338
+ downsamples.append(AttentionBlock(out_dim))
339
+ in_dim = out_dim
340
+
341
+ # downsample block
342
+ if i != len(dim_mult) - 1:
343
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
344
+ downsamples.append(Resample(out_dim, mode=mode))
345
+ scale /= 2.0
346
+ self.downsamples = nn.Sequential(*downsamples)
347
+
348
+ # middle blocks
349
+ self.middle = nn.Sequential(
350
+ ResidualBlock(out_dim, out_dim, dropout),
351
+ AttentionBlock(out_dim),
352
+ ResidualBlock(out_dim, out_dim, dropout),
353
+ )
354
+
355
+ # output blocks
356
+ self.head = nn.Sequential(
357
+ RMS_norm(out_dim, images=False),
358
+ nn.SiLU(),
359
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
360
+ )
361
+
362
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
363
+ if feat_cache is not None:
364
+ idx = feat_idx[0]
365
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
366
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
367
+ # cache last frame of last two chunk
368
+ cache_x = torch.cat(
369
+ [
370
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
371
+ cache_x,
372
+ ],
373
+ dim=2,
374
+ )
375
+ x = self.conv1(x, feat_cache[idx])
376
+ feat_cache[idx] = cache_x
377
+ feat_idx[0] += 1
378
+ else:
379
+ x = self.conv1(x)
380
+
381
+ ## downsamples
382
+ for layer in self.downsamples:
383
+ if feat_cache is not None:
384
+ x = layer(x, feat_cache, feat_idx)
385
+ else:
386
+ x = layer(x)
387
+
388
+ ## middle
389
+ for layer in self.middle:
390
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
391
+ x = layer(x, feat_cache, feat_idx)
392
+ else:
393
+ x = layer(x)
394
+
395
+ ## head
396
+ for layer in self.head:
397
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
398
+ idx = feat_idx[0]
399
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
400
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
401
+ # cache last frame of last two chunk
402
+ cache_x = torch.cat(
403
+ [
404
+ feat_cache[idx][:, :, -1, :, :]
405
+ .unsqueeze(2)
406
+ .to(cache_x.device),
407
+ cache_x,
408
+ ],
409
+ dim=2,
410
+ )
411
+ x = layer(x, feat_cache[idx])
412
+ feat_cache[idx] = cache_x
413
+ feat_idx[0] += 1
414
+ else:
415
+ x = layer(x)
416
+ return x
417
+
418
+
419
+ class Decoder3d(nn.Module):
420
+ def __init__(
421
+ self,
422
+ dim=128,
423
+ z_dim=4,
424
+ dim_mult=[1, 2, 4, 4],
425
+ num_res_blocks=2,
426
+ attn_scales=[],
427
+ temperal_upsample=[False, True, True],
428
+ dropout=0.0,
429
+ ):
430
+ super().__init__()
431
+ self.dim = dim
432
+ self.z_dim = z_dim
433
+ self.dim_mult = dim_mult
434
+ self.num_res_blocks = num_res_blocks
435
+ self.attn_scales = attn_scales
436
+ self.temperal_upsample = temperal_upsample
437
+
438
+ # dimensions
439
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
440
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
441
+
442
+ # init block
443
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
444
+
445
+ # middle blocks
446
+ self.middle = nn.Sequential(
447
+ ResidualBlock(dims[0], dims[0], dropout),
448
+ AttentionBlock(dims[0]),
449
+ ResidualBlock(dims[0], dims[0], dropout),
450
+ )
451
+
452
+ # upsample blocks
453
+ upsamples = []
454
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
455
+ # residual (+attention) blocks
456
+ if i == 1 or i == 2 or i == 3:
457
+ in_dim = in_dim // 2
458
+ for _ in range(num_res_blocks + 1):
459
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
460
+ if scale in attn_scales:
461
+ upsamples.append(AttentionBlock(out_dim))
462
+ in_dim = out_dim
463
+
464
+ # upsample block
465
+ if i != len(dim_mult) - 1:
466
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
467
+ upsamples.append(Resample(out_dim, mode=mode))
468
+ scale *= 2.0
469
+ self.upsamples = nn.Sequential(*upsamples)
470
+
471
+ # output blocks
472
+ self.head = nn.Sequential(
473
+ RMS_norm(out_dim, images=False),
474
+ nn.SiLU(),
475
+ CausalConv3d(out_dim, 3, 3, padding=1),
476
+ )
477
+
478
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
479
+ ## conv1
480
+ if feat_cache is not None:
481
+ idx = feat_idx[0]
482
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
483
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
484
+ # cache last frame of last two chunk
485
+ cache_x = torch.cat(
486
+ [
487
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
488
+ cache_x,
489
+ ],
490
+ dim=2,
491
+ )
492
+ x = self.conv1(x, feat_cache[idx])
493
+ feat_cache[idx] = cache_x
494
+ feat_idx[0] += 1
495
+ else:
496
+ x = self.conv1(x)
497
+
498
+ ## middle
499
+ for layer in self.middle:
500
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
501
+ x = layer(x, feat_cache, feat_idx)
502
+ else:
503
+ x = layer(x)
504
+
505
+ ## upsamples
506
+ for layer in self.upsamples:
507
+ if feat_cache is not None:
508
+ x = layer(x, feat_cache, feat_idx)
509
+ else:
510
+ x = layer(x)
511
+
512
+ ## head
513
+ for layer in self.head:
514
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
515
+ idx = feat_idx[0]
516
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
517
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
518
+ # cache last frame of last two chunk
519
+ cache_x = torch.cat(
520
+ [
521
+ feat_cache[idx][:, :, -1, :, :]
522
+ .unsqueeze(2)
523
+ .to(cache_x.device),
524
+ cache_x,
525
+ ],
526
+ dim=2,
527
+ )
528
+ x = layer(x, feat_cache[idx])
529
+ feat_cache[idx] = cache_x
530
+ feat_idx[0] += 1
531
+ else:
532
+ x = layer(x)
533
+ return x
534
+
535
+
536
+ def count_conv3d(model):
537
+ count = 0
538
+ for m in model.modules():
539
+ if check_is_instance(m, CausalConv3d):
540
+ count += 1
541
+ return count
542
+
543
+
544
+ class VideoVAE_(nn.Module):
545
+ def __init__(
546
+ self,
547
+ dim=96,
548
+ z_dim=16,
549
+ dim_mult=[1, 2, 4, 4],
550
+ num_res_blocks=2,
551
+ attn_scales=[],
552
+ temperal_downsample=[False, True, True],
553
+ dropout=0.0,
554
+ ):
555
+ super().__init__()
556
+ self.dim = dim
557
+ self.z_dim = z_dim
558
+ self.dim_mult = dim_mult
559
+ self.num_res_blocks = num_res_blocks
560
+ self.attn_scales = attn_scales
561
+ self.temperal_downsample = temperal_downsample
562
+ self.temperal_upsample = temperal_downsample[::-1]
563
+
564
+ # modules
565
+ self.encoder = Encoder3d(
566
+ dim,
567
+ z_dim * 2,
568
+ dim_mult,
569
+ num_res_blocks,
570
+ attn_scales,
571
+ self.temperal_downsample,
572
+ dropout,
573
+ )
574
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
575
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
576
+ self.decoder = Decoder3d(
577
+ dim,
578
+ z_dim,
579
+ dim_mult,
580
+ num_res_blocks,
581
+ attn_scales,
582
+ self.temperal_upsample,
583
+ dropout,
584
+ )
585
+
586
+ def forward(self, x):
587
+ mu, log_var = self.encode(x)
588
+ z = self.reparameterize(mu, log_var)
589
+ x_recon = self.decode(z)
590
+ return x_recon, mu, log_var
591
+
592
+ def encode(self, x, scale): # x: B, C, T, H, W
593
+ self.clear_cache()
594
+ ## cache
595
+ t = x.shape[2]
596
+ iter_ = 1 + (t - 1) // 4
597
+
598
+ for i in range(iter_):
599
+ self._enc_conv_idx = [0]
600
+ if i == 0:
601
+ out = self.encoder(
602
+ x[:, :, :1, :, :],
603
+ feat_cache=self._enc_feat_map,
604
+ feat_idx=self._enc_conv_idx,
605
+ )
606
+ else:
607
+ out_ = self.encoder(
608
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
609
+ feat_cache=self._enc_feat_map,
610
+ feat_idx=self._enc_conv_idx,
611
+ )
612
+ out = torch.cat([out, out_], 2)
613
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
614
+ if isinstance(scale[0], torch.Tensor):
615
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
616
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
617
+ 1, self.z_dim, 1, 1, 1
618
+ )
619
+ else:
620
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
621
+ mu = (mu - scale[0]) * scale[1]
622
+ return mu
623
+
624
+ def decode(self, z, scale):
625
+ self.clear_cache()
626
+ # z: [b,c,t,h,w]
627
+ if isinstance(scale[0], torch.Tensor):
628
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
629
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
630
+ 1, self.z_dim, 1, 1, 1
631
+ )
632
+ else:
633
+ scale = scale.to(dtype=z.dtype, device=z.device)
634
+ z = z / scale[1] + scale[0]
635
+ iter_ = z.shape[2]
636
+ x = self.conv2(z)
637
+ for i in range(iter_):
638
+ self._conv_idx = [0]
639
+ if i == 0:
640
+ out = self.decoder(
641
+ x[:, :, i : i + 1, :, :],
642
+ feat_cache=self._feat_map,
643
+ feat_idx=self._conv_idx,
644
+ )
645
+ else:
646
+ out_ = self.decoder(
647
+ x[:, :, i : i + 1, :, :],
648
+ feat_cache=self._feat_map,
649
+ feat_idx=self._conv_idx,
650
+ )
651
+ out = torch.cat([out, out_], 2) # may add tensor offload
652
+ return out
653
+
654
+ def reparameterize(self, mu, log_var):
655
+ std = torch.exp(0.5 * log_var)
656
+ eps = torch.randn_like(std)
657
+ return eps * std + mu
658
+
659
+ def sample(self, imgs, deterministic=False):
660
+ mu, log_var = self.encode(imgs)
661
+ if deterministic:
662
+ return mu
663
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
664
+ return mu + std * torch.randn_like(std)
665
+
666
+ def clear_cache(self):
667
+ self._conv_num = count_conv3d(self.decoder)
668
+ self._conv_idx = [0]
669
+ self._feat_map = [None] * self._conv_num
670
+ # cache encode
671
+ self._enc_conv_num = count_conv3d(self.encoder)
672
+ self._enc_conv_idx = [0]
673
+ self._enc_feat_map = [None] * self._enc_conv_num
674
+
675
+
676
+ class WanVideoVAE(nn.Module):
677
+ def __init__(self, z_dim=16):
678
+ super().__init__()
679
+
680
+ mean = [
681
+ -0.7571,
682
+ -0.7089,
683
+ -0.9113,
684
+ 0.1075,
685
+ -0.1745,
686
+ 0.9653,
687
+ -0.1517,
688
+ 1.5508,
689
+ 0.4134,
690
+ -0.0715,
691
+ 0.5517,
692
+ -0.3632,
693
+ -0.1922,
694
+ -0.9497,
695
+ 0.2503,
696
+ -0.2921,
697
+ ]
698
+ std = [
699
+ 2.8184,
700
+ 1.4541,
701
+ 2.3275,
702
+ 2.6558,
703
+ 1.2196,
704
+ 1.7708,
705
+ 2.6052,
706
+ 2.0743,
707
+ 3.2687,
708
+ 2.1526,
709
+ 2.8652,
710
+ 1.5579,
711
+ 1.6382,
712
+ 1.1253,
713
+ 2.8251,
714
+ 1.9160,
715
+ ]
716
+ self.mean = torch.tensor(mean)
717
+ self.std = torch.tensor(std)
718
+ self.scale = [self.mean, 1.0 / self.std]
719
+
720
+ # init model
721
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
722
+ self.upsampling_factor = 8
723
+
724
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
725
+ x = torch.ones((length,))
726
+ if not left_bound:
727
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
728
+ if not right_bound:
729
+ x[-border_width:] = torch.flip(
730
+ (torch.arange(border_width) + 1) / border_width, dims=(0,)
731
+ )
732
+ return x
733
+
734
+ def build_mask(self, data, is_bound, border_width):
735
+ _, _, _, H, W = data.shape
736
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
737
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
738
+
739
+ h = repeat(h, "H -> H W", H=H, W=W)
740
+ w = repeat(w, "W -> H W", H=H, W=W)
741
+
742
+ mask = torch.stack([h, w]).min(dim=0).values
743
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
744
+ return mask
745
+
746
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
747
+ _, _, T, H, W = hidden_states.shape
748
+ size_h, size_w = tile_size
749
+ stride_h, stride_w = tile_stride
750
+
751
+ # Split tasks
752
+ tasks = []
753
+ for h in range(0, H, stride_h):
754
+ if h - stride_h >= 0 and h - stride_h + size_h >= H:
755
+ continue
756
+ for w in range(0, W, stride_w):
757
+ if w - stride_w >= 0 and w - stride_w + size_w >= W:
758
+ continue
759
+ h_, w_ = h + size_h, w + size_w
760
+ tasks.append((h, h_, w, w_))
761
+
762
+ data_device = "cpu"
763
+ computation_device = device
764
+
765
+ out_T = T * 4 - 3
766
+ weight = torch.zeros(
767
+ (1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
768
+ dtype=hidden_states.dtype,
769
+ device=data_device,
770
+ )
771
+ values = torch.zeros(
772
+ (1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
773
+ dtype=hidden_states.dtype,
774
+ device=data_device,
775
+ )
776
+
777
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
778
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(
779
+ computation_device
780
+ )
781
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(
782
+ data_device
783
+ )
784
+
785
+ mask = self.build_mask(
786
+ hidden_states_batch,
787
+ is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
788
+ border_width=(
789
+ (size_h - stride_h) * self.upsampling_factor,
790
+ (size_w - stride_w) * self.upsampling_factor,
791
+ ),
792
+ ).to(dtype=hidden_states.dtype, device=data_device)
793
+
794
+ target_h = h * self.upsampling_factor
795
+ target_w = w * self.upsampling_factor
796
+ values[
797
+ :,
798
+ :,
799
+ :,
800
+ target_h : target_h + hidden_states_batch.shape[3],
801
+ target_w : target_w + hidden_states_batch.shape[4],
802
+ ] += (
803
+ hidden_states_batch * mask
804
+ )
805
+ weight[
806
+ :,
807
+ :,
808
+ :,
809
+ target_h : target_h + hidden_states_batch.shape[3],
810
+ target_w : target_w + hidden_states_batch.shape[4],
811
+ ] += mask
812
+ values = values / weight
813
+ values = values.float().clamp_(-1, 1)
814
+ return values
815
+
816
+ def tiled_encode(self, video, device, tile_size, tile_stride):
817
+ _, _, T, H, W = video.shape
818
+ size_h, size_w = tile_size
819
+ stride_h, stride_w = tile_stride
820
+
821
+ # Split tasks
822
+ tasks = []
823
+ for h in range(0, H, stride_h):
824
+ if h - stride_h >= 0 and h - stride_h + size_h >= H:
825
+ continue
826
+ for w in range(0, W, stride_w):
827
+ if w - stride_w >= 0 and w - stride_w + size_w >= W:
828
+ continue
829
+ h_, w_ = h + size_h, w + size_w
830
+ tasks.append((h, h_, w, w_))
831
+
832
+ data_device = "cpu"
833
+ computation_device = device
834
+
835
+ out_T = (T + 3) // 4
836
+ weight = torch.zeros(
837
+ (1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
838
+ dtype=video.dtype,
839
+ device=data_device,
840
+ )
841
+ values = torch.zeros(
842
+ (1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
843
+ dtype=video.dtype,
844
+ device=data_device,
845
+ )
846
+
847
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
848
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
849
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(
850
+ data_device
851
+ )
852
+
853
+ mask = self.build_mask(
854
+ hidden_states_batch,
855
+ is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
856
+ border_width=(
857
+ (size_h - stride_h) // self.upsampling_factor,
858
+ (size_w - stride_w) // self.upsampling_factor,
859
+ ),
860
+ ).to(dtype=video.dtype, device=data_device)
861
+
862
+ target_h = h // self.upsampling_factor
863
+ target_w = w // self.upsampling_factor
864
+ values[
865
+ :,
866
+ :,
867
+ :,
868
+ target_h : target_h + hidden_states_batch.shape[3],
869
+ target_w : target_w + hidden_states_batch.shape[4],
870
+ ] += (
871
+ hidden_states_batch * mask
872
+ )
873
+ weight[
874
+ :,
875
+ :,
876
+ :,
877
+ target_h : target_h + hidden_states_batch.shape[3],
878
+ target_w : target_w + hidden_states_batch.shape[4],
879
+ ] += mask
880
+ values = values / weight
881
+ values = values.float()
882
+ return values
883
+
884
+ def single_encode(self, video, device):
885
+ video = video.to(device)
886
+ x = self.model.encode(video, self.scale)
887
+ return x.float()
888
+
889
+ def single_decode(self, hidden_state, device):
890
+ hidden_state = hidden_state.to(device)
891
+ video = self.model.decode(hidden_state, self.scale)
892
+ return video.float().clamp_(-1, 1)
893
+
894
+ def encode(
895
+ self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)
896
+ ):
897
+ videos = [video.to("cpu") for video in videos]
898
+ hidden_states = []
899
+ for video in videos:
900
+ video = video.unsqueeze(0)
901
+ if tiled:
902
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
903
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
904
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
905
+ else:
906
+ hidden_state = self.single_encode(video, device)
907
+ hidden_state = hidden_state.squeeze(0)
908
+ hidden_states.append(hidden_state)
909
+ hidden_states = torch.stack(hidden_states)
910
+ return hidden_states
911
+
912
+ def decode(
913
+ self,
914
+ hidden_states,
915
+ device,
916
+ tiled=False,
917
+ tile_size=(34, 34),
918
+ tile_stride=(18, 16),
919
+ ):
920
+ hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
921
+ videos = []
922
+ for hidden_state in hidden_states:
923
+ hidden_state = hidden_state.unsqueeze(0)
924
+ if tiled:
925
+ video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
926
+ else:
927
+ video = self.single_decode(hidden_state, device)
928
+ video = video.squeeze(0)
929
+ videos.append(video)
930
+ videos = torch.stack(videos)
931
+ return videos
932
+
933
+ @staticmethod
934
+ def state_dict_converter():
935
+ return WanVideoVAEStateDictConverter()
936
+
937
+
938
+ class WanVideoVAEStateDictConverter:
939
+ def __init__(self):
940
+ pass
941
+
942
+ def from_civitai(self, state_dict):
943
+ state_dict_ = {}
944
+ if "model_state" in state_dict:
945
+ state_dict = state_dict["model_state"]
946
+ for name in state_dict:
947
+ state_dict_["model." + name] = state_dict[name]
948
+ return state_dict_
FantasyTalking/diffsynth/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wan_video import WanVideoPipeline
FantasyTalking/diffsynth/pipelines/base.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision.transforms import GaussianBlur
5
+
6
+
7
+ class BasePipeline(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ device="cuda",
11
+ torch_dtype=torch.float16,
12
+ height_division_factor=64,
13
+ width_division_factor=64,
14
+ ):
15
+ super().__init__()
16
+ self.device = device
17
+ self.torch_dtype = torch_dtype
18
+ self.height_division_factor = height_division_factor
19
+ self.width_division_factor = width_division_factor
20
+ self.cpu_offload = False
21
+ self.model_names = []
22
+
23
+ def check_resize_height_width(self, height, width):
24
+ if height % self.height_division_factor != 0:
25
+ height = (
26
+ (height + self.height_division_factor - 1)
27
+ // self.height_division_factor
28
+ * self.height_division_factor
29
+ )
30
+ print(
31
+ f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}."
32
+ )
33
+ if width % self.width_division_factor != 0:
34
+ width = (
35
+ (width + self.width_division_factor - 1)
36
+ // self.width_division_factor
37
+ * self.width_division_factor
38
+ )
39
+ print(
40
+ f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}."
41
+ )
42
+ return height, width
43
+
44
+ def preprocess_image(self, image):
45
+ image = (
46
+ torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1)
47
+ .permute(2, 0, 1)
48
+ .unsqueeze(0)
49
+ )
50
+ return image
51
+
52
+ def preprocess_images(self, images):
53
+ return [self.preprocess_image(image) for image in images]
54
+
55
+ def vae_output_to_image(self, vae_output):
56
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
57
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
58
+ return image
59
+
60
+ def vae_output_to_video(self, vae_output):
61
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
62
+ video = [
63
+ Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
64
+ for image in video
65
+ ]
66
+ return video
67
+
68
+ def merge_latents(
69
+ self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0
70
+ ):
71
+ if len(latents) > 0:
72
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
73
+ height, width = value.shape[-2:]
74
+ weight = torch.ones_like(value)
75
+ for latent, mask, scale in zip(latents, masks, scales):
76
+ mask = (
77
+ self.preprocess_image(mask.resize((width, height))).mean(
78
+ dim=1, keepdim=True
79
+ )
80
+ > 0
81
+ )
82
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(
83
+ dtype=latent.dtype, device=latent.device
84
+ )
85
+ mask = blur(mask)
86
+ value += latent * mask * scale
87
+ weight += mask * scale
88
+ value /= weight
89
+ return value
90
+
91
+ def control_noise_via_local_prompts(
92
+ self,
93
+ prompt_emb_global,
94
+ prompt_emb_locals,
95
+ masks,
96
+ mask_scales,
97
+ inference_callback,
98
+ special_kwargs=None,
99
+ special_local_kwargs_list=None,
100
+ ):
101
+ if special_kwargs is None:
102
+ noise_pred_global = inference_callback(prompt_emb_global)
103
+ else:
104
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
105
+ if special_local_kwargs_list is None:
106
+ noise_pred_locals = [
107
+ inference_callback(prompt_emb_local)
108
+ for prompt_emb_local in prompt_emb_locals
109
+ ]
110
+ else:
111
+ noise_pred_locals = [
112
+ inference_callback(prompt_emb_local, special_kwargs)
113
+ for prompt_emb_local, special_kwargs in zip(
114
+ prompt_emb_locals, special_local_kwargs_list
115
+ )
116
+ ]
117
+ noise_pred = self.merge_latents(
118
+ noise_pred_global, noise_pred_locals, masks, mask_scales
119
+ )
120
+ return noise_pred
121
+
122
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
123
+ local_prompts = local_prompts or []
124
+ masks = masks or []
125
+ mask_scales = mask_scales or []
126
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
127
+ prompt = extended_prompt_dict.get("prompt", prompt)
128
+ local_prompts += extended_prompt_dict.get("prompts", [])
129
+ masks += extended_prompt_dict.get("masks", [])
130
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
131
+ return prompt, local_prompts, masks, mask_scales
132
+
133
+ def enable_cpu_offload(self):
134
+ self.cpu_offload = True
135
+
136
+ def load_models_to_device(self, loadmodel_names=[]):
137
+ # only load models to device if cpu_offload is enabled
138
+ if not self.cpu_offload:
139
+ return
140
+ # offload the unneeded models to cpu
141
+ for model_name in self.model_names:
142
+ if model_name not in loadmodel_names:
143
+ model = getattr(self, model_name)
144
+ if model is not None:
145
+ if (
146
+ hasattr(model, "vram_management_enabled")
147
+ and model.vram_management_enabled
148
+ ):
149
+ for module in model.modules():
150
+ if hasattr(module, "offload"):
151
+ module.offload()
152
+ else:
153
+ model.cpu()
154
+ # load the needed models to device
155
+ for model_name in loadmodel_names:
156
+ model = getattr(self, model_name)
157
+ if model is not None:
158
+ if (
159
+ hasattr(model, "vram_management_enabled")
160
+ and model.vram_management_enabled
161
+ ):
162
+ for module in model.modules():
163
+ if hasattr(module, "onload"):
164
+ module.onload()
165
+ else:
166
+ model.to(self.device)
167
+ # fresh the cuda cache
168
+ torch.cuda.empty_cache()
169
+
170
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
171
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
172
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
173
+ return noise
FantasyTalking/diffsynth/pipelines/wan_video.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from einops import rearrange
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ from ..models import ModelManager
10
+ from ..models.wan_video_dit import WanLayerNorm, WanModel, WanRMSNorm
11
+ from ..models.wan_video_image_encoder import WanImageEncoder
12
+ from ..models.wan_video_text_encoder import (T5LayerNorm, T5RelativeEmbedding,
13
+ WanTextEncoder)
14
+ from ..models.wan_video_vae import (CausalConv3d, RMS_norm, Upsample,
15
+ WanVideoVAE)
16
+ from ..prompters import WanPrompter
17
+ from ..schedulers.flow_match import FlowMatchScheduler
18
+ from ..vram_management import (AutoWrappedLinear, AutoWrappedModule,
19
+ enable_vram_management)
20
+ from .base import BasePipeline
21
+
22
+
23
+ class WanVideoPipeline(BasePipeline):
24
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
25
+ super().__init__(device=device, torch_dtype=torch_dtype)
26
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
27
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
28
+ self.text_encoder: WanTextEncoder = None
29
+ self.image_encoder: WanImageEncoder = None
30
+ self.dit: WanModel = None
31
+ self.vae: WanVideoVAE = None
32
+ self.model_names = ["text_encoder", "dit", "vae"]
33
+ self.height_division_factor = 16
34
+ self.width_division_factor = 16
35
+
36
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
37
+ dtype = next(iter(self.text_encoder.parameters())).dtype
38
+ enable_vram_management(
39
+ self.text_encoder,
40
+ module_map={
41
+ torch.nn.Linear: AutoWrappedLinear,
42
+ torch.nn.Embedding: AutoWrappedModule,
43
+ T5RelativeEmbedding: AutoWrappedModule,
44
+ T5LayerNorm: AutoWrappedModule,
45
+ },
46
+ module_config=dict(
47
+ offload_dtype=dtype,
48
+ offload_device="cpu",
49
+ onload_dtype=dtype,
50
+ onload_device="cpu",
51
+ computation_dtype=self.torch_dtype,
52
+ computation_device=self.device,
53
+ ),
54
+ )
55
+ dtype = next(iter(self.dit.parameters())).dtype
56
+ enable_vram_management(
57
+ self.dit,
58
+ module_map={
59
+ torch.nn.Linear: AutoWrappedLinear,
60
+ torch.nn.Conv3d: AutoWrappedModule,
61
+ torch.nn.LayerNorm: AutoWrappedModule,
62
+ WanLayerNorm: AutoWrappedModule,
63
+ WanRMSNorm: AutoWrappedModule,
64
+ },
65
+ module_config=dict(
66
+ offload_dtype=dtype,
67
+ offload_device="cpu",
68
+ onload_dtype=dtype,
69
+ onload_device=self.device,
70
+ computation_dtype=self.torch_dtype,
71
+ computation_device=self.device,
72
+ ),
73
+ max_num_param=num_persistent_param_in_dit,
74
+ overflow_module_config=dict(
75
+ offload_dtype=dtype,
76
+ offload_device="cpu",
77
+ onload_dtype=dtype,
78
+ onload_device="cpu",
79
+ computation_dtype=self.torch_dtype,
80
+ computation_device=self.device,
81
+ ),
82
+ )
83
+ dtype = next(iter(self.vae.parameters())).dtype
84
+ enable_vram_management(
85
+ self.vae,
86
+ module_map={
87
+ torch.nn.Linear: AutoWrappedLinear,
88
+ torch.nn.Conv2d: AutoWrappedModule,
89
+ RMS_norm: AutoWrappedModule,
90
+ CausalConv3d: AutoWrappedModule,
91
+ Upsample: AutoWrappedModule,
92
+ torch.nn.SiLU: AutoWrappedModule,
93
+ torch.nn.Dropout: AutoWrappedModule,
94
+ },
95
+ module_config=dict(
96
+ offload_dtype=dtype,
97
+ offload_device="cpu",
98
+ onload_dtype=dtype,
99
+ onload_device=self.device,
100
+ computation_dtype=self.torch_dtype,
101
+ computation_device=self.device,
102
+ ),
103
+ )
104
+ if self.image_encoder is not None:
105
+ dtype = next(iter(self.image_encoder.parameters())).dtype
106
+ enable_vram_management(
107
+ self.image_encoder,
108
+ module_map={
109
+ torch.nn.Linear: AutoWrappedLinear,
110
+ torch.nn.Conv2d: AutoWrappedModule,
111
+ torch.nn.LayerNorm: AutoWrappedModule,
112
+ },
113
+ module_config=dict(
114
+ offload_dtype=dtype,
115
+ offload_device="cpu",
116
+ onload_dtype=dtype,
117
+ onload_device="cpu",
118
+ computation_dtype=self.torch_dtype,
119
+ computation_device=self.device,
120
+ ),
121
+ )
122
+ self.enable_cpu_offload()
123
+
124
+ def fetch_models(self, model_manager: ModelManager):
125
+ text_encoder_model_and_path = model_manager.fetch_model(
126
+ "wan_video_text_encoder", require_model_path=True
127
+ )
128
+ if text_encoder_model_and_path is not None:
129
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
130
+ self.prompter.fetch_models(self.text_encoder)
131
+ self.prompter.fetch_tokenizer(
132
+ os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl")
133
+ )
134
+ self.dit = model_manager.fetch_model("wan_video_dit")
135
+ self.vae = model_manager.fetch_model("wan_video_vae")
136
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
137
+
138
+ @staticmethod
139
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
140
+ if device is None:
141
+ device = model_manager.device
142
+ if torch_dtype is None:
143
+ torch_dtype = model_manager.torch_dtype
144
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
145
+ pipe.fetch_models(model_manager)
146
+ return pipe
147
+
148
+ def denoising_model(self):
149
+ return self.dit
150
+
151
+ def encode_prompt(self, prompt, positive=True):
152
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
153
+ return {"context": prompt_emb}
154
+
155
+ def encode_image(self, image, num_frames, height, width):
156
+ with torch.amp.autocast(
157
+ dtype=torch.bfloat16, device_type=torch.device(self.device).type
158
+ ):
159
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
160
+ clip_context = self.image_encoder.encode_image([image])
161
+ msk = torch.ones(1, num_frames, height // 8, width // 8, device=self.device)
162
+ msk[:, 1:] = 0
163
+ msk = torch.concat(
164
+ [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]],
165
+ dim=1,
166
+ )
167
+ msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
168
+ msk = msk.transpose(1, 2)[0]
169
+ y = self.vae.encode(
170
+ [
171
+ torch.concat(
172
+ [
173
+ image.transpose(0, 1),
174
+ torch.zeros(3, num_frames - 1, height, width).to(
175
+ image.device
176
+ ),
177
+ ],
178
+ dim=1,
179
+ )
180
+ ],
181
+ device=self.device,
182
+ )[0]
183
+ y = torch.concat([msk, y])
184
+ return {"clip_fea": clip_context, "y": [y]}
185
+
186
+ def tensor2video(self, frames):
187
+ frames = rearrange(frames, "C T H W -> T H W C")
188
+ frames = (
189
+ ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
190
+ )
191
+ frames = [Image.fromarray(frame) for frame in frames]
192
+ return frames
193
+
194
+ def prepare_extra_input(self, latents=None):
195
+ return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
196
+
197
+ def encode_video(
198
+ self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
199
+ ):
200
+ with torch.amp.autocast(
201
+ dtype=torch.bfloat16, device_type=torch.device(self.device).type
202
+ ):
203
+ latents = self.vae.encode(
204
+ input_video,
205
+ device=self.device,
206
+ tiled=tiled,
207
+ tile_size=tile_size,
208
+ tile_stride=tile_stride,
209
+ )
210
+ return latents
211
+
212
+ def decode_video(
213
+ self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
214
+ ):
215
+ with torch.amp.autocast(
216
+ dtype=torch.bfloat16, device_type=torch.device(self.device).type
217
+ ):
218
+ frames = self.vae.decode(
219
+ latents,
220
+ device=self.device,
221
+ tiled=tiled,
222
+ tile_size=tile_size,
223
+ tile_stride=tile_stride,
224
+ )
225
+ return frames
226
+
227
+ def set_ip(self, local_path):
228
+ pass
229
+
230
+ @torch.no_grad()
231
+ def __call__(
232
+ self,
233
+ prompt,
234
+ negative_prompt="",
235
+ input_image=None,
236
+ input_video=None,
237
+ denoising_strength=1.0,
238
+ seed=None,
239
+ rand_device="cpu",
240
+ height=480,
241
+ width=832,
242
+ num_frames=81,
243
+ cfg_scale=5.0,
244
+ audio_cfg_scale=None,
245
+ num_inference_steps=50,
246
+ sigma_shift=5.0,
247
+ tiled=True,
248
+ tile_size=(30, 52),
249
+ tile_stride=(15, 26),
250
+ progress_bar_cmd=tqdm,
251
+ progress_bar_st=None,
252
+ **kwargs,
253
+ ):
254
+ # Parameter check
255
+ height, width = self.check_resize_height_width(height, width)
256
+ if num_frames % 4 != 1:
257
+ num_frames = (num_frames + 2) // 4 * 4 + 1
258
+ print(
259
+ f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}."
260
+ )
261
+
262
+ # Tiler parameters
263
+ tiler_kwargs = {
264
+ "tiled": tiled,
265
+ "tile_size": tile_size,
266
+ "tile_stride": tile_stride,
267
+ }
268
+
269
+ # Scheduler
270
+ self.scheduler.set_timesteps(
271
+ num_inference_steps, denoising_strength, shift=sigma_shift
272
+ )
273
+
274
+ # Initialize noise
275
+ noise = self.generate_noise(
276
+ (1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8),
277
+ seed=seed,
278
+ device=rand_device,
279
+ dtype=torch.float32,
280
+ ).to(self.device)
281
+ if input_video is not None:
282
+ self.load_models_to_device(["vae"])
283
+ input_video = self.preprocess_images(input_video)
284
+ input_video = torch.stack(input_video, dim=2)
285
+ latents = self.encode_video(input_video, **tiler_kwargs).to(
286
+ dtype=noise.dtype, device=noise.device
287
+ )
288
+ latents = self.scheduler.add_noise(
289
+ latents, noise, timestep=self.scheduler.timesteps[0]
290
+ )
291
+ else:
292
+ latents = noise
293
+
294
+ # Encode prompts
295
+ self.load_models_to_device(["text_encoder"])
296
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
297
+ if cfg_scale != 1.0:
298
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
299
+
300
+ # Encode image
301
+ if input_image is not None and self.image_encoder is not None:
302
+ self.load_models_to_device(["image_encoder", "vae"])
303
+ image_emb = self.encode_image(input_image, num_frames, height, width)
304
+ else:
305
+ image_emb = {}
306
+
307
+ # Extra input
308
+ extra_input = self.prepare_extra_input(latents)
309
+
310
+ # Denoise
311
+ self.load_models_to_device(["dit"])
312
+ with torch.amp.autocast(
313
+ dtype=torch.bfloat16, device_type=torch.device(self.device).type
314
+ ):
315
+ for progress_id, timestep in enumerate(
316
+ progress_bar_cmd(self.scheduler.timesteps)
317
+ ):
318
+ timestep = timestep.unsqueeze(0).to(
319
+ dtype=torch.float32, device=self.device
320
+ )
321
+
322
+ # Inference
323
+ noise_pred_posi = self.dit(
324
+ latents,
325
+ timestep=timestep,
326
+ **prompt_emb_posi,
327
+ **image_emb,
328
+ **extra_input,
329
+ **kwargs,
330
+ ) # (zt,audio,prompt)
331
+ if audio_cfg_scale is not None:
332
+ audio_scale = kwargs["audio_scale"]
333
+ kwargs["audio_scale"] = 0.0
334
+ noise_pred_noaudio = self.dit(
335
+ latents,
336
+ timestep=timestep,
337
+ **prompt_emb_posi,
338
+ **image_emb,
339
+ **extra_input,
340
+ **kwargs,
341
+ ) # (zt,0,prompt)
342
+ # kwargs['ip_scale'] = ip_scale
343
+ if cfg_scale != 1.0: # prompt cfg
344
+ noise_pred_no_cond = self.dit(
345
+ latents,
346
+ timestep=timestep,
347
+ **prompt_emb_nega,
348
+ **image_emb,
349
+ **extra_input,
350
+ **kwargs,
351
+ ) # (zt,0,0)
352
+ noise_pred = (
353
+ noise_pred_no_cond
354
+ + cfg_scale * (noise_pred_noaudio - noise_pred_no_cond)
355
+ + audio_cfg_scale * (noise_pred_posi - noise_pred_noaudio)
356
+ )
357
+ else:
358
+ noise_pred = noise_pred_noaudio + audio_cfg_scale * (
359
+ noise_pred_posi - noise_pred_noaudio
360
+ )
361
+ kwargs["audio_scale"] = audio_scale
362
+ else:
363
+ if cfg_scale != 1.0:
364
+ noise_pred_nega = self.dit(
365
+ latents,
366
+ timestep=timestep,
367
+ **prompt_emb_nega,
368
+ **image_emb,
369
+ **extra_input,
370
+ **kwargs,
371
+ ) # (zt,audio,0)
372
+ noise_pred = noise_pred_nega + cfg_scale * (
373
+ noise_pred_posi - noise_pred_nega
374
+ )
375
+ else:
376
+ noise_pred = noise_pred_posi
377
+
378
+ # Scheduler
379
+ latents = self.scheduler.step(
380
+ noise_pred, self.scheduler.timesteps[progress_id], latents
381
+ )
382
+
383
+ # Decode
384
+ self.load_models_to_device(["vae"])
385
+ frames = self.decode_video(latents, **tiler_kwargs)
386
+ self.load_models_to_device([])
387
+ frames = self.tensor2video(frames[0])
388
+
389
+ return frames
FantasyTalking/diffsynth/prompters/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wan_prompter import WanPrompter
FantasyTalking/diffsynth/prompters/base_prompter.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..models.model_manager import ModelManager
4
+
5
+
6
+ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
7
+ # Get model_max_length from self.tokenizer
8
+ length = tokenizer.model_max_length if max_length is None else max_length
9
+
10
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
11
+ tokenizer.model_max_length = 99999999
12
+
13
+ # Tokenize it!
14
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
15
+
16
+ # Determine the real length.
17
+ max_length = (input_ids.shape[1] + length - 1) // length * length
18
+
19
+ # Restore tokenizer.model_max_length
20
+ tokenizer.model_max_length = length
21
+
22
+ # Tokenize it again with fixed length.
23
+ input_ids = tokenizer(
24
+ prompt,
25
+ return_tensors="pt",
26
+ padding="max_length",
27
+ max_length=max_length,
28
+ truncation=True,
29
+ ).input_ids
30
+
31
+ # Reshape input_ids to fit the text encoder.
32
+ num_sentence = input_ids.shape[1] // length
33
+ input_ids = input_ids.reshape((num_sentence, length))
34
+
35
+ return input_ids
36
+
37
+
38
+ class BasePrompter:
39
+ def __init__(self):
40
+ self.refiners = []
41
+ self.extenders = []
42
+
43
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
44
+ for refiner_class in refiner_classes:
45
+ refiner = refiner_class.from_model_manager(model_manager)
46
+ self.refiners.append(refiner)
47
+
48
+ def load_prompt_extenders(self, model_manager: ModelManager, extender_classes=[]):
49
+ for extender_class in extender_classes:
50
+ extender = extender_class.from_model_manager(model_manager)
51
+ self.extenders.append(extender)
52
+
53
+ @torch.no_grad()
54
+ def process_prompt(self, prompt, positive=True):
55
+ if isinstance(prompt, list):
56
+ prompt = [
57
+ self.process_prompt(prompt_, positive=positive) for prompt_ in prompt
58
+ ]
59
+ else:
60
+ for refiner in self.refiners:
61
+ prompt = refiner(prompt, positive=positive)
62
+ return prompt
63
+
64
+ @torch.no_grad()
65
+ def extend_prompt(self, prompt: str, positive=True):
66
+ extended_prompt = dict(prompt=prompt)
67
+ for extender in self.extenders:
68
+ extended_prompt = extender(extended_prompt)
69
+ return extended_prompt
FantasyTalking/diffsynth/prompters/wan_prompter.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import os
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ import torch
8
+ from transformers import AutoTokenizer
9
+
10
+ from ..models.wan_video_text_encoder import WanTextEncoder
11
+ from .base_prompter import BasePrompter
12
+
13
+
14
+ def basic_clean(text):
15
+ text = ftfy.fix_text(text)
16
+ text = html.unescape(html.unescape(text))
17
+ return text.strip()
18
+
19
+
20
+ def whitespace_clean(text):
21
+ text = re.sub(r"\s+", " ", text)
22
+ text = text.strip()
23
+ return text
24
+
25
+
26
+ def canonicalize(text, keep_punctuation_exact_string=None):
27
+ text = text.replace("_", " ")
28
+ if keep_punctuation_exact_string:
29
+ text = keep_punctuation_exact_string.join(
30
+ part.translate(str.maketrans("", "", string.punctuation))
31
+ for part in text.split(keep_punctuation_exact_string)
32
+ )
33
+ else:
34
+ text = text.translate(str.maketrans("", "", string.punctuation))
35
+ text = text.lower()
36
+ text = re.sub(r"\s+", " ", text)
37
+ return text.strip()
38
+
39
+
40
+ class HuggingfaceTokenizer:
41
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
42
+ assert clean in (None, "whitespace", "lower", "canonicalize")
43
+ self.name = name
44
+ self.seq_len = seq_len
45
+ self.clean = clean
46
+
47
+ # init tokenizer
48
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
49
+ self.vocab_size = self.tokenizer.vocab_size
50
+
51
+ def __call__(self, sequence, **kwargs):
52
+ return_mask = kwargs.pop("return_mask", False)
53
+
54
+ # arguments
55
+ _kwargs = {"return_tensors": "pt"}
56
+ if self.seq_len is not None:
57
+ _kwargs.update(
58
+ {
59
+ "padding": "max_length",
60
+ "truncation": True,
61
+ "max_length": self.seq_len,
62
+ }
63
+ )
64
+ _kwargs.update(**kwargs)
65
+
66
+ # tokenization
67
+ if isinstance(sequence, str):
68
+ sequence = [sequence]
69
+ if self.clean:
70
+ sequence = [self._clean(u) for u in sequence]
71
+ ids = self.tokenizer(sequence, **_kwargs)
72
+
73
+ # output
74
+ if return_mask:
75
+ return ids.input_ids, ids.attention_mask
76
+ else:
77
+ return ids.input_ids
78
+
79
+ def _clean(self, text):
80
+ if self.clean == "whitespace":
81
+ text = whitespace_clean(basic_clean(text))
82
+ elif self.clean == "lower":
83
+ text = whitespace_clean(basic_clean(text)).lower()
84
+ elif self.clean == "canonicalize":
85
+ text = canonicalize(basic_clean(text))
86
+ return text
87
+
88
+
89
+ class WanPrompter(BasePrompter):
90
+ def __init__(self, tokenizer_path=None, text_len=512):
91
+ super().__init__()
92
+ self.text_len = text_len
93
+ self.text_encoder = None
94
+ self.fetch_tokenizer(tokenizer_path)
95
+
96
+ def fetch_tokenizer(self, tokenizer_path=None):
97
+ if tokenizer_path is not None:
98
+ self.tokenizer = HuggingfaceTokenizer(
99
+ name=tokenizer_path, seq_len=self.text_len, clean="whitespace"
100
+ )
101
+
102
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
103
+ self.text_encoder = text_encoder
104
+
105
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
106
+ prompt = self.process_prompt(prompt, positive=positive)
107
+
108
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
109
+ ids = ids.to(device)
110
+ mask = mask.to(device)
111
+ seq_lens = mask.gt(0).sum(dim=1).long()
112
+ prompt_emb = self.text_encoder(ids, mask)
113
+ prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
114
+ return prompt_emb
FantasyTalking/diffsynth/schedulers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .continuous_ode import ContinuousODEScheduler
2
+ from .ddim import EnhancedDDIMScheduler
3
+ from .flow_match import FlowMatchScheduler
FantasyTalking/diffsynth/schedulers/continuous_ode.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ContinuousODEScheduler:
5
+ def __init__(
6
+ self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0
7
+ ):
8
+ self.sigma_max = sigma_max
9
+ self.sigma_min = sigma_min
10
+ self.rho = rho
11
+ self.set_timesteps(num_inference_steps)
12
+
13
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
14
+ ramp = torch.linspace(1 - denoising_strength, 1, num_inference_steps)
15
+ min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
16
+ max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
17
+ self.sigmas = torch.pow(
18
+ max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho
19
+ )
20
+ self.timesteps = torch.log(self.sigmas) * 0.25
21
+
22
+ def step(self, model_output, timestep, sample, to_final=False):
23
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
24
+ sigma = self.sigmas[timestep_id]
25
+ sample *= (sigma * sigma + 1).sqrt()
26
+ estimated_sample = (
27
+ -sigma / (sigma * sigma + 1).sqrt() * model_output
28
+ + 1 / (sigma * sigma + 1) * sample
29
+ )
30
+ if to_final or timestep_id + 1 >= len(self.timesteps):
31
+ prev_sample = estimated_sample
32
+ else:
33
+ sigma_ = self.sigmas[timestep_id + 1]
34
+ derivative = 1 / sigma * (sample - estimated_sample)
35
+ prev_sample = sample + derivative * (sigma_ - sigma)
36
+ prev_sample /= (sigma_ * sigma_ + 1).sqrt()
37
+ return prev_sample
38
+
39
+ def return_to_timestep(self, timestep, sample, sample_stablized):
40
+ # This scheduler doesn't support this function.
41
+ pass
42
+
43
+ def add_noise(self, original_samples, noise, timestep):
44
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
45
+ sigma = self.sigmas[timestep_id]
46
+ sample = (original_samples + noise * sigma) / (sigma * sigma + 1).sqrt()
47
+ return sample
48
+
49
+ def training_target(self, sample, noise, timestep):
50
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
51
+ sigma = self.sigmas[timestep_id]
52
+ target = (
53
+ -(sigma * sigma + 1).sqrt() / sigma + 1 / (sigma * sigma + 1).sqrt() / sigma
54
+ ) * sample + 1 / (sigma * sigma + 1).sqrt() * noise
55
+ return target
56
+
57
+ def training_weight(self, timestep):
58
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
59
+ sigma = self.sigmas[timestep_id]
60
+ weight = (1 + sigma * sigma).sqrt() / sigma
61
+ return weight
FantasyTalking/diffsynth/schedulers/ddim.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ class EnhancedDDIMScheduler:
7
+ def __init__(
8
+ self,
9
+ num_train_timesteps=1000,
10
+ beta_start=0.00085,
11
+ beta_end=0.012,
12
+ beta_schedule="scaled_linear",
13
+ prediction_type="epsilon",
14
+ rescale_zero_terminal_snr=False,
15
+ ):
16
+ self.num_train_timesteps = num_train_timesteps
17
+ if beta_schedule == "scaled_linear":
18
+ betas = torch.square(
19
+ torch.linspace(
20
+ math.sqrt(beta_start),
21
+ math.sqrt(beta_end),
22
+ num_train_timesteps,
23
+ dtype=torch.float32,
24
+ )
25
+ )
26
+ elif beta_schedule == "linear":
27
+ betas = torch.linspace(
28
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
29
+ )
30
+ else:
31
+ raise NotImplementedError(f"{beta_schedule} is not implemented")
32
+ self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
33
+ if rescale_zero_terminal_snr:
34
+ self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
35
+ self.alphas_cumprod = self.alphas_cumprod.tolist()
36
+ self.set_timesteps(10)
37
+ self.prediction_type = prediction_type
38
+
39
+ def rescale_zero_terminal_snr(self, alphas_cumprod):
40
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
41
+
42
+ # Store old values.
43
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
44
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
45
+
46
+ # Shift so the last timestep is zero.
47
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
48
+
49
+ # Scale so the first timestep is back to the old value.
50
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
51
+
52
+ # Convert alphas_bar_sqrt to betas
53
+ alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
54
+
55
+ return alphas_bar
56
+
57
+ def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
58
+ # The timesteps are aligned to 999...0, which is different from other implementations,
59
+ # but I think this implementation is more reasonable in theory.
60
+ max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
61
+ num_inference_steps = min(num_inference_steps, max_timestep + 1)
62
+ if num_inference_steps == 1:
63
+ self.timesteps = torch.Tensor([max_timestep])
64
+ else:
65
+ step_length = max_timestep / (num_inference_steps - 1)
66
+ self.timesteps = torch.Tensor(
67
+ [
68
+ round(max_timestep - i * step_length)
69
+ for i in range(num_inference_steps)
70
+ ]
71
+ )
72
+
73
+ def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
74
+ if self.prediction_type == "epsilon":
75
+ weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(
76
+ alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t
77
+ )
78
+ weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
79
+ prev_sample = sample * weight_x + model_output * weight_e
80
+ elif self.prediction_type == "v_prediction":
81
+ weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(
82
+ alpha_prod_t * (1 - alpha_prod_t_prev)
83
+ )
84
+ weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt(
85
+ (1 - alpha_prod_t) * (1 - alpha_prod_t_prev)
86
+ )
87
+ prev_sample = sample * weight_x + model_output * weight_e
88
+ else:
89
+ raise NotImplementedError(f"{self.prediction_type} is not implemented")
90
+ return prev_sample
91
+
92
+ def step(self, model_output, timestep, sample, to_final=False):
93
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
94
+ if isinstance(timestep, torch.Tensor):
95
+ timestep = timestep.cpu()
96
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
97
+ if to_final or timestep_id + 1 >= len(self.timesteps):
98
+ alpha_prod_t_prev = 1.0
99
+ else:
100
+ timestep_prev = int(self.timesteps[timestep_id + 1])
101
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
102
+
103
+ return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
104
+
105
+ def return_to_timestep(self, timestep, sample, sample_stablized):
106
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
107
+ noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(
108
+ 1 - alpha_prod_t
109
+ )
110
+ return noise_pred
111
+
112
+ def add_noise(self, original_samples, noise, timestep):
113
+ sqrt_alpha_prod = math.sqrt(
114
+ self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
115
+ )
116
+ sqrt_one_minus_alpha_prod = math.sqrt(
117
+ 1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
118
+ )
119
+ noisy_samples = (
120
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
121
+ )
122
+ return noisy_samples
123
+
124
+ def training_target(self, sample, noise, timestep):
125
+ if self.prediction_type == "epsilon":
126
+ return noise
127
+ else:
128
+ sqrt_alpha_prod = math.sqrt(
129
+ self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
130
+ )
131
+ sqrt_one_minus_alpha_prod = math.sqrt(
132
+ 1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
133
+ )
134
+ target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
135
+ return target
136
+
137
+ def training_weight(self, timestep):
138
+ return 1.0
FantasyTalking/diffsynth/schedulers/flow_match.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class FlowMatchScheduler:
5
+ def __init__(
6
+ self,
7
+ num_inference_steps=100,
8
+ num_train_timesteps=1000,
9
+ shift=3.0,
10
+ sigma_max=1.0,
11
+ sigma_min=0.003 / 1.002,
12
+ inverse_timesteps=False,
13
+ extra_one_step=False,
14
+ reverse_sigmas=False,
15
+ ):
16
+ self.num_train_timesteps = num_train_timesteps
17
+ self.shift = shift
18
+ self.sigma_max = sigma_max
19
+ self.sigma_min = sigma_min
20
+ self.inverse_timesteps = inverse_timesteps
21
+ self.extra_one_step = extra_one_step
22
+ self.reverse_sigmas = reverse_sigmas
23
+ self.set_timesteps(num_inference_steps)
24
+
25
+ def set_timesteps(
26
+ self,
27
+ num_inference_steps=100,
28
+ denoising_strength=1.0,
29
+ training=False,
30
+ shift=None,
31
+ ):
32
+ if shift is not None:
33
+ self.shift = shift
34
+ sigma_start = (
35
+ self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
36
+ )
37
+ if self.extra_one_step:
38
+ self.sigmas = torch.linspace(
39
+ sigma_start, self.sigma_min, num_inference_steps + 1
40
+ )[:-1]
41
+ else:
42
+ self.sigmas = torch.linspace(
43
+ sigma_start, self.sigma_min, num_inference_steps
44
+ )
45
+ if self.inverse_timesteps:
46
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
47
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
48
+ if self.reverse_sigmas:
49
+ self.sigmas = 1 - self.sigmas
50
+ self.timesteps = self.sigmas * self.num_train_timesteps
51
+ if training:
52
+ x = self.timesteps
53
+ y = torch.exp(
54
+ -2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2
55
+ )
56
+ y_shifted = y - y.min()
57
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
58
+ self.linear_timesteps_weights = bsmntw_weighing
59
+
60
+ def step(self, model_output, timestep, sample, to_final=False):
61
+ if isinstance(timestep, torch.Tensor):
62
+ timestep = timestep.cpu()
63
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
64
+ sigma = self.sigmas[timestep_id]
65
+ if to_final or timestep_id + 1 >= len(self.timesteps):
66
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
67
+ else:
68
+ sigma_ = self.sigmas[timestep_id + 1]
69
+ prev_sample = sample + model_output * (sigma_ - sigma)
70
+ return prev_sample
71
+
72
+ def return_to_timestep(self, timestep, sample, sample_stablized):
73
+ if isinstance(timestep, torch.Tensor):
74
+ timestep = timestep.cpu()
75
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
76
+ sigma = self.sigmas[timestep_id]
77
+ model_output = (sample - sample_stablized) / sigma
78
+ return model_output
79
+
80
+ def add_noise(self, original_samples, noise, timestep):
81
+ if isinstance(timestep, torch.Tensor):
82
+ timestep = timestep.cpu()
83
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
84
+ sigma = self.sigmas[timestep_id]
85
+ sample = (1 - sigma) * original_samples + sigma * noise
86
+ return sample
87
+
88
+ def training_target(self, sample, noise, timestep):
89
+ target = noise - sample
90
+ return target
91
+
92
+ def training_weight(self, timestep):
93
+ timestep_id = torch.argmin(
94
+ (self.timesteps - timestep.to(self.timesteps.device)).abs()
95
+ )
96
+ weights = self.linear_timesteps_weights[timestep_id]
97
+ return weights
FantasyTalking/diffsynth/vram_management/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layers import *
FantasyTalking/diffsynth/vram_management/layers.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import torch
4
+
5
+ from ..models.utils import init_weights_on_device
6
+
7
+
8
+ def cast_to(weight, dtype, device):
9
+ r = torch.empty_like(weight, dtype=dtype, device=device)
10
+ r.copy_(weight)
11
+ return r
12
+
13
+
14
+ class AutoWrappedModule(torch.nn.Module):
15
+ def __init__(
16
+ self,
17
+ module: torch.nn.Module,
18
+ offload_dtype,
19
+ offload_device,
20
+ onload_dtype,
21
+ onload_device,
22
+ computation_dtype,
23
+ computation_device,
24
+ ):
25
+ super().__init__()
26
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
27
+ self.offload_dtype = offload_dtype
28
+ self.offload_device = offload_device
29
+ self.onload_dtype = onload_dtype
30
+ self.onload_device = onload_device
31
+ self.computation_dtype = computation_dtype
32
+ self.computation_device = computation_device
33
+ self.state = 0
34
+
35
+ def offload(self):
36
+ if self.state == 1 and (
37
+ self.offload_dtype != self.onload_dtype
38
+ or self.offload_device != self.onload_device
39
+ ):
40
+ self.module.to(dtype=self.offload_dtype, device=self.offload_device)
41
+ self.state = 0
42
+
43
+ def onload(self):
44
+ if self.state == 0 and (
45
+ self.offload_dtype != self.onload_dtype
46
+ or self.offload_device != self.onload_device
47
+ ):
48
+ self.module.to(dtype=self.onload_dtype, device=self.onload_device)
49
+ self.state = 1
50
+
51
+ def forward(self, *args, **kwargs):
52
+ if (
53
+ self.onload_dtype == self.computation_dtype
54
+ and self.onload_device == self.computation_device
55
+ ):
56
+ module = self.module
57
+ else:
58
+ module = copy.deepcopy(self.module).to(
59
+ dtype=self.computation_dtype, device=self.computation_device
60
+ )
61
+ return module(*args, **kwargs)
62
+
63
+
64
+ class AutoWrappedLinear(torch.nn.Linear):
65
+ def __init__(
66
+ self,
67
+ module: torch.nn.Linear,
68
+ offload_dtype,
69
+ offload_device,
70
+ onload_dtype,
71
+ onload_device,
72
+ computation_dtype,
73
+ computation_device,
74
+ ):
75
+ with init_weights_on_device(device=torch.device("meta")):
76
+ super().__init__(
77
+ in_features=module.in_features,
78
+ out_features=module.out_features,
79
+ bias=module.bias is not None,
80
+ dtype=offload_dtype,
81
+ device=offload_device,
82
+ )
83
+ self.weight = module.weight
84
+ self.bias = module.bias
85
+ self.offload_dtype = offload_dtype
86
+ self.offload_device = offload_device
87
+ self.onload_dtype = onload_dtype
88
+ self.onload_device = onload_device
89
+ self.computation_dtype = computation_dtype
90
+ self.computation_device = computation_device
91
+ self.state = 0
92
+
93
+ def offload(self):
94
+ if self.state == 1 and (
95
+ self.offload_dtype != self.onload_dtype
96
+ or self.offload_device != self.onload_device
97
+ ):
98
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
99
+ self.state = 0
100
+
101
+ def onload(self):
102
+ if self.state == 0 and (
103
+ self.offload_dtype != self.onload_dtype
104
+ or self.offload_device != self.onload_device
105
+ ):
106
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
107
+ self.state = 1
108
+
109
+ def forward(self, x, *args, **kwargs):
110
+ if (
111
+ self.onload_dtype == self.computation_dtype
112
+ and self.onload_device == self.computation_device
113
+ ):
114
+ weight, bias = self.weight, self.bias
115
+ else:
116
+ weight = cast_to(
117
+ self.weight, self.computation_dtype, self.computation_device
118
+ )
119
+ bias = (
120
+ None
121
+ if self.bias is None
122
+ else cast_to(self.bias, self.computation_dtype, self.computation_device)
123
+ )
124
+ return torch.nn.functional.linear(x, weight, bias)
125
+
126
+
127
+ def enable_vram_management_recursively(
128
+ model: torch.nn.Module,
129
+ module_map: dict,
130
+ module_config: dict,
131
+ max_num_param=None,
132
+ overflow_module_config: dict = None,
133
+ total_num_param=0,
134
+ ):
135
+ for name, module in model.named_children():
136
+ for source_module, target_module in module_map.items():
137
+ if isinstance(module, source_module):
138
+ num_param = sum(p.numel() for p in module.parameters())
139
+ if (
140
+ max_num_param is not None
141
+ and total_num_param + num_param > max_num_param
142
+ ):
143
+ module_config_ = overflow_module_config
144
+ else:
145
+ module_config_ = module_config
146
+ module_ = target_module(module, **module_config_)
147
+ setattr(model, name, module_)
148
+ total_num_param += num_param
149
+ break
150
+ else:
151
+ total_num_param = enable_vram_management_recursively(
152
+ module,
153
+ module_map,
154
+ module_config,
155
+ max_num_param,
156
+ overflow_module_config,
157
+ total_num_param,
158
+ )
159
+ return total_num_param
160
+
161
+
162
+ def enable_vram_management(
163
+ model: torch.nn.Module,
164
+ module_map: dict,
165
+ module_config: dict,
166
+ max_num_param=None,
167
+ overflow_module_config: dict = None,
168
+ ):
169
+ enable_vram_management_recursively(
170
+ model,
171
+ module_map,
172
+ module_config,
173
+ max_num_param,
174
+ overflow_module_config,
175
+ total_num_param=0,
176
+ )
177
+ model.vram_management_enabled = True
FantasyTalking/infer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Alibaba Inc. All Rights Reserved.
2
+
3
+ import argparse
4
+ import os
5
+ import subprocess
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import librosa
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
14
+
15
+ from diffsynth import ModelManager, WanVideoPipeline
16
+ from model import FantasyTalkingAudioConditionModel
17
+ from utils import get_audio_features, resize_image_by_longest_edge, save_video
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
22
+
23
+ parser.add_argument(
24
+ "--wan_model_dir",
25
+ type=str,
26
+ default="./models/Wan2.1-I2V-14B-720P",
27
+ required=False,
28
+ help="The dir of the Wan I2V 14B model.",
29
+ )
30
+ parser.add_argument(
31
+ "--fantasytalking_model_path",
32
+ type=str,
33
+ default="./models/fantasytalking_model.ckpt",
34
+ required=False,
35
+ help="The .ckpt path of fantasytalking model.",
36
+ )
37
+ parser.add_argument(
38
+ "--wav2vec_model_dir",
39
+ type=str,
40
+ default="./models/wav2vec2-base-960h",
41
+ required=False,
42
+ help="The dir of wav2vec model.",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--image_path",
47
+ type=str,
48
+ default="./assets/images/woman.png",
49
+ required=False,
50
+ help="The path of the image.",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--audio_path",
55
+ type=str,
56
+ default="./assets/audios/woman.wav",
57
+ required=False,
58
+ help="The path of the audio.",
59
+ )
60
+ parser.add_argument(
61
+ "--prompt",
62
+ type=str,
63
+ default="A woman is talking.",
64
+ required=False,
65
+ help="prompt.",
66
+ )
67
+ parser.add_argument(
68
+ "--output_dir",
69
+ type=str,
70
+ default="./output",
71
+ help="Dir to save the model.",
72
+ )
73
+ parser.add_argument(
74
+ "--image_size",
75
+ type=int,
76
+ default=512,
77
+ help="The image will be resized proportionally to this size.",
78
+ )
79
+ parser.add_argument(
80
+ "--audio_scale",
81
+ type=float,
82
+ default=1.0,
83
+ help="Audio condition injection weight",
84
+ )
85
+ parser.add_argument(
86
+ "--prompt_cfg_scale",
87
+ type=float,
88
+ default=5.0,
89
+ required=False,
90
+ help="Prompt cfg scale",
91
+ )
92
+ parser.add_argument(
93
+ "--audio_cfg_scale",
94
+ type=float,
95
+ default=5.0,
96
+ required=False,
97
+ help="Audio cfg scale",
98
+ )
99
+ parser.add_argument(
100
+ "--max_num_frames",
101
+ type=int,
102
+ default=81,
103
+ required=False,
104
+ help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
105
+ )
106
+ parser.add_argument(
107
+ "--fps",
108
+ type=int,
109
+ default=23,
110
+ required=False,
111
+ )
112
+ parser.add_argument(
113
+ "--num_persistent_param_in_dit",
114
+ type=int,
115
+ default=None,
116
+ required=False,
117
+ help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required",
118
+ )
119
+ parser.add_argument(
120
+ "--seed",
121
+ type=int,
122
+ default=1111,
123
+ required=False,
124
+ )
125
+ args = parser.parse_args()
126
+ return args
127
+
128
+
129
+ def load_models(args):
130
+ # Load Wan I2V models
131
+ model_manager = ModelManager(device="cpu")
132
+ model_manager.load_models(
133
+ [
134
+ [
135
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors",
136
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors",
137
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors",
138
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors",
139
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors",
140
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors",
141
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors",
142
+ ],
143
+ f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
144
+ f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
145
+ f"{args.wan_model_dir}/Wan2.1_VAE.pth",
146
+ ],
147
+ # torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
148
+ torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
149
+ )
150
+ pipe = WanVideoPipeline.from_model_manager(
151
+ model_manager, torch_dtype=torch.bfloat16, device="cuda"
152
+ )
153
+
154
+ # Load FantasyTalking weights
155
+ fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
156
+ fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
157
+
158
+ # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
159
+ pipe.enable_vram_management(
160
+ num_persistent_param_in_dit=args.num_persistent_param_in_dit
161
+ )
162
+
163
+ # Load wav2vec models
164
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
165
+ wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
166
+
167
+ return pipe, fantasytalking, wav2vec_processor, wav2vec
168
+
169
+
170
+ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
171
+ os.makedirs(args.output_dir, exist_ok=True)
172
+
173
+ duration = librosa.get_duration(filename=args.audio_path)
174
+ num_frames = min(int(args.fps * duration // 4) * 4 + 5, args.max_num_frames)
175
+
176
+ audio_wav2vec_fea = get_audio_features(
177
+ wav2vec, wav2vec_processor, args.audio_path, args.fps, num_frames
178
+ )
179
+ image = resize_image_by_longest_edge(args.image_path, args.image_size)
180
+ width, height = image.size
181
+
182
+ audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
183
+ pos_idx_ranges = fantasytalking.split_audio_sequence(
184
+ audio_proj_fea.size(1), num_frames=num_frames
185
+ )
186
+ audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding(
187
+ audio_proj_fea, pos_idx_ranges, expand_length=4
188
+ ) # [b,21,9+8,768]
189
+
190
+ # Image-to-video
191
+ video_audio = pipe(
192
+ prompt=args.prompt,
193
+ negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
194
+ input_image=image,
195
+ width=width,
196
+ height=height,
197
+ num_frames=num_frames,
198
+ num_inference_steps=30,
199
+ seed=args.seed,
200
+ tiled=True,
201
+ audio_scale=args.audio_scale,
202
+ cfg_scale=args.prompt_cfg_scale,
203
+ audio_cfg_scale=args.audio_cfg_scale,
204
+ audio_proj=audio_proj_split,
205
+ audio_context_lens=audio_context_lens,
206
+ latents_num_frames=(num_frames - 1) // 4 + 1,
207
+ )
208
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
209
+ save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
210
+ save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
211
+
212
+ save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
213
+ final_command = [
214
+ "ffmpeg",
215
+ "-y",
216
+ "-i",
217
+ save_path_tmp,
218
+ "-i",
219
+ args.audio_path,
220
+ "-c:v",
221
+ "libx264",
222
+ "-c:a",
223
+ "aac",
224
+ "-shortest",
225
+ save_path,
226
+ ]
227
+ subprocess.run(final_command, check=True)
228
+ os.remove(save_path_tmp)
229
+ return save_path
230
+
231
+
232
+ if __name__ == "__main__":
233
+ args = parse_args()
234
+ pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
235
+
236
+ main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
FantasyTalking/infer.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python infer.py \
2
+ --image_path ./assets/images/woman.png \
3
+ --audio_path ./assets/audios/woman.wav \
4
+ --prompt "A woman is talking." \
5
+ --max_num_frames 81 \
6
+ --image_size 512 \
7
+ --audio_scale 1.0 \
8
+ --prompt_cfg_scale 5.0 \
9
+ --audio_cfg_scale 5.0 \
10
+ --fps 23 \
11
+ --seed 1111
FantasyTalking/infer_24G.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=2 python infer.py \
2
+ --image_path ./assets/images/woman.png \
3
+ --audio_path ./assets/audios/woman.wav \
4
+ --prompt "A woman is talking." \
5
+ --max_num_frames 81 \
6
+ --image_size 512 \
7
+ --audio_scale 1.0 \
8
+ --prompt_cfg_scale 5.0 \
9
+ --audio_cfg_scale 5.0 \
10
+ --fps 23 \
11
+ --num_persistent_param_in_dit 7000000000 \
12
+ --seed 1111
FantasyTalking/model.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from safetensors import safe_open
7
+
8
+ from diffsynth.models.wan_video_dit import WanModel, flash_attention, attention
9
+
10
+
11
+ class AudioProjModel(nn.Module):
12
+ def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
13
+ super().__init__()
14
+ self.cross_attention_dim = cross_attention_dim
15
+ self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
16
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
17
+
18
+ def forward(self, audio_embeds):
19
+ context_tokens = self.proj(audio_embeds)
20
+ context_tokens = self.norm(context_tokens)
21
+ return context_tokens # [B,L,C]
22
+
23
+
24
+ class WanCrossAttentionProcessor(nn.Module):
25
+ def __init__(self, context_dim, hidden_dim):
26
+ super().__init__()
27
+
28
+ self.context_dim = context_dim
29
+ self.hidden_dim = hidden_dim
30
+
31
+ self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
32
+ self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
33
+
34
+ nn.init.zeros_(self.k_proj.weight)
35
+ nn.init.zeros_(self.v_proj.weight)
36
+
37
+ def __call__(
38
+ self,
39
+ attn: nn.Module,
40
+ x: torch.Tensor,
41
+ context: torch.Tensor,
42
+ context_lens: torch.Tensor,
43
+ audio_proj: torch.Tensor,
44
+ audio_context_lens: torch.Tensor,
45
+ latents_num_frames: int = 21,
46
+ audio_scale: float = 1.0,
47
+ ) -> torch.Tensor:
48
+ """
49
+ x: [B, L1, C].
50
+ context: [B, L2, C].
51
+ context_lens: [B].
52
+ audio_proj: [B, 21, L3, C]
53
+ audio_context_lens: [B*21].
54
+ """
55
+ context_img = context[:, :257]
56
+ context = context[:, 257:]
57
+ b, n, d = x.size(0), attn.num_heads, attn.head_dim
58
+
59
+ # compute query, key, value
60
+ q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
61
+ k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
62
+ v = attn.v(context).view(b, -1, n, d)
63
+ k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
64
+ v_img = attn.v_img(context_img).view(b, -1, n, d)
65
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
66
+ # compute attention
67
+ x = flash_attention(q, k, v, k_lens=context_lens)
68
+ x = x.flatten(2)
69
+ img_x = img_x.flatten(2)
70
+
71
+ if len(audio_proj.shape) == 4:
72
+ audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
73
+ ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
74
+ ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
75
+ audio_x = attention(
76
+ audio_q, ip_key, ip_value, k_lens=audio_context_lens
77
+ )
78
+ audio_x = audio_x.view(b, q.size(1), n, d)
79
+ audio_x = audio_x.flatten(2)
80
+ elif len(audio_proj.shape) == 3:
81
+ ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
82
+ ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
83
+ audio_x = attention(q, ip_key, ip_value, k_lens=audio_context_lens)
84
+ audio_x = audio_x.flatten(2)
85
+ # output
86
+ x = x + img_x + audio_x * audio_scale
87
+ x = attn.o(x)
88
+ return x
89
+
90
+
91
+ class FantasyTalkingAudioConditionModel(nn.Module):
92
+ def __init__(self, wan_dit: WanModel, audio_in_dim: int, audio_proj_dim: int):
93
+ super().__init__()
94
+
95
+ self.audio_in_dim = audio_in_dim
96
+ self.audio_proj_dim = audio_proj_dim
97
+
98
+ # audio proj model
99
+ self.proj_model = self.init_proj(self.audio_proj_dim)
100
+ self.set_audio_processor(wan_dit)
101
+
102
+ def init_proj(self, cross_attention_dim=5120):
103
+ proj_model = AudioProjModel(
104
+ audio_in_dim=self.audio_in_dim, cross_attention_dim=cross_attention_dim
105
+ )
106
+ return proj_model
107
+
108
+ def set_audio_processor(self, wan_dit):
109
+ attn_procs = {}
110
+ for name in wan_dit.attn_processors.keys():
111
+ attn_procs[name] = WanCrossAttentionProcessor(
112
+ context_dim=self.audio_proj_dim, hidden_dim=wan_dit.dim
113
+ )
114
+ wan_dit.set_attn_processor(attn_procs)
115
+
116
+ def load_audio_processor(self, ip_ckpt: str, wan_dit):
117
+ if os.path.splitext(ip_ckpt)[-1] == ".safetensors":
118
+ state_dict = {"proj_model": {}, "audio_processor": {}}
119
+ with safe_open(ip_ckpt, framework="pt", device="cpu") as f:
120
+ for key in f.keys():
121
+ if key.startswith("proj_model."):
122
+ state_dict["proj_model"][
123
+ key.replace("proj_model.", "")
124
+ ] = f.get_tensor(key)
125
+ elif key.startswith("audio_processor."):
126
+ state_dict["audio_processor"][
127
+ key.replace("audio_processor.", "")
128
+ ] = f.get_tensor(key)
129
+ else:
130
+ state_dict = torch.load(ip_ckpt, map_location="cpu")
131
+ self.proj_model.load_state_dict(state_dict["proj_model"])
132
+ wan_dit.load_state_dict(state_dict["audio_processor"], strict=False)
133
+
134
+ def get_proj_fea(self, audio_fea=None):
135
+ return self.proj_model(audio_fea) if audio_fea is not None else None
136
+
137
+ def split_audio_sequence(self, audio_proj_length, num_frames=81):
138
+ """
139
+ Map the audio feature sequence to corresponding latent frame slices.
140
+
141
+ Args:
142
+ audio_proj_length (int): The total length of the audio feature sequence
143
+ (e.g., 173 in audio_proj[1, 173, 768]).
144
+ num_frames (int): The number of video frames in the training data (default: 81).
145
+
146
+ Returns:
147
+ list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
148
+ (within the audio feature sequence) corresponding to a latent frame.
149
+ """
150
+ # Average number of tokens per original video frame
151
+ tokens_per_frame = audio_proj_length / num_frames
152
+
153
+ # Each latent frame covers 4 video frames, and we want the center
154
+ tokens_per_latent_frame = tokens_per_frame * 4
155
+ half_tokens = int(tokens_per_latent_frame / 2)
156
+
157
+ pos_indices = []
158
+ for i in range(int((num_frames - 1) / 4) + 1):
159
+ if i == 0:
160
+ pos_indices.append(0)
161
+ else:
162
+ start_token = tokens_per_frame * ((i - 1) * 4 + 1)
163
+ end_token = tokens_per_frame * (i * 4 + 1)
164
+ center_token = int((start_token + end_token) / 2) - 1
165
+ pos_indices.append(center_token)
166
+
167
+ # Build index ranges centered around each position
168
+ pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
169
+
170
+ # Adjust the first range to avoid negative start index
171
+ pos_idx_ranges[0] = [
172
+ -(half_tokens * 2 - pos_idx_ranges[1][0]),
173
+ pos_idx_ranges[1][0],
174
+ ]
175
+
176
+ return pos_idx_ranges
177
+
178
+ def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
179
+ """
180
+ Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
181
+ if the range exceeds the input boundaries.
182
+
183
+ Args:
184
+ input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
185
+ pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
186
+ expand_length (int): Number of tokens to expand on both sides of each subsequence.
187
+
188
+ Returns:
189
+ sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
190
+ Each element is a padded subsequence.
191
+ k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
192
+ Useful for ignoring padding tokens in attention masks.
193
+ """
194
+ pos_idx_ranges = [
195
+ [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
196
+ ]
197
+ sub_sequences = []
198
+ seq_len = input_tensor.size(1) # 173
199
+ max_valid_idx = seq_len - 1 # 172
200
+ k_lens_list = []
201
+ for start, end in pos_idx_ranges:
202
+ # Calculate the fill amount
203
+ pad_front = max(-start, 0)
204
+ pad_back = max(end - max_valid_idx, 0)
205
+
206
+ # Calculate the start and end indices of the valid part
207
+ valid_start = max(start, 0)
208
+ valid_end = min(end, max_valid_idx)
209
+
210
+ # Extract the valid part
211
+ if valid_start <= valid_end:
212
+ valid_part = input_tensor[:, valid_start : valid_end + 1, :]
213
+ else:
214
+ valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
215
+
216
+ # In the sequence dimension (the 1st dimension) perform padding
217
+ padded_subseq = F.pad(
218
+ valid_part,
219
+ (0, 0, 0, pad_back + pad_front, 0, 0),
220
+ mode="constant",
221
+ value=0,
222
+ )
223
+ k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
224
+
225
+ sub_sequences.append(padded_subseq)
226
+ return torch.stack(sub_sequences, dim=1), torch.tensor(
227
+ k_lens_list, dtype=torch.long
228
+ )
FantasyTalking/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ cupy-cuda12x
4
+ transformers==4.46.2
5
+ controlnet-aux==0.0.7
6
+ imageio
7
+ imageio[ffmpeg]
8
+ safetensors
9
+ einops
10
+ sentencepiece
11
+ protobuf
12
+ modelscope
13
+ ftfy
14
+ librosa
FantasyTalking/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Alibaba Inc. All Rights Reserved.
2
+
3
+ import imageio
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+
11
+ def resize_image_by_longest_edge(image_path, target_size):
12
+ image = Image.open(image_path).convert("RGB")
13
+ width, height = image.size
14
+ scale = target_size / max(width, height)
15
+ new_size = (int(width * scale), int(height * scale))
16
+ return image.resize(new_size, Image.LANCZOS)
17
+
18
+
19
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
20
+ writer = imageio.get_writer(
21
+ save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
22
+ )
23
+ for frame in tqdm(frames, desc="Saving video"):
24
+ frame = np.array(frame)
25
+ writer.append_data(frame)
26
+ writer.close()
27
+
28
+
29
+ def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
30
+ sr = 16000
31
+ audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
32
+
33
+ start_time = 0
34
+ # end_time = (0 + (num_frames - 1) * 1) / fps
35
+ end_time = num_frames / fps
36
+
37
+ start_sample = int(start_time * sr)
38
+ end_sample = int(end_time * sr)
39
+
40
+ try:
41
+ audio_segment = audio_input[start_sample:end_sample]
42
+ except:
43
+ audio_segment = audio_input
44
+
45
+ input_values = audio_processor(
46
+ audio_segment, sampling_rate=sample_rate, return_tensors="pt"
47
+ ).input_values.to("cuda")
48
+
49
+ with torch.no_grad():
50
+ fea = wav2vec(input_values).last_hidden_state
51
+
52
+ return fea