Scorpion06 commited on
Commit
56c3ab7
·
verified ·
1 Parent(s): 31ab001

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +177 -0
  3. LICENSE +407 -0
  4. README.md +180 -12
  5. defaults.ini +56 -0
  6. example/V2A_sample-1.mp4 +3 -0
  7. example/V2A_sample-2.mp4 +3 -0
  8. example/V2A_sample-3.mp4 +3 -0
  9. example/V2M_sample-1.mp4 +3 -0
  10. example/V2M_sample-2.mp4 +3 -0
  11. example/V2M_sample-3.mp4 +3 -0
  12. pyproject.toml +3 -0
  13. run.bat +8 -0
  14. run_gradio.py +32 -0
  15. setup.py +46 -0
  16. stable_audio_tools/__init__.py +2 -0
  17. stable_audio_tools/data/__init__.py +0 -0
  18. stable_audio_tools/data/dataset.py +876 -0
  19. stable_audio_tools/data/utils.py +199 -0
  20. stable_audio_tools/inference/__init__.py +0 -0
  21. stable_audio_tools/inference/generation.py +275 -0
  22. stable_audio_tools/inference/sampling.py +235 -0
  23. stable_audio_tools/inference/utils.py +35 -0
  24. stable_audio_tools/interface/__init__.py +0 -0
  25. stable_audio_tools/interface/gradio.py +495 -0
  26. stable_audio_tools/models/__init__.py +1 -0
  27. stable_audio_tools/models/adp.py +1588 -0
  28. stable_audio_tools/models/autoencoders.py +794 -0
  29. stable_audio_tools/models/blocks.py +339 -0
  30. stable_audio_tools/models/bottleneck.py +355 -0
  31. stable_audio_tools/models/codebook_patterns.py +545 -0
  32. stable_audio_tools/models/conditioners.py +710 -0
  33. stable_audio_tools/models/diffusion.py +704 -0
  34. stable_audio_tools/models/discriminators.py +546 -0
  35. stable_audio_tools/models/dit.py +379 -0
  36. stable_audio_tools/models/factory.py +153 -0
  37. stable_audio_tools/models/lm.py +542 -0
  38. stable_audio_tools/models/local_attention.py +278 -0
  39. stable_audio_tools/models/pqmf.py +393 -0
  40. stable_audio_tools/models/pretrained.py +25 -0
  41. stable_audio_tools/models/pretransforms.py +258 -0
  42. stable_audio_tools/models/temptransformer.py +190 -0
  43. stable_audio_tools/models/transformer.py +812 -0
  44. stable_audio_tools/models/utils.py +92 -0
  45. stable_audio_tools/models/wavelets.py +82 -0
  46. stable_audio_tools/training/__init__.py +1 -0
  47. stable_audio_tools/training/autoencoders.py +476 -0
  48. stable_audio_tools/training/diffusion.py +1656 -0
  49. stable_audio_tools/training/factory.py +240 -0
  50. stable_audio_tools/training/lm.py +267 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/V2A_sample-1.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ example/V2A_sample-2.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ example/V2A_sample-3.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ example/V2M_sample-1.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ example/V2M_sample-2.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ example/V2M_sample-3.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+
153
+
154
+ # PyCharm
155
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
156
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
157
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
158
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
159
+ #.idea/
160
+
161
+ *.ckpt
162
+ *.wav
163
+ # *.mp4
164
+ *.mp3
165
+ *.jsonl
166
+ wandb/*
167
+
168
+
169
+
170
+
171
+ model/
172
+ logs/
173
+ log/
174
+ saved_ckpt/
175
+ wandb/
176
+ demo_result/
177
+ model/
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,12 +1,180 @@
1
- ---
2
- title: AudioX Viet
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.29.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AudioX-Viet
3
+ app_file: run_gradio.py
4
+ sdk: gradio
5
+ sdk_version: 4.44.1
6
+ ---
7
+ # 🎧 AudioX: Diffusion Transformer for Anything-to-Audio Generation
8
+
9
+ [![arXiv](https://img.shields.io/badge/arXiv-2503.10522-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2503.10522)
10
+ [![Project Page](https://img.shields.io/badge/GitHub.io-Project-blue?logo=Github&style=flat-square)](https://zeyuet.github.io/AudioX/)
11
+ [![🤗 Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/HKUSTAudio/AudioX)
12
+ [![🤗 Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue)](https://huggingface.co/spaces/Zeyue7/AudioX)
13
+
14
+ ---
15
+
16
+ **This is the official repository for "[AudioX: Diffusion Transformer for Anything-to-Audio Generation](https://arxiv.org/pdf/2503.10522)".**
17
+
18
+
19
+ ## 📺 Demo Video
20
+
21
+ https://github.com/user-attachments/assets/0d8dd927-ff0f-4b35-ab1f-b3c3915017be
22
+
23
+ ---
24
+
25
+
26
+ ## ✨ Abstract
27
+
28
+ Audio and music generation have emerged as crucial tasks in many applications, yet existing approaches face significant limitations: they operate in isolation without unified capabilities across modalities, suffer from scarce high-quality, multi-modal training data, and struggle to effectively integrate diverse inputs. In this work, we propose AudioX, a unified Diffusion Transformer model for Anything-to-Audio and Music Generation. Unlike previous domain-specific models, AudioX can generate both general audio and music with high quality, while offering flexible natural language control and seamless processing of various modalities including text, video, image, music, and audio. Its key innovation is a multi-modal masked training strategy that masks inputs across modalities and forces the model to learn from masked inputs, yielding robust and unified cross-modal representations. To address data scarcity, we curate two comprehensive datasets: vggsound-caps with 190K audio captions based on the VGGSound dataset, and V2M-caps with 6 million music captions derived from the V2M dataset. Extensive experiments demonstrate that AudioX not only matches or outperforms state-of-the-art specialized models, but also offers remarkable versatility in handling diverse input modalities and generation tasks within a unified architecture.
29
+
30
+
31
+ ## ✨ Teaser
32
+
33
+ <p align="center">
34
+ <img src="https://github.com/user-attachments/assets/ea723225-f9c8-4ca2-8837-2c2c08189bdd" alt="method">
35
+ </p>
36
+ <p style="text-align: left;">(a) Overview of AudioX, illustrating its capabilities across various tasks. (b) Radar chart comparing the performance of different methods across multiple benchmarks. AudioX demonstrates superior Inception Scores (IS) across a diverse set of datasets in audio and music generation tasks.</p>
37
+
38
+
39
+ ## ✨ Method
40
+
41
+ <p align="center">
42
+ <img src="https://github.com/user-attachments/assets/94ea3df0-8c66-4259-b681-791ee41bada8" alt="method">
43
+ </p>
44
+ <p align="center">Overview of the AudioX Framework.</p>
45
+
46
+
47
+
48
+ ## Code
49
+
50
+
51
+ ### 🛠️ Environment Setup
52
+
53
+ ```bash
54
+ git clone https://github.com/ZeyueT/AudioX.git
55
+ cd AudioX
56
+ conda create -n AudioX python=3.8.20
57
+ conda activate AudioX
58
+ pip install git+https://github.com/ZeyueT/AudioX.git
59
+ conda install -c conda-forge ffmpeg libsndfile
60
+
61
+ ```
62
+
63
+ ## 🪄 Pretrained Checkpoints
64
+
65
+ Download the pretrained model from 🤗 [AudioX on Hugging Face](https://huggingface.co/HKUSTAudio/AudioX):
66
+
67
+ ```bash
68
+ mkdir -p model
69
+ wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/model.ckpt -O model/model.ckpt
70
+ wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/config.json -O model/config.json
71
+ ```
72
+
73
+ ### 🤗 Gradio Demo
74
+
75
+ To launch the Gradio demo locally, run:
76
+
77
+ ```bash
78
+ python3 run_gradio.py \
79
+ --model-config model/config.json \
80
+ --share
81
+ ```
82
+
83
+
84
+ ### 🎯 Prompt Configuration Examples
85
+
86
+ | Task | `video_path` | `text_prompt` | `audio_path` |
87
+ |:---------------------|:-------------------|:----------------------------------------------|:-------------|
88
+ | Text-to-Audio (T2A) | `None` | `"Typing on a keyboard"` | `None` |
89
+ | Text-to-Music (T2M) | `None` | `"A music with piano and violin"` | `None` |
90
+ | Video-to-Audio (V2A) | `"video_path.mp4"` | `"Generate general audio for the video"` | `None` |
91
+ | Video-to-Music (V2M) | `"video_path.mp4"` | `"Generate music for the video"` | `None` |
92
+ | TV-to-Audio (TV2A) | `"video_path.mp4"` | `"Ocean waves crashing with people laughing"` | `None` |
93
+ | TV-to-Music (TV2M) | `"video_path.mp4"` | `"Generate music with piano instrument"` | `None` |
94
+
95
+ ### 🖥️ Script Inference
96
+
97
+ ```python
98
+ import torch
99
+ import torchaudio
100
+ from einops import rearrange
101
+ from stable_audio_tools import get_pretrained_model
102
+ from stable_audio_tools.inference.generation import generate_diffusion_cond
103
+ from stable_audio_tools.data.utils import read_video, merge_video_audio
104
+ from stable_audio_tools.data.utils import load_and_process_audio
105
+ import os
106
+
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+
109
+ # Download model
110
+ model, model_config = get_pretrained_model("HKUSTAudio/AudioX")
111
+ sample_rate = model_config["sample_rate"]
112
+ sample_size = model_config["sample_size"]
113
+ target_fps = model_config["video_fps"]
114
+ seconds_start = 0
115
+ seconds_total = 10
116
+
117
+ model = model.to(device)
118
+
119
+ # for video-to-music generation
120
+ video_path = "example/V2M_sample-1.mp4"
121
+ text_prompt = "Generate music for the video"
122
+ audio_path = None
123
+
124
+ video_tensor = read_video(video_path, seek_time=0, duration=seconds_total, target_fps=target_fps)
125
+ audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
126
+
127
+ conditioning = [{
128
+ "video_prompt": [video_tensor.unsqueeze(0)],
129
+ "text_prompt": text_prompt,
130
+ "audio_prompt": audio_tensor.unsqueeze(0),
131
+ "seconds_start": seconds_start,
132
+ "seconds_total": seconds_total
133
+ }]
134
+
135
+ # Generate stereo audio
136
+ output = generate_diffusion_cond(
137
+ model,
138
+ steps=250,
139
+ cfg_scale=7,
140
+ conditioning=conditioning,
141
+ sample_size=sample_size,
142
+ sigma_min=0.3,
143
+ sigma_max=500,
144
+ sampler_type="dpmpp-3m-sde",
145
+ device=device
146
+ )
147
+
148
+ # Rearrange audio batch to a single sequence
149
+ output = rearrange(output, "b d n -> d (b n)")
150
+
151
+ # Peak normalize, clip, convert to int16, and save to file
152
+ output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
153
+ torchaudio.save("output.wav", output, sample_rate)
154
+
155
+ if video_path is not None and os.path.exists(video_path):
156
+ merge_video_audio(video_path, "output.wav", "output.mp4", 0, seconds_total)
157
+
158
+ ```
159
+
160
+
161
+ ## 🚀 Citation
162
+
163
+ If you find our work useful, please consider citing:
164
+
165
+ ```
166
+ @article{tian2025audiox,
167
+ title={AudioX: Diffusion Transformer for Anything-to-Audio Generation},
168
+ author={Tian, Zeyue and Jin, Yizhu and Liu, Zhaoyang and Yuan, Ruibin and Tan, Xu and Chen, Qifeng and Xue, Wei and Guo, Yike},
169
+ journal={arXiv preprint arXiv:2503.10522},
170
+ year={2025}
171
+ }
172
+ ```
173
+
174
+ ## 📭 Contact
175
+
176
+ If you have any comments or questions, feel free to contact Zeyue Tian([email protected]).
177
+
178
+ ## License
179
+
180
+ Please follow [CC-BY-NC](./LICENSE).
defaults.ini ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [DEFAULTS]
3
+
4
+ #name of the run
5
+ name = stable_audio_tools
6
+
7
+ # the batch size
8
+ batch_size = 8
9
+
10
+ # number of GPUs to use for training
11
+ num_gpus = 1
12
+
13
+ # number of nodes to use for training
14
+ num_nodes = 1
15
+
16
+ # Multi-GPU strategy for PyTorch Lightning
17
+ strategy = ""
18
+
19
+ # Precision to use for training
20
+ precision = "16-mixed"
21
+
22
+ # number of CPU workers for the DataLoader
23
+ num_workers = 8
24
+
25
+ # the random seed
26
+ seed = 42
27
+
28
+ # Batches for gradient accumulation
29
+ accum_batches = 1
30
+
31
+ # Number of steps between checkpoints
32
+ checkpoint_every = 10000
33
+
34
+ # trainer checkpoint file to restart training from
35
+ ckpt_path = ''
36
+
37
+ # model checkpoint file to start a new training run from
38
+ pretrained_ckpt_path = ''
39
+
40
+ # Checkpoint path for the pretransform model if needed
41
+ pretransform_ckpt_path = ''
42
+
43
+ # configuration model specifying model hyperparameters
44
+ model_config = ''
45
+
46
+ # configuration for datasets
47
+ dataset_config = ''
48
+
49
+ # directory to save the checkpoints in
50
+ save_dir = ''
51
+
52
+ # gradient_clip_val passed into PyTorch Lightning Trainer
53
+ gradient_clip_val = 0.0
54
+
55
+ # remove the weight norm from the pretransform model
56
+ remove_pretransform_weight_norm = ''
example/V2A_sample-1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bcb04e7f53461f2420078122338447a18a2baea5e14397cd7099cd97bec6260
3
+ size 4101390
example/V2A_sample-2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87850f3803ed2aba8928322b9eb703fc4653638e3ce4bd4a2dc179bbbe9c0542
3
+ size 2434915
example/V2A_sample-3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:339d929f307f01adf0ce139a3ecc19a1b95bea711cc718114099102fb2280200
3
+ size 2772307
example/V2M_sample-1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:489488199943d4154430e0adeeede9ae41e070a63d3a4bbb01de3247e8817a2e
3
+ size 7634025
example/V2M_sample-2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f225024c10bfeed527f666d5c00f5232a8822b57abdbaecc2cc105fc8a7d509f
3
+ size 1705374
example/V2M_sample-3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b49940836aa92b24bcb82627780081ff2cfd5e3e977a5529c64a6a5e15b96e5
3
+ size 2654967
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
run.bat ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ REM — 1) Activate your Conda environment (must use CALL) :contentReference[oaicite:0]{index=0}
2
+ call conda activate AudioX
3
+
4
+ REM — 2) Run the Gradio script
5
+ python run_gradio.py --model-config model/config.json --share
6
+
7
+ REM — 3) Exit the batch file and close the window :contentReference[oaicite:1]{index=1}
8
+ exit /B 0
run_gradio.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_audio_tools import get_pretrained_model
2
+ from stable_audio_tools.interface.gradio import create_ui
3
+ import json
4
+
5
+ import torch
6
+
7
+ def main(args):
8
+ torch.manual_seed(42)
9
+
10
+ interface = create_ui(
11
+ model_config_path = args.model_config,
12
+ ckpt_path=args.ckpt_path,
13
+ pretrained_name=args.pretrained_name,
14
+ pretransform_ckpt_path=args.pretransform_ckpt_path,
15
+ model_half=args.model_half
16
+ )
17
+ interface.queue()
18
+ interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None)
19
+
20
+ if __name__ == "__main__":
21
+ import argparse
22
+ parser = argparse.ArgumentParser(description='Run gradio interface')
23
+ parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
24
+ parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
25
+ parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
26
+ parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
27
+ parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False)
28
+ parser.add_argument('--username', type=str, help='Gradio username', required=False)
29
+ parser.add_argument('--password', type=str, help='Gradio password', required=False)
30
+ parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
31
+ args = parser.parse_args()
32
+ main(args)
setup.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='AudioX',
5
+ version='0.1.0',
6
+ url='https://github.com/ZeyueT/AudioX.git',
7
+ author='AudioX, HKUST',
8
+ description='Training and inference tools for generative audio models from AudioX',
9
+ packages=find_packages(),
10
+ install_requires=[
11
+ 'aeiou',
12
+ 'alias-free-torch==0.0.6',
13
+ 'auraloss==0.4.0',
14
+ 'descript-audio-codec==1.0.0',
15
+ 'decord==0.6.0',
16
+ 'einops',
17
+ 'einops_exts',
18
+ 'ema-pytorch==0.2.3',
19
+ 'encodec==0.1.1',
20
+ 'gradio==4.44.1',
21
+ 'gradio_client==1.3.0',
22
+ 'huggingface_hub',
23
+ 'importlib-resources==5.12.0',
24
+ 'k-diffusion==0.1.1',
25
+ 'laion-clap==1.1.6',
26
+ 'local-attention==1.8.6',
27
+ 'pandas==2.0.2',
28
+ 'pedalboard==0.9.14',
29
+ 'prefigure==0.0.9',
30
+ 'pytorch_lightning==2.4.0',
31
+ 'PyWavelets==1.4.1',
32
+ 'safetensors',
33
+ 'sentencepiece==0.1.99',
34
+ 'torch>=2.0.1',
35
+ 'torchaudio>=2.0.2',
36
+ 'torchmetrics==0.11.4',
37
+ 'tqdm',
38
+ 'transformers',
39
+ 'v-diffusion-pytorch==0.0.2',
40
+ 'vector-quantize-pytorch==1.9.14',
41
+ 'wandb',
42
+ 'webdataset==0.2.48',
43
+ 'x-transformers<1.27.0',
44
+ ],
45
+
46
+ )
stable_audio_tools/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models.factory import create_model_from_config, create_model_from_config_path
2
+ from .models.pretrained import get_pretrained_model
stable_audio_tools/data/__init__.py ADDED
File without changes
stable_audio_tools/data/dataset.py ADDED
@@ -0,0 +1,876 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import io
4
+ import os
5
+ import posixpath
6
+ import random
7
+ import re
8
+ import subprocess
9
+ import time
10
+ import torch
11
+ import torchaudio
12
+ import webdataset as wds
13
+
14
+ from aeiou.core import is_silence
15
+ from os import path
16
+ from pedalboard.io import AudioFile
17
+ from torchaudio import transforms as T
18
+ from typing import Optional, Callable, List
19
+ from torchdata.datapipes.iter import IterDataPipe, IterableWrapper
20
+ from torchdata.datapipes.iter import Prefetcher
21
+
22
+ from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T
23
+ import json
24
+
25
+
26
+ import os
27
+ import datetime
28
+ from memory_profiler import profile
29
+
30
+
31
+ AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
32
+
33
+ # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
34
+
35
+ def fast_scandir(
36
+ dir:str, # top-level directory at which to begin scanning
37
+ ext:list, # list of allowed file extensions,
38
+ #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
39
+ ):
40
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
41
+ subfolders, files = [], []
42
+ ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
43
+ try: # hope to avoid 'permission denied' by this try
44
+ for f in os.scandir(dir):
45
+ try: # 'hope to avoid too many levels of symbolic links' error
46
+ if f.is_dir():
47
+ subfolders.append(f.path)
48
+ elif f.is_file():
49
+ file_ext = os.path.splitext(f.name)[1].lower()
50
+ is_hidden = os.path.basename(f.path).startswith(".")
51
+
52
+ if file_ext in ext and not is_hidden:
53
+ files.append(f.path)
54
+ except:
55
+ pass
56
+ except:
57
+ pass
58
+
59
+ for dir in list(subfolders):
60
+ sf, f = fast_scandir(dir, ext)
61
+ subfolders.extend(sf)
62
+ files.extend(f)
63
+ return subfolders, files
64
+
65
+ def extract_audio_paths(jsonl_file, exts):
66
+ audio_paths = []
67
+ video_paths = []
68
+ text_prompts = []
69
+ data_types = []
70
+ with open(jsonl_file, 'r') as file:
71
+ for line in file:
72
+ try:
73
+ data = json.loads(line.strip())
74
+ path = data.get('path', '')
75
+ video_path = data.get('video_path', '')
76
+ text_prompt = data.get('caption', '')
77
+ data_type = data.get('type', None)
78
+ if any(path.endswith(ext) for ext in exts):
79
+ audio_paths.append(path)
80
+ video_paths.append(video_path)
81
+ text_prompts.append(text_prompt)
82
+ data_types.append(data_type)
83
+ except json.JSONDecodeError:
84
+ print(f"Error decoding JSON line: {line.strip()}")
85
+ return audio_paths, video_paths, text_prompts, data_types
86
+
87
+ def keyword_scandir(
88
+ dir: str, # top-level directory at which to begin scanning
89
+ ext: list, # list of allowed file extensions
90
+ keywords: list, # list of keywords to search for in the file name
91
+ ):
92
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
93
+ subfolders, files = [], []
94
+ # make keywords case insensitive
95
+ keywords = [keyword.lower() for keyword in keywords]
96
+ # add starting period to extensions if needed
97
+ ext = ['.'+x if x[0] != '.' else x for x in ext]
98
+ banned_words = ["paxheader", "__macosx"]
99
+ try: # hope to avoid 'permission denied' by this try
100
+ for f in os.scandir(dir):
101
+ try: # 'hope to avoid too many levels of symbolic links' error
102
+ if f.is_dir():
103
+ subfolders.append(f.path)
104
+ elif f.is_file():
105
+ is_hidden = f.name.split("/")[-1][0] == '.'
106
+ has_ext = os.path.splitext(f.name)[1].lower() in ext
107
+ name_lower = f.name.lower()
108
+ has_keyword = any(
109
+ [keyword in name_lower for keyword in keywords])
110
+ has_banned = any(
111
+ [banned_word in name_lower for banned_word in banned_words])
112
+ if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
113
+ files.append(f.path)
114
+ except:
115
+ pass
116
+ except:
117
+ pass
118
+
119
+ for dir in list(subfolders):
120
+ sf, f = keyword_scandir(dir, ext, keywords)
121
+ subfolders.extend(sf)
122
+ files.extend(f)
123
+ return subfolders, files
124
+
125
+ def get_audio_filenames(
126
+ paths: list, # directories in which to search
127
+ keywords=None,
128
+ exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
129
+ ):
130
+
131
+ "recursively get a list of audio filenames"
132
+ filenames = []
133
+ video_filenames = []
134
+ text_prompts = []
135
+ data_types = []
136
+
137
+ if type(paths) is str:
138
+ paths = [paths]
139
+
140
+
141
+ if os.path.isdir(paths[0]):
142
+ for path in paths: # get a list of relevant filenames
143
+ if keywords is not None:
144
+ subfolders, files = keyword_scandir(path, exts, keywords)
145
+ else:
146
+ subfolders, files = fast_scandir(path, exts)
147
+ filenames.extend(files)
148
+ return filenames
149
+
150
+ elif os.path.isfile(paths[0]):
151
+ assert paths[0].endswith('.jsonl')
152
+ for path in paths:
153
+ audio_paths, video_paths, text_prompt, data_type = extract_audio_paths(path, exts)
154
+ filenames.extend(audio_paths)
155
+ video_filenames.extend(video_paths)
156
+ text_prompts.extend(text_prompt)
157
+ data_types.extend(data_type)
158
+
159
+ return filenames, video_filenames, text_prompts, data_types
160
+
161
+
162
+ class LocalDatasetConfig:
163
+ def __init__(
164
+ self,
165
+ id: str,
166
+ path: str,
167
+ video_fps: int,
168
+ custom_metadata_fn: Optional[Callable[[str], str]] = None
169
+ ):
170
+ self.id = id
171
+ self.path = path
172
+ self.video_fps = video_fps
173
+ self.custom_metadata_fn = custom_metadata_fn
174
+
175
+
176
+ # @profile
177
+ class SampleDataset(torch.utils.data.Dataset):
178
+ # @profile
179
+ def __init__(
180
+ self,
181
+ configs,
182
+ sample_size=65536,
183
+ sample_rate=48000,
184
+ keywords=None,
185
+ random_crop=True,
186
+ force_channels="stereo",
187
+ video_fps=5
188
+ ):
189
+ super().__init__()
190
+ self.filenames = []
191
+ self.video_filenames = []
192
+ self.text_prompts = []
193
+ self.data_types = []
194
+
195
+ self.augs = torch.nn.Sequential(
196
+ PhaseFlipper(),
197
+ )
198
+
199
+ self.root_paths = []
200
+
201
+ self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
202
+
203
+ self.force_channels = force_channels
204
+
205
+ self.encoding = torch.nn.Sequential(
206
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
207
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
208
+ )
209
+
210
+ self.sr = sample_rate
211
+
212
+ self.custom_metadata_fns = {}
213
+
214
+ for config in configs:
215
+ self.video_fps = config.video_fps
216
+
217
+ self.root_paths.append(config.path)
218
+ audio_files, video_files, text_prompt, data_types = get_audio_filenames(config.path, keywords)
219
+
220
+ self.filenames.extend(audio_files)
221
+ self.video_filenames.extend(video_files)
222
+ self.text_prompts.extend(text_prompt)
223
+ self.data_types.extend(data_types)
224
+ if config.custom_metadata_fn is not None:
225
+ self.custom_metadata_fns[config.path] = config.custom_metadata_fn
226
+
227
+ print(f'Found {len(self.filenames)} files')
228
+
229
+
230
+ def load_file(self, filename):
231
+ ext = filename.split(".")[-1]
232
+
233
+ if ext == "mp3":
234
+ with AudioFile(filename) as f:
235
+ audio = f.read(f.frames)
236
+ audio = torch.from_numpy(audio)
237
+ in_sr = f.samplerate
238
+ else:
239
+ audio, in_sr = torchaudio.load(filename, format=ext)
240
+
241
+ if in_sr != self.sr:
242
+ resample_tf = T.Resample(in_sr, self.sr)
243
+ audio = resample_tf(audio)
244
+
245
+ return audio
246
+
247
+ def __len__(self):
248
+ return len(self.filenames)
249
+
250
+
251
+ def __getitem__(self, idx):
252
+ audio_filename = self.filenames[idx]
253
+ video_filename = self.video_filenames[idx]
254
+ text_prompt = self.text_prompts[idx]
255
+ data_type = self.data_types[idx]
256
+
257
+ try:
258
+
259
+ start_time = time.time()
260
+ audio = self.load_file(audio_filename)
261
+
262
+
263
+ if data_type in ["text_condition-audio", "text_condition-music",
264
+ "video_condition-audio", "video_condition-music",
265
+ "text+video_condition-audio","text+video_condition-music"]:
266
+ if_audio_contition = False
267
+ audio_prompt = torch.zeros((2, self.sr * 10))
268
+ elif data_type in ["audio_condition-audio", "audio_condition-music",
269
+ "uni_condition-audio", "uni_condition-music"]:
270
+ if_audio_contition = True
271
+
272
+ if if_audio_contition:
273
+ audio_org = audio.clamp(-1, 1)
274
+
275
+
276
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
277
+
278
+ if self.augs is not None:
279
+ audio = self.augs(audio)
280
+
281
+ audio = audio.clamp(-1, 1)
282
+
283
+ if if_audio_contition:
284
+ if data_type.split("-")[-1] == "audio":
285
+ start_index = max(0, int((seconds_start) * self.sr))
286
+ end_index = int((seconds_start+10) * self.sr)
287
+ audio_prompt = audio_org[:, start_index:end_index]
288
+
289
+ elif data_type.split("-")[-1] == "music":
290
+ if seconds_start < 10:
291
+ start_index = 0
292
+ end_index = int(10 * self.sr)
293
+ else:
294
+ start_index = max(0, int((seconds_start - 10) * self.sr))
295
+ end_index = int(seconds_start * self.sr)
296
+ audio_prompt = audio_org[:, start_index:end_index]
297
+
298
+ # Encode the file to assist in prediction
299
+ if self.encoding is not None:
300
+ audio = self.encoding(audio)
301
+
302
+ info = {}
303
+
304
+
305
+ info["path"] = audio_filename
306
+ info["video_path"] = video_filename
307
+ info["text_prompt"] = text_prompt
308
+ info["audio_prompt"] = audio_prompt
309
+ info["data_type"] = data_type
310
+
311
+ for root_path in self.root_paths:
312
+ if root_path in audio_filename:
313
+ info["relpath"] = path.relpath(audio_filename, root_path)
314
+
315
+ info["timestamps"] = (t_start, t_end)
316
+ info["seconds_start"] = seconds_start
317
+ info["seconds_total"] = seconds_total
318
+ info["padding_mask"] = padding_mask
319
+ info["video_fps"] = self.video_fps
320
+ end_time = time.time()
321
+
322
+ info["load_time"] = end_time - start_time
323
+
324
+ for custom_md_path in self.custom_metadata_fns.keys():
325
+ if os.path.isdir(custom_md_path):
326
+ if custom_md_path in audio_filename:
327
+ custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
328
+ custom_metadata = custom_metadata_fn(info, audio)
329
+ info.update(custom_metadata)
330
+ elif os.path.isfile(custom_md_path):
331
+ custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
332
+ custom_metadata = custom_metadata_fn(info, audio)
333
+ info.update(custom_metadata)
334
+
335
+ if "__reject__" in info and info["__reject__"]:
336
+ return self[random.randrange(len(self))]
337
+
338
+ file_name = audio_filename.split('/')[-1]
339
+
340
+ return (audio, info)
341
+ except Exception as e:
342
+ print(f'Couldn\'t load file {audio_filename}: {e}')
343
+ return self[random.randrange(len(self))]
344
+
345
+ def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
346
+ """Return function over iterator that groups key, value pairs into samples.
347
+ :param keys: function that splits the key into key and extension (base_plus_ext)
348
+ :param lcase: convert suffixes to lower case (Default value = True)
349
+ """
350
+ current_sample = None
351
+ for filesample in data:
352
+ assert isinstance(filesample, dict)
353
+ fname, value = filesample["fname"], filesample["data"]
354
+ prefix, suffix = keys(fname)
355
+ if wds.tariterators.trace:
356
+ print(
357
+ prefix,
358
+ suffix,
359
+ current_sample.keys() if isinstance(current_sample, dict) else None,
360
+ )
361
+ if prefix is None:
362
+ continue
363
+ if lcase:
364
+ suffix = suffix.lower()
365
+ if current_sample is None or prefix != current_sample["__key__"]:
366
+ if wds.tariterators.valid_sample(current_sample):
367
+ yield current_sample
368
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
369
+ if suffix in current_sample:
370
+ print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
371
+ if suffixes is None or suffix in suffixes:
372
+ current_sample[suffix] = value
373
+ if wds.tariterators.valid_sample(current_sample):
374
+ yield current_sample
375
+
376
+ wds.tariterators.group_by_keys = group_by_keys
377
+
378
+ # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
379
+
380
+ def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
381
+ """
382
+ Returns a list of full S3 paths to files in a given S3 bucket and directory path.
383
+ """
384
+ # Ensure dataset_path ends with a trailing slash
385
+ if dataset_path != '' and not dataset_path.endswith('/'):
386
+ dataset_path += '/'
387
+ # Use posixpath to construct the S3 URL path
388
+ bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
389
+ # Construct the `aws s3 ls` command
390
+ cmd = ['aws', 's3', 'ls', bucket_path]
391
+
392
+ if profile is not None:
393
+ cmd.extend(['--profile', profile])
394
+
395
+ if recursive:
396
+ # Add the --recursive flag if requested
397
+ cmd.append('--recursive')
398
+
399
+ # Run the `aws s3 ls` command and capture the output
400
+ run_ls = subprocess.run(cmd, capture_output=True, check=True)
401
+ # Split the output into lines and strip whitespace from each line
402
+ contents = run_ls.stdout.decode('utf-8').split('\n')
403
+ contents = [x.strip() for x in contents if x]
404
+ # Remove the timestamp from lines that begin with a timestamp
405
+ contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
406
+ if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
407
+ # Construct a full S3 path for each file in the contents list
408
+ contents = [posixpath.join(s3_url_prefix or '', x)
409
+ for x in contents if not x.endswith('/')]
410
+ # Apply the filter, if specified
411
+ if filter:
412
+ contents = [x for x in contents if filter in x]
413
+ # Remove redundant directory names in the S3 URL
414
+ if recursive:
415
+ # Get the main directory name from the S3 URL
416
+ main_dir = "/".join(bucket_path.split('/')[3:])
417
+ # Remove the redundant directory names from each file path
418
+ contents = [x.replace(f'{main_dir}', '').replace(
419
+ '//', '/') for x in contents]
420
+ # Print debugging information, if requested
421
+ if debug:
422
+ print("contents = \n", contents)
423
+ # Return the list of S3 paths to files
424
+ return contents
425
+
426
+
427
+ def get_all_s3_urls(
428
+ names=[], # list of all valid [LAION AudioDataset] dataset names
429
+ # list of subsets you want from those datasets, e.g. ['train','valid']
430
+ subsets=[''],
431
+ s3_url_prefix=None, # prefix for those dataset names
432
+ recursive=True, # recursively list all tar files in all subdirs
433
+ filter_str='tar', # only grab files with this substring
434
+ # print debugging info -- note: info displayed likely to change at dev's whims
435
+ debug=False,
436
+ profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
437
+ ):
438
+ "get urls of shards (tar files) for multiple datasets in one s3 bucket"
439
+ urls = []
440
+ for name in names:
441
+ # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
442
+ if s3_url_prefix is None:
443
+ contents_str = name
444
+ else:
445
+ # Construct the S3 path using the s3_url_prefix and the current name value
446
+ contents_str = posixpath.join(s3_url_prefix, name)
447
+ if debug:
448
+ print(f"get_all_s3_urls: {contents_str}:")
449
+ for subset in subsets:
450
+ subset_str = posixpath.join(contents_str, subset)
451
+ if debug:
452
+ print(f"subset_str = {subset_str}")
453
+ # Get the list of tar files in the current subset directory
454
+ profile = profiles.get(name, None)
455
+ tar_list = get_s3_contents(
456
+ subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
457
+ for tar in tar_list:
458
+ # Escape spaces and parentheses in the tar filename for use in the shell command
459
+ tar = tar.replace(" ", "\ ").replace(
460
+ "(", "\(").replace(")", "\)")
461
+ # Construct the S3 path to the current tar file
462
+ s3_path = posixpath.join(name, subset, tar) + " -"
463
+ # Construct the AWS CLI command to download the current tar file
464
+ if s3_url_prefix is None:
465
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
466
+ else:
467
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
468
+ if profiles.get(name):
469
+ request_str += f" --profile {profiles.get(name)}"
470
+ if debug:
471
+ print("request_str = ", request_str)
472
+ # Add the constructed URL to the list of URLs
473
+ urls.append(request_str)
474
+ return urls
475
+
476
+
477
+ def log_and_continue(exn):
478
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
479
+ print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
480
+ return True
481
+
482
+
483
+ def is_valid_sample(sample):
484
+ has_json = "json" in sample
485
+ has_audio = "audio" in sample
486
+ is_silent = is_silence(sample["audio"])
487
+ is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
488
+
489
+ return has_json and has_audio and not is_silent and not is_rejected
490
+
491
+ class S3DatasetConfig:
492
+ def __init__(
493
+ self,
494
+ id: str,
495
+ s3_path: str,
496
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
497
+ profile: Optional[str] = None,
498
+ ):
499
+ self.id = id
500
+ self.path = s3_path
501
+ self.custom_metadata_fn = custom_metadata_fn
502
+ self.profile = profile
503
+ self.urls = []
504
+
505
+ def load_data_urls(self):
506
+ self.urls = get_all_s3_urls(
507
+ names=[self.path],
508
+ s3_url_prefix=None,
509
+ recursive=True,
510
+ profiles={self.path: self.profile} if self.profile else {},
511
+ )
512
+
513
+ return self.urls
514
+
515
+ class LocalWebDatasetConfig:
516
+ def __init__(
517
+ self,
518
+ id: str,
519
+ path: str,
520
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
521
+ profile: Optional[str] = None,
522
+ ):
523
+ self.id = id
524
+ self.path = path
525
+ self.custom_metadata_fn = custom_metadata_fn
526
+ self.urls = []
527
+
528
+ def load_data_urls(self):
529
+
530
+ self.urls = fast_scandir(self.path, ["tar"])[1]
531
+
532
+ return self.urls
533
+
534
+ def audio_decoder(key, value):
535
+ # Get file extension from key
536
+ ext = key.split(".")[-1]
537
+
538
+ if ext in AUDIO_KEYS:
539
+ return torchaudio.load(io.BytesIO(value))
540
+ else:
541
+ return None
542
+
543
+ def collation_fn(samples):
544
+ batched = list(zip(*samples))
545
+ result = []
546
+ for b in batched:
547
+ if isinstance(b[0], (int, float)):
548
+ b = np.array(b)
549
+ elif isinstance(b[0], torch.Tensor):
550
+ b = torch.stack(b)
551
+ elif isinstance(b[0], np.ndarray):
552
+ b = np.array(b)
553
+ else:
554
+ b = b
555
+ result.append(b)
556
+ return result
557
+
558
+ class WebDatasetDataLoader():
559
+ def __init__(
560
+ self,
561
+ datasets: List[S3DatasetConfig],
562
+ batch_size,
563
+ sample_size,
564
+ sample_rate=48000,
565
+ num_workers=8,
566
+ epoch_steps=1000,
567
+ random_crop=True,
568
+ force_channels="stereo",
569
+ augment_phase=True,
570
+ **data_loader_kwargs
571
+ ):
572
+
573
+ self.datasets = datasets
574
+
575
+ self.sample_size = sample_size
576
+ self.sample_rate = sample_rate
577
+ self.random_crop = random_crop
578
+ self.force_channels = force_channels
579
+ self.augment_phase = augment_phase
580
+
581
+ urls = [dataset.load_data_urls() for dataset in datasets]
582
+
583
+ # Flatten the list of lists of URLs
584
+ urls = [url for dataset_urls in urls for url in dataset_urls]
585
+
586
+ # Shuffle the urls
587
+ random.shuffle(urls)
588
+
589
+ self.dataset = wds.DataPipeline(
590
+ wds.ResampledShards(urls),
591
+ wds.tarfile_to_samples(handler=log_and_continue),
592
+ wds.decode(audio_decoder, handler=log_and_continue),
593
+ wds.map(self.wds_preprocess, handler=log_and_continue),
594
+ wds.select(is_valid_sample),
595
+ wds.to_tuple("audio", "json", handler=log_and_continue),
596
+ #wds.shuffle(bufsize=1000, initial=5000),
597
+ wds.batched(batch_size, partial=False, collation_fn=collation_fn),
598
+ ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
599
+
600
+ self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
601
+
602
+ def wds_preprocess(self, sample):
603
+
604
+ found_key, rewrite_key = '', ''
605
+ for k, v in sample.items(): # print the all entries in dict
606
+ for akey in AUDIO_KEYS:
607
+ if k.endswith(akey):
608
+ # to rename long/weird key with its simpler counterpart
609
+ found_key, rewrite_key = k, akey
610
+ break
611
+ if '' != found_key:
612
+ break
613
+ if '' == found_key: # got no audio!
614
+ return None # try returning None to tell WebDataset to skip this one
615
+
616
+ audio, in_sr = sample[found_key]
617
+ if in_sr != self.sample_rate:
618
+ resample_tf = T.Resample(in_sr, self.sample_rate)
619
+ audio = resample_tf(audio)
620
+
621
+ if self.sample_size is not None:
622
+ # Pad/crop and get the relative timestamp
623
+ pad_crop = PadCrop_Normalized_T(
624
+ self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
625
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
626
+ audio)
627
+ sample["json"]["seconds_start"] = seconds_start
628
+ sample["json"]["seconds_total"] = seconds_total
629
+ sample["json"]["padding_mask"] = padding_mask
630
+ else:
631
+ t_start, t_end = 0, 1
632
+
633
+ # Check if audio is length zero, initialize to a single zero if so
634
+ if audio.shape[-1] == 0:
635
+ audio = torch.zeros(1, 1)
636
+
637
+ # Make the audio stereo and augment by randomly inverting phase
638
+ augs = torch.nn.Sequential(
639
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
640
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
641
+ PhaseFlipper() if self.augment_phase else torch.nn.Identity()
642
+ )
643
+
644
+ audio = augs(audio)
645
+
646
+ sample["json"]["timestamps"] = (t_start, t_end)
647
+
648
+ if "text" in sample["json"]:
649
+ sample["json"]["prompt"] = sample["json"]["text"]
650
+
651
+ # Check for custom metadata functions
652
+ for dataset in self.datasets:
653
+ if dataset.custom_metadata_fn is None:
654
+ continue
655
+
656
+ if dataset.path in sample["__url__"]:
657
+ custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
658
+ sample["json"].update(custom_metadata)
659
+
660
+ if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
661
+ del sample[found_key]
662
+
663
+ sample["audio"] = audio
664
+
665
+ # Add audio to the metadata as well for conditioning
666
+ sample["json"]["audio"] = audio
667
+
668
+ return sample
669
+
670
+ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, video_fps=5):
671
+
672
+ dataset_type = dataset_config.get("dataset_type", None)
673
+
674
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
675
+
676
+ if audio_channels == 1:
677
+ force_channels = "mono"
678
+ else:
679
+ force_channels = "stereo"
680
+
681
+ if dataset_type == "audio_dir":
682
+
683
+ audio_dir_configs = dataset_config.get("datasets", None)
684
+
685
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
686
+
687
+ configs = []
688
+
689
+ for audio_dir_config in audio_dir_configs:
690
+ audio_dir_path = audio_dir_config.get("path", None)
691
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
692
+
693
+ custom_metadata_fn = None
694
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
695
+
696
+ if custom_metadata_module_path is not None:
697
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
698
+ metadata_module = importlib.util.module_from_spec(spec)
699
+ spec.loader.exec_module(metadata_module)
700
+
701
+ custom_metadata_fn = metadata_module.get_custom_metadata
702
+
703
+ configs.append(
704
+ LocalDatasetConfig(
705
+ id=audio_dir_config["id"],
706
+ path=audio_dir_path,
707
+ custom_metadata_fn=custom_metadata_fn,
708
+ video_fps=video_fps
709
+ )
710
+ )
711
+
712
+ train_set = SampleDataset(
713
+ configs,
714
+ sample_rate=sample_rate,
715
+ sample_size=sample_size,
716
+ random_crop=dataset_config.get("random_crop", True),
717
+ force_channels=force_channels,
718
+ video_fps=video_fps
719
+ )
720
+
721
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
722
+ num_workers=num_workers, persistent_workers=True, pin_memory=False, drop_last=True, collate_fn=collation_fn)
723
+
724
+ elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
725
+ wds_configs = []
726
+
727
+ for wds_config in dataset_config["datasets"]:
728
+
729
+ custom_metadata_fn = None
730
+ custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
731
+
732
+ if custom_metadata_module_path is not None:
733
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
734
+ metadata_module = importlib.util.module_from_spec(spec)
735
+ spec.loader.exec_module(metadata_module)
736
+
737
+ custom_metadata_fn = metadata_module.get_custom_metadata
738
+
739
+ if "s3_path" in wds_config:
740
+
741
+ wds_configs.append(
742
+ S3DatasetConfig(
743
+ id=wds_config["id"],
744
+ s3_path=wds_config["s3_path"],
745
+ custom_metadata_fn=custom_metadata_fn,
746
+ profile=wds_config.get("profile", None),
747
+ )
748
+ )
749
+
750
+ elif "path" in wds_config:
751
+
752
+ wds_configs.append(
753
+ LocalWebDatasetConfig(
754
+ id=wds_config["id"],
755
+ path=wds_config["path"],
756
+ custom_metadata_fn=custom_metadata_fn
757
+ )
758
+ )
759
+
760
+ return WebDatasetDataLoader(
761
+ wds_configs,
762
+ sample_rate=sample_rate,
763
+ sample_size=sample_size,
764
+ batch_size=batch_size,
765
+ random_crop=dataset_config.get("random_crop", True),
766
+ num_workers=num_workers,
767
+ persistent_workers=True,
768
+ force_channels=force_channels,
769
+ epoch_steps=dataset_config.get("epoch_steps", 2000)
770
+ ).data_loader
771
+
772
+
773
+
774
+
775
+ def create_dataloader_from_config_valid(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
776
+
777
+
778
+ dataset_type = dataset_config.get("dataset_type", None)
779
+
780
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
781
+
782
+ if audio_channels == 1:
783
+ force_channels = "mono"
784
+ else:
785
+ force_channels = "stereo"
786
+
787
+ if dataset_type == "audio_dir":
788
+
789
+ audio_dir_configs = dataset_config.get("datasets", None)
790
+
791
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
792
+
793
+ configs = []
794
+
795
+ for audio_dir_config in audio_dir_configs:
796
+ audio_dir_path = audio_dir_config.get("path", None)
797
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
798
+
799
+ custom_metadata_fn = None
800
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
801
+
802
+ if custom_metadata_module_path is not None:
803
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
804
+ metadata_module = importlib.util.module_from_spec(spec)
805
+ spec.loader.exec_module(metadata_module)
806
+
807
+ custom_metadata_fn = metadata_module.get_custom_metadata
808
+
809
+ configs.append(
810
+ LocalDatasetConfig(
811
+ id=audio_dir_config["id"],
812
+ path=audio_dir_path,
813
+ custom_metadata_fn=custom_metadata_fn
814
+ )
815
+ )
816
+
817
+ valid_set = SampleDataset(
818
+ configs,
819
+ sample_rate=sample_rate,
820
+ sample_size=sample_size,
821
+ random_crop=dataset_config.get("random_crop", True),
822
+ force_channels=force_channels
823
+ )
824
+
825
+
826
+ return torch.utils.data.DataLoader(valid_set, batch_size, shuffle=False,
827
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
828
+
829
+ elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
830
+ wds_configs = []
831
+
832
+ for wds_config in dataset_config["datasets"]:
833
+
834
+ custom_metadata_fn = None
835
+ custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
836
+
837
+ if custom_metadata_module_path is not None:
838
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
839
+ metadata_module = importlib.util.module_from_spec(spec)
840
+ spec.loader.exec_module(metadata_module)
841
+
842
+ custom_metadata_fn = metadata_module.get_custom_metadata
843
+
844
+ if "s3_path" in wds_config:
845
+
846
+ wds_configs.append(
847
+ S3DatasetConfig(
848
+ id=wds_config["id"],
849
+ s3_path=wds_config["s3_path"],
850
+ custom_metadata_fn=custom_metadata_fn,
851
+ profile=wds_config.get("profile", None),
852
+ )
853
+ )
854
+
855
+ elif "path" in wds_config:
856
+
857
+ wds_configs.append(
858
+ LocalWebDatasetConfig(
859
+ id=wds_config["id"],
860
+ path=wds_config["path"],
861
+ custom_metadata_fn=custom_metadata_fn
862
+ )
863
+ )
864
+
865
+ return WebDatasetDataLoader(
866
+ wds_configs,
867
+ sample_rate=sample_rate,
868
+ sample_size=sample_size,
869
+ batch_size=batch_size,
870
+ random_crop=dataset_config.get("random_crop", True),
871
+ num_workers=num_workers,
872
+ persistent_workers=True,
873
+ force_channels=force_channels,
874
+ epoch_steps=dataset_config.get("epoch_steps", 2000)
875
+ ).data_loader
876
+
stable_audio_tools/data/utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+
5
+ from torch import nn
6
+ from typing import Tuple
7
+ import os
8
+ import subprocess as sp
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from decord import VideoReader, cpu
12
+
13
+ class PadCrop(nn.Module):
14
+ def __init__(self, n_samples, randomize=True):
15
+ super().__init__()
16
+ self.n_samples = n_samples
17
+ self.randomize = randomize
18
+
19
+ def __call__(self, signal):
20
+ n, s = signal.shape
21
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
22
+ end = start + self.n_samples
23
+ output = signal.new_zeros([n, self.n_samples])
24
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
25
+ return output
26
+
27
+
28
+ class PadCrop_Normalized_T(nn.Module):
29
+
30
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
31
+ super().__init__()
32
+ self.n_samples = n_samples
33
+ self.sample_rate = sample_rate
34
+ self.randomize = randomize
35
+
36
+ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]:
37
+ n_channels, n_samples = source.shape
38
+
39
+ # Calculate the duration of the audio in seconds
40
+ total_duration = n_samples // self.sample_rate
41
+
42
+ # If the audio is shorter than the desired length, pad it
43
+ upper_bound = max(0, n_samples - self.n_samples)
44
+
45
+ # If randomize is False, always start at the beginning of the audio
46
+ offset = 0
47
+
48
+ if self.randomize and n_samples > self.n_samples:
49
+ valid_offsets = [
50
+ i * self.sample_rate for i in range(0, total_duration, 10)
51
+ if i * self.sample_rate + self.n_samples <= n_samples and
52
+ (total_duration <= 20 or total_duration - i >= 15)
53
+ ]
54
+ if valid_offsets:
55
+ offset = random.choice(valid_offsets)
56
+
57
+ # Calculate the start and end times of the chunk
58
+ t_start = offset / (upper_bound + self.n_samples)
59
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
60
+
61
+ # Create the chunk
62
+ chunk = source.new_zeros([n_channels, self.n_samples])
63
+
64
+ # Copy the audio into the chunk
65
+ chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
66
+
67
+ # Calculate the start and end times of the chunk in seconds
68
+ seconds_start = math.floor(offset / self.sample_rate)
69
+ seconds_total = math.ceil(n_samples / self.sample_rate)
70
+
71
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
72
+ padding_mask = torch.zeros([self.n_samples])
73
+ padding_mask[:min(n_samples, self.n_samples)] = 1
74
+
75
+ return (
76
+ chunk,
77
+ t_start,
78
+ t_end,
79
+ seconds_start,
80
+ seconds_total,
81
+ padding_mask
82
+ )
83
+
84
+
85
+ class PhaseFlipper(nn.Module):
86
+ "Randomly invert the phase of a signal"
87
+ def __init__(self, p=0.5):
88
+ super().__init__()
89
+ self.p = p
90
+ def __call__(self, signal):
91
+ return -signal if (random.random() < self.p) else signal
92
+
93
+ class Mono(nn.Module):
94
+ def __call__(self, signal):
95
+ return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
96
+
97
+ class Stereo(nn.Module):
98
+ def __call__(self, signal):
99
+ signal_shape = signal.shape
100
+ # Check if it's mono
101
+ if len(signal_shape) == 1: # s -> 2, s
102
+ signal = signal.unsqueeze(0).repeat(2, 1)
103
+ elif len(signal_shape) == 2:
104
+ if signal_shape[0] == 1: #1, s -> 2, s
105
+ signal = signal.repeat(2, 1)
106
+ elif signal_shape[0] > 2: #?, s -> 2,s
107
+ signal = signal[:2, :]
108
+
109
+ return signal
110
+
111
+
112
+ def adjust_video_duration(video_tensor, duration, target_fps):
113
+ current_duration = video_tensor.shape[0]
114
+ target_duration = duration * target_fps
115
+ if current_duration > target_duration:
116
+ video_tensor = video_tensor[:target_duration]
117
+ elif current_duration < target_duration:
118
+ last_frame = video_tensor[-1:]
119
+ repeat_times = target_duration - current_duration
120
+ video_tensor = torch.cat((video_tensor, last_frame.repeat(repeat_times, 1, 1, 1)), dim=0)
121
+ return video_tensor
122
+
123
+ def read_video(filepath, seek_time=0., duration=-1, target_fps=2):
124
+ if filepath is None:
125
+ return torch.zeros((int(duration * target_fps), 3, 224, 224))
126
+
127
+ ext = os.path.splitext(filepath)[1].lower()
128
+ if ext in ['.jpg', '.jpeg', '.png']:
129
+ resize_transform = transforms.Resize((224, 224))
130
+ image = Image.open(filepath).convert("RGB")
131
+ frame = transforms.ToTensor()(image).unsqueeze(0)
132
+ frame = resize_transform(frame)
133
+ target_frames = int(duration * target_fps)
134
+ frame = frame.repeat(int(math.ceil(target_frames / frame.shape[0])), 1, 1, 1)[:target_frames]
135
+ assert frame.shape[0] == target_frames, f"The shape of frame is {frame.shape}"
136
+ return frame
137
+
138
+ vr = VideoReader(filepath, ctx=cpu(0))
139
+ fps = vr.get_avg_fps()
140
+ total_frames = len(vr)
141
+
142
+ seek_frame = int(seek_time * fps)
143
+ if duration > 0:
144
+ total_frames_to_read = int(target_fps * duration)
145
+ frame_interval = int(math.ceil(fps / target_fps))
146
+ end_frame = min(seek_frame + total_frames_to_read * frame_interval, total_frames)
147
+ frame_ids = list(range(seek_frame, end_frame, frame_interval))
148
+ else:
149
+ frame_interval = int(math.ceil(fps / target_fps))
150
+ frame_ids = list(range(0, total_frames, frame_interval))
151
+
152
+ frames = vr.get_batch(frame_ids).asnumpy()
153
+ frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
154
+
155
+ if frames.shape[2] != 224 or frames.shape[3] != 224:
156
+ resize_transform = transforms.Resize((224, 224))
157
+ frames = resize_transform(frames)
158
+
159
+ video_tensor = adjust_video_duration(frames, duration, target_fps)
160
+ assert video_tensor.shape[0] == duration * target_fps, f"The shape of video_tensor is {video_tensor.shape}"
161
+ return video_tensor
162
+
163
+ def merge_video_audio(video_path, audio_path, output_path, start_time, duration):
164
+ command = [
165
+ 'ffmpeg',
166
+ '-y',
167
+ '-ss', str(start_time),
168
+ '-t', str(duration),
169
+ '-i', video_path,
170
+ '-i', audio_path,
171
+ '-c:v', 'copy',
172
+ '-c:a', 'aac',
173
+ '-map', '0:v:0',
174
+ '-map', '1:a:0',
175
+ '-shortest',
176
+ '-strict', 'experimental',
177
+ output_path
178
+ ]
179
+
180
+ try:
181
+ sp.run(command, check=True)
182
+ print(f"Successfully merged audio and video into {output_path}")
183
+ return output_path
184
+ except sp.CalledProcessError as e:
185
+ print(f"Error merging audio and video: {e}")
186
+ return None
187
+
188
+ def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
189
+ if audio_path is None:
190
+ return torch.zeros((2, int(sample_rate * seconds_total)))
191
+ audio_tensor, sr = torchaudio.load(audio_path)
192
+ start_index = int(sample_rate * seconds_start)
193
+ target_length = int(sample_rate * seconds_total)
194
+ end_index = start_index + target_length
195
+ audio_tensor = audio_tensor[:, start_index:end_index]
196
+ if audio_tensor.shape[1] < target_length:
197
+ pad_length = target_length - audio_tensor.shape[1]
198
+ audio_tensor = F.pad(audio_tensor, (pad_length, 0))
199
+ return audio_tensor
stable_audio_tools/inference/__init__.py ADDED
File without changes
stable_audio_tools/inference/generation.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import typing as tp
4
+ import math
5
+ from torchaudio import transforms as T
6
+
7
+ from .utils import prepare_audio
8
+ from .sampling import sample, sample_k, sample_rf
9
+ from ..data.utils import PadCrop
10
+
11
+ def generate_diffusion_uncond(
12
+ model,
13
+ steps: int = 250,
14
+ batch_size: int = 1,
15
+ sample_size: int = 2097152,
16
+ seed: int = -1,
17
+ device: str = "cuda",
18
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
19
+ init_noise_level: float = 1.0,
20
+ return_latents = False,
21
+ **sampler_kwargs
22
+ ) -> torch.Tensor:
23
+
24
+ # The length of the output in audio samples
25
+ audio_sample_size = sample_size
26
+
27
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
28
+ if model.pretransform is not None:
29
+ sample_size = sample_size // model.pretransform.downsampling_ratio
30
+
31
+ # Seed
32
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
33
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
34
+ # seed = 777
35
+ print(seed)
36
+ torch.manual_seed(seed)
37
+ # Define the initial noise immediately after setting the seed
38
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
39
+
40
+ if init_audio is not None:
41
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
42
+ in_sr, init_audio = init_audio
43
+
44
+ io_channels = model.io_channels
45
+
46
+ # For latent models, set the io_channels to the autoencoder's io_channels
47
+ if model.pretransform is not None:
48
+ io_channels = model.pretransform.io_channels
49
+
50
+ # Prepare the initial audio for use by the model
51
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
52
+
53
+ # For latent models, encode the initial audio into latents
54
+ if model.pretransform is not None:
55
+ init_audio = model.pretransform.encode(init_audio)
56
+
57
+ init_audio = init_audio.repeat(batch_size, 1, 1)
58
+ else:
59
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
60
+ init_audio = None
61
+ init_noise_level = None
62
+
63
+ # Inpainting mask
64
+
65
+ if init_audio is not None:
66
+ # variations
67
+ sampler_kwargs["sigma_max"] = init_noise_level
68
+ mask = None
69
+ else:
70
+ mask = None
71
+
72
+ # Now the generative AI part:
73
+
74
+ diff_objective = model.diffusion_objective
75
+
76
+ if diff_objective == "v":
77
+ # k-diffusion denoising process go!
78
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
79
+ elif diff_objective == "rectified_flow":
80
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
81
+
82
+ # Denoising process done.
83
+ # If this is latent diffusion, decode latents back into audio
84
+ if model.pretransform is not None and not return_latents:
85
+ sampled = model.pretransform.decode(sampled)
86
+
87
+ # Return audio
88
+ return sampled
89
+
90
+
91
+ def generate_diffusion_cond(
92
+ model,
93
+ steps: int = 250,
94
+ cfg_scale=6,
95
+ conditioning: dict = None,
96
+ conditioning_tensors: tp.Optional[dict] = None,
97
+ negative_conditioning: dict = None,
98
+ negative_conditioning_tensors: tp.Optional[dict] = None,
99
+ batch_size: int = 1,
100
+ sample_size: int = 2097152,
101
+ sample_rate: int = 48000,
102
+ seed: int = -1,
103
+ device: str = "cuda",
104
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
105
+ init_noise_level: float = 1.0,
106
+ mask_args: dict = None,
107
+ return_latents = False,
108
+ **sampler_kwargs
109
+ ) -> torch.Tensor:
110
+ """
111
+ Generate audio from a prompt using a diffusion model.
112
+
113
+ Args:
114
+ model: The diffusion model to use for generation.
115
+ steps: The number of diffusion steps to use.
116
+ cfg_scale: Classifier-free guidance scale
117
+ conditioning: A dictionary of conditioning parameters to use for generation.
118
+ conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
119
+ batch_size: The batch size to use for generation.
120
+ sample_size: The length of the audio to generate, in samples.
121
+ sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
122
+ seed: The random seed to use for generation, or -1 to use a random seed.
123
+ device: The device to use for generation.
124
+ init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
125
+ init_noise_level: The noise level to use when generating from an initial audio sample.
126
+ return_latents: Whether to return the latents used for generation instead of the decoded audio.
127
+ **sampler_kwargs: Additional keyword arguments to pass to the sampler.
128
+ """
129
+
130
+ # The length of the output in audio samples
131
+ audio_sample_size = sample_size
132
+
133
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
134
+ if model.pretransform is not None:
135
+ sample_size = sample_size // model.pretransform.downsampling_ratio
136
+
137
+ # Seed
138
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
139
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
140
+ # seed = 777
141
+ # print(seed)
142
+ torch.manual_seed(seed)
143
+ # Define the initial noise immediately after setting the seed
144
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
145
+
146
+ torch.backends.cuda.matmul.allow_tf32 = False
147
+ torch.backends.cudnn.allow_tf32 = False
148
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
149
+ torch.backends.cudnn.benchmark = False
150
+
151
+ # Conditioning
152
+ assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
153
+ if conditioning_tensors is None:
154
+ conditioning_tensors = model.conditioner(conditioning, device)
155
+ conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
156
+
157
+ if negative_conditioning is not None or negative_conditioning_tensors is not None:
158
+
159
+ if negative_conditioning_tensors is None:
160
+ negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
161
+
162
+ negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
163
+ else:
164
+ negative_conditioning_tensors = {}
165
+
166
+ if init_audio is not None:
167
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
168
+ in_sr, init_audio = init_audio
169
+
170
+ io_channels = model.io_channels
171
+
172
+ # For latent models, set the io_channels to the autoencoder's io_channels
173
+ if model.pretransform is not None:
174
+ io_channels = model.pretransform.io_channels
175
+
176
+ # Prepare the initial audio for use by the model
177
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
178
+
179
+ # For latent models, encode the initial audio into latents
180
+ if model.pretransform is not None:
181
+ init_audio = model.pretransform.encode(init_audio)
182
+
183
+ init_audio = init_audio.repeat(batch_size, 1, 1)
184
+ else:
185
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
186
+ init_audio = None
187
+ init_noise_level = None
188
+ mask_args = None
189
+
190
+ # Inpainting mask
191
+ if init_audio is not None and mask_args is not None:
192
+ # Cut and paste init_audio according to cropfrom, pastefrom, pasteto
193
+ # This is helpful for forward and reverse outpainting
194
+ cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
195
+ pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
196
+ pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
197
+ assert pastefrom < pasteto, "Paste From should be less than Paste To"
198
+ croplen = pasteto - pastefrom
199
+ if cropfrom + croplen > sample_size:
200
+ croplen = sample_size - cropfrom
201
+ cropto = cropfrom + croplen
202
+ pasteto = pastefrom + croplen
203
+ cutpaste = init_audio.new_zeros(init_audio.shape)
204
+ cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
205
+ #print(cropfrom, cropto, pastefrom, pasteto)
206
+ init_audio = cutpaste
207
+ # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
208
+ mask = build_mask(sample_size, mask_args)
209
+ mask = mask.to(device)
210
+ elif init_audio is not None and mask_args is None:
211
+ # variations
212
+ sampler_kwargs["sigma_max"] = init_noise_level
213
+ mask = None
214
+ else:
215
+ mask = None
216
+
217
+ model_dtype = next(model.model.parameters()).dtype
218
+ noise = noise.type(model_dtype)
219
+ conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
220
+ # Now the generative AI part:
221
+ # k-diffusion denoising process go!
222
+
223
+ diff_objective = model.diffusion_objective
224
+
225
+ if diff_objective == "v":
226
+ # k-diffusion denoising process go!
227
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
228
+
229
+ elif diff_objective == "rectified_flow":
230
+
231
+ if "sigma_min" in sampler_kwargs:
232
+ del sampler_kwargs["sigma_min"]
233
+
234
+ if "sampler_type" in sampler_kwargs:
235
+ del sampler_kwargs["sampler_type"]
236
+
237
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
238
+
239
+ # v-diffusion:
240
+ del noise
241
+ del conditioning_tensors
242
+ del conditioning_inputs
243
+ torch.cuda.empty_cache()
244
+ # Denoising process done.
245
+ # If this is latent diffusion, decode latents back into audio
246
+
247
+ if model.pretransform is not None and not return_latents:
248
+ #cast sampled latents to pretransform dtype
249
+ sampled = sampled.to(next(model.pretransform.parameters()).dtype)
250
+ sampled = model.pretransform.decode(sampled)
251
+
252
+ return sampled
253
+
254
+ # builds a softmask given the parameters
255
+ # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
256
+ # and anything between is a mixture of old/new
257
+ # ideally 0.5 is half/half mixture but i haven't figured this out yet
258
+ def build_mask(sample_size, mask_args):
259
+ maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
260
+ maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
261
+ softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
262
+ softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
263
+ marination = mask_args["marination"]
264
+ # use hann windows for softening the transition (i don't know if this is correct)
265
+ hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
266
+ hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
267
+ # build the mask.
268
+ mask = torch.zeros((sample_size))
269
+ mask[maskstart:maskend] = 1
270
+ mask[maskstart:maskstart+softnessL] = hannL
271
+ mask[maskend-softnessR:maskend] = hannR
272
+ # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
273
+ if marination > 0:
274
+ mask = mask * (1-marination)
275
+ return mask
stable_audio_tools/inference/sampling.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from tqdm import trange, tqdm
4
+
5
+ import k_diffusion as K
6
+
7
+ # Define the noise schedule and sampling loop
8
+ def get_alphas_sigmas(t):
9
+ """Returns the scaling factors for the clean image (alpha) and for the
10
+ noise (sigma), given a timestep."""
11
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
+
13
+ def alpha_sigma_to_t(alpha, sigma):
14
+ """Returns a timestep, given the scaling factors for the clean image and for
15
+ the noise."""
16
+ return torch.atan2(sigma, alpha) / math.pi * 2
17
+
18
+ def t_to_alpha_sigma(t):
19
+ """Returns the scaling factors for the clean image and for the noise, given
20
+ a timestep."""
21
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
+
23
+
24
+ @torch.no_grad()
25
+ def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
26
+ """Draws samples from a model given starting noise. Euler method"""
27
+
28
+ # Make tensor of ones to broadcast the single t values
29
+ ts = x.new_ones([x.shape[0]])
30
+
31
+ # Create the noise schedule
32
+ t = torch.linspace(sigma_max, 0, steps + 1)
33
+
34
+ #alphas, sigmas = 1-t, t
35
+
36
+ for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
37
+ # Broadcast the current timestep to the correct shape
38
+ t_curr_tensor = t_curr * torch.ones(
39
+ (x.shape[0],), dtype=x.dtype, device=x.device
40
+ )
41
+ dt = t_prev - t_curr # we solve backwards in our formulation
42
+ x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
43
+
44
+ # If we are on the last timestep, output the denoised image
45
+ return x
46
+
47
+ @torch.no_grad()
48
+ def sample(model, x, steps, eta, **extra_args):
49
+ """Draws samples from a model given starting noise. v-diffusion"""
50
+ ts = x.new_ones([x.shape[0]])
51
+
52
+ # Create the noise schedule
53
+ t = torch.linspace(1, 0, steps + 1)[:-1]
54
+
55
+ alphas, sigmas = get_alphas_sigmas(t)
56
+
57
+ # The sampling loop
58
+ for i in trange(steps):
59
+
60
+ # Get the model output (v, the predicted velocity)
61
+ with torch.cuda.amp.autocast():
62
+ v = model(x, ts * t[i], **extra_args).float()
63
+
64
+ # Predict the noise and the denoised image
65
+ pred = x * alphas[i] - v * sigmas[i]
66
+ eps = x * sigmas[i] + v * alphas[i]
67
+
68
+ # If we are not on the last timestep, compute the noisy image for the
69
+ # next timestep.
70
+ if i < steps - 1:
71
+ # If eta > 0, adjust the scaling factor for the predicted noise
72
+ # downward according to the amount of additional noise to add
73
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
74
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
75
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
76
+
77
+ # Recombine the predicted noise and predicted denoised image in the
78
+ # correct proportions for the next step
79
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
80
+
81
+ # Add the correct amount of fresh noise
82
+ if eta:
83
+ x += torch.randn_like(x) * ddim_sigma
84
+
85
+ # If we are on the last timestep, output the denoised image
86
+ return pred
87
+
88
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
89
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
90
+ def get_bmask(i, steps, mask):
91
+ strength = (i+1)/(steps)
92
+ # convert to binary mask
93
+ bmask = torch.where(mask<=strength,1,0)
94
+ return bmask
95
+
96
+ def make_cond_model_fn(model, cond_fn):
97
+ def cond_model_fn(x, sigma, **kwargs):
98
+ with torch.enable_grad():
99
+ x = x.detach().requires_grad_()
100
+ denoised = model(x, sigma, **kwargs)
101
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
102
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
103
+ return cond_denoised
104
+ return cond_model_fn
105
+
106
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
107
+ # init_data is init_audio as latents (if this is latent diffusion)
108
+ # For sampling, set both init_data and mask to None
109
+ # For variations, set init_data
110
+ # For inpainting, set both init_data & mask
111
+ def sample_k(
112
+ model_fn,
113
+ noise,
114
+ init_data=None,
115
+ mask=None,
116
+ steps=100,
117
+ sampler_type="dpmpp-2m-sde",
118
+ sigma_min=0.5,
119
+ sigma_max=50,
120
+ rho=1.0, device="cuda",
121
+ callback=None,
122
+ cond_fn=None,
123
+ **extra_args
124
+ ):
125
+
126
+ denoiser = K.external.VDenoiser(model_fn)
127
+
128
+ if cond_fn is not None:
129
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
130
+
131
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
132
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
133
+ # Scale the initial noise by sigma
134
+ noise = noise * sigmas[0]
135
+
136
+ wrapped_callback = callback
137
+
138
+
139
+ if mask is None and init_data is not None:
140
+ # VARIATION (no inpainting)
141
+ # set the initial latent to the init_data, and noise it with initial sigma
142
+
143
+ x = init_data + noise
144
+
145
+ elif mask is not None and init_data is not None:
146
+ # INPAINTING
147
+ bmask = get_bmask(0, steps, mask)
148
+ # initial noising
149
+ input_noised = init_data + noise
150
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
151
+ x = input_noised * bmask + noise * (1-bmask)
152
+ # define the inpainting callback function (Note: side effects, it mutates x)
153
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
154
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
155
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
156
+ def inpainting_callback(args):
157
+ i = args["i"]
158
+ x = args["x"]
159
+ sigma = args["sigma"]
160
+ #denoised = args["denoised"]
161
+ # noise the init_data input with this step's appropriate amount of noise
162
+ input_noised = init_data + torch.randn_like(init_data) * sigma
163
+ # shrinking hard mask
164
+ bmask = get_bmask(i, steps, mask)
165
+ # mix input_noise with x, using binary mask
166
+ new_x = input_noised * bmask + x * (1-bmask)
167
+ # mutate x
168
+ x[:,:,:] = new_x[:,:,:]
169
+ # wrap together the inpainting callback and the user-submitted callback.
170
+ if callback is None:
171
+ wrapped_callback = inpainting_callback
172
+ else:
173
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
174
+ else:
175
+ # SAMPLING
176
+ # set the initial latent to noise
177
+ x = noise
178
+ # x = noise
179
+
180
+ with torch.cuda.amp.autocast():
181
+ if sampler_type == "k-heun":
182
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
183
+ elif sampler_type == "k-lms":
184
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
185
+ elif sampler_type == "k-dpmpp-2s-ancestral":
186
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
187
+ elif sampler_type == "k-dpm-2":
188
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
189
+ elif sampler_type == "k-dpm-fast":
190
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
191
+ elif sampler_type == "k-dpm-adaptive":
192
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
193
+ elif sampler_type == "dpmpp-2m-sde":
194
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
195
+ elif sampler_type == "dpmpp-3m-sde":
196
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
197
+
198
+ # Uses discrete Euler sampling for rectified flow models
199
+ # init_data is init_audio as latents (if this is latent diffusion)
200
+ # For sampling, set both init_data and mask to None
201
+ # For variations, set init_data
202
+ # For inpainting, set both init_data & mask
203
+ def sample_rf(
204
+ model_fn,
205
+ noise,
206
+ init_data=None,
207
+ steps=100,
208
+ sigma_max=1,
209
+ device="cuda",
210
+ callback=None,
211
+ cond_fn=None,
212
+ **extra_args
213
+ ):
214
+
215
+ if sigma_max > 1:
216
+ sigma_max = 1
217
+
218
+ if cond_fn is not None:
219
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
220
+
221
+ wrapped_callback = callback
222
+
223
+ if init_data is not None:
224
+ # VARIATION (no inpainting)
225
+ # Interpolate the init data and the noise for init audio
226
+ x = init_data * (1 - sigma_max) + noise * sigma_max
227
+ else:
228
+ # SAMPLING
229
+ # set the initial latent to noise
230
+ x = noise
231
+
232
+ with torch.cuda.amp.autocast():
233
+ # TODO: Add callback support
234
+ #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
235
+ return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
stable_audio_tools/inference/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data.utils import PadCrop
2
+
3
+ from torchaudio import transforms as T
4
+
5
+ def set_audio_channels(audio, target_channels):
6
+ if target_channels == 1:
7
+ # Convert to mono
8
+ audio = audio.mean(1, keepdim=True)
9
+ elif target_channels == 2:
10
+ # Convert to stereo
11
+ if audio.shape[1] == 1:
12
+ audio = audio.repeat(1, 2, 1)
13
+ elif audio.shape[1] > 2:
14
+ audio = audio[:, :2, :]
15
+ return audio
16
+
17
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
18
+
19
+ audio = audio.to(device)
20
+
21
+ if in_sr != target_sr:
22
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
23
+ audio = resample_tf(audio)
24
+
25
+ audio = PadCrop(target_length, randomize=False)(audio)
26
+
27
+ # Add batch dimension
28
+ if audio.dim() == 1:
29
+ audio = audio.unsqueeze(0).unsqueeze(0)
30
+ elif audio.dim() == 2:
31
+ audio = audio.unsqueeze(0)
32
+
33
+ audio = set_audio_channels(audio, target_channels)
34
+
35
+ return audio
stable_audio_tools/interface/__init__.py ADDED
File without changes
stable_audio_tools/interface/gradio.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import platform
3
+ import os
4
+ import subprocess as sp
5
+ import gradio as gr
6
+ import json
7
+ import torch
8
+ import torchaudio
9
+
10
+ from aeiou.viz import audio_spectrogram_image
11
+ from einops import rearrange
12
+ from safetensors.torch import load_file
13
+ from torch.nn import functional as F
14
+ from torchaudio import transforms as T
15
+
16
+ from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
17
+ from ..models.factory import create_model_from_config
18
+ from ..models.pretrained import get_pretrained_model
19
+ from ..models.utils import load_ckpt_state_dict
20
+ from ..inference.utils import prepare_audio
21
+ from ..training.utils import copy_state_dict
22
+ from ..data.utils import read_video, merge_video_audio
23
+
24
+
25
+ import os
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
+
28
+ import warnings
29
+ warnings.filterwarnings("ignore", category=UserWarning)
30
+
31
+
32
+ device = torch.device("cpu")
33
+
34
+ os.environ['TMPDIR'] = './tmp'
35
+
36
+ current_model_name = None
37
+ current_model = None
38
+ current_sample_rate = None
39
+ current_sample_size = None
40
+
41
+
42
+
43
+ def load_model(model_name, model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
44
+ global model_configurations
45
+
46
+ if pretrained_name is not None:
47
+ print(f"Loading pretrained model {pretrained_name}")
48
+ model, model_config = get_pretrained_model(pretrained_name)
49
+ elif model_config is not None and model_ckpt_path is not None:
50
+ print(f"Creating model from config")
51
+ model = create_model_from_config(model_config)
52
+ print(f"Loading model checkpoint from {model_ckpt_path}")
53
+ copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
54
+ sample_rate = model_config["sample_rate"]
55
+ sample_size = model_config["sample_size"]
56
+ if pretransform_ckpt_path is not None:
57
+ print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
58
+ model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
59
+ print(f"Done loading pretransform")
60
+ model.to(device).eval().requires_grad_(False)
61
+ if model_half:
62
+ model.to(torch.float16)
63
+ print(f"Done loading model")
64
+ return model, model_config, sample_rate, sample_size
65
+
66
+ def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
67
+ if audio_path is None:
68
+ return torch.zeros((2, int(sample_rate * seconds_total)))
69
+ audio_tensor, sr = torchaudio.load(audio_path)
70
+ start_index = int(sample_rate * seconds_start)
71
+ target_length = int(sample_rate * seconds_total)
72
+ end_index = start_index + target_length
73
+ audio_tensor = audio_tensor[:, start_index:end_index]
74
+ if audio_tensor.shape[1] < target_length:
75
+ pad_length = target_length - audio_tensor.shape[1]
76
+ audio_tensor = F.pad(audio_tensor, (pad_length, 0))
77
+ return audio_tensor
78
+
79
+ def generate_cond(
80
+ prompt,
81
+ negative_prompt=None,
82
+ video_file=None,
83
+ video_path=None,
84
+ audio_prompt_file=None,
85
+ audio_prompt_path=None,
86
+ seconds_start=0,
87
+ seconds_total=10,
88
+ cfg_scale=6.0,
89
+ steps=250,
90
+ preview_every=None,
91
+ seed=-1,
92
+ sampler_type="dpmpp-3m-sde",
93
+ sigma_min=0.03,
94
+ sigma_max=1000,
95
+ cfg_rescale=0.0,
96
+ use_init=False,
97
+ init_audio=None,
98
+ init_noise_level=1.0,
99
+ mask_cropfrom=None,
100
+ mask_pastefrom=None,
101
+ mask_pasteto=None,
102
+ mask_maskstart=None,
103
+ mask_maskend=None,
104
+ mask_softnessL=None,
105
+ mask_softnessR=None,
106
+ mask_marination=None,
107
+ batch_size=1
108
+ ):
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ gc.collect()
112
+ print(f"Prompt: {prompt}")
113
+ preview_images = []
114
+ if preview_every == 0:
115
+ preview_every = None
116
+
117
+ try:
118
+ has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
119
+ except Exception:
120
+ has_mps = False
121
+ if has_mps:
122
+ device = torch.device("mps")
123
+ elif torch.cuda.is_available():
124
+ device = torch.device("cuda")
125
+ else:
126
+ device = torch.device("cpu")
127
+ model_name = 'default'
128
+ cfg = model_configurations[model_name]
129
+ model_config_path = cfg.get("model_config")
130
+ ckpt_path = cfg.get("ckpt_path")
131
+ pretrained_name = cfg.get("pretrained_name")
132
+ pretransform_ckpt_path = cfg.get("pretransform_ckpt_path")
133
+ model_type = cfg.get("model_type", "diffusion_cond")
134
+ if model_config_path:
135
+ with open(model_config_path) as f:
136
+ model_config = json.load(f)
137
+ else:
138
+ model_config = None
139
+ target_fps = model_config.get("video_fps", 5)
140
+ global current_model_name, current_model, current_sample_rate, current_sample_size
141
+ if current_model is None or model_name != current_model_name:
142
+ current_model, model_config, sample_rate, sample_size = load_model(
143
+ model_name=model_name,
144
+ model_config=model_config,
145
+ model_ckpt_path=ckpt_path,
146
+ pretrained_name=pretrained_name,
147
+ pretransform_ckpt_path=pretransform_ckpt_path,
148
+ device=device,
149
+ model_half=False
150
+ )
151
+ current_model_name = model_name
152
+ model = current_model
153
+ current_sample_rate = sample_rate
154
+ current_sample_size = sample_size
155
+ else:
156
+ model = current_model
157
+ sample_rate = current_sample_rate
158
+ sample_size = current_sample_size
159
+ if video_file is not None:
160
+ video_path = video_file.name
161
+ elif video_path:
162
+ video_path = video_path.strip()
163
+ else:
164
+ video_path = None
165
+
166
+ if audio_prompt_file is not None:
167
+ print(f'audio_prompt_file: {audio_prompt_file}')
168
+ audio_path = audio_prompt_file.name
169
+ elif audio_prompt_path:
170
+ audio_path = audio_prompt_path.strip()
171
+ else:
172
+ audio_path = None
173
+
174
+ Video_tensors = read_video(video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps)
175
+ audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
176
+
177
+ audio_tensor = audio_tensor.to(device)
178
+ seconds_input = sample_size / sample_rate
179
+ print(f'video_path: {video_path}')
180
+
181
+ if not prompt:
182
+ prompt = ""
183
+
184
+ conditioning = [{
185
+ "video_prompt": [Video_tensors.unsqueeze(0)],
186
+ "text_prompt": prompt,
187
+ "audio_prompt": audio_tensor.unsqueeze(0),
188
+ "seconds_start": seconds_start,
189
+ "seconds_total": seconds_input
190
+ }] * batch_size
191
+ if negative_prompt:
192
+ negative_conditioning = [{
193
+ "video_prompt": [Video_tensors.unsqueeze(0)],
194
+ "text_prompt": negative_prompt,
195
+ "audio_prompt": audio_tensor.unsqueeze(0),
196
+ "seconds_start": seconds_start,
197
+ "seconds_total": seconds_total
198
+ }] * batch_size
199
+ else:
200
+ negative_conditioning = None
201
+ try:
202
+ device = next(model.parameters()).device
203
+ except Exception as e:
204
+ device = next(current_model.parameters()).device
205
+ seed = int(seed)
206
+ if not use_init:
207
+ init_audio = None
208
+ input_sample_size = sample_size
209
+ if init_audio is not None:
210
+ in_sr, init_audio = init_audio
211
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
212
+ if init_audio.dim() == 1:
213
+ init_audio = init_audio.unsqueeze(0)
214
+ elif init_audio.dim() == 2:
215
+ init_audio = init_audio.transpose(0, 1)
216
+ if in_sr != sample_rate:
217
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
218
+ init_audio = resample_tf(init_audio)
219
+ audio_length = init_audio.shape[-1]
220
+ if audio_length > sample_size:
221
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
222
+ init_audio = (sample_rate, init_audio)
223
+ def progress_callback(callback_info):
224
+ nonlocal preview_images
225
+ denoised = callback_info["denoised"]
226
+ current_step = callback_info["i"]
227
+ sigma = callback_info["sigma"]
228
+ if (current_step - 1) % preview_every == 0:
229
+ if model.pretransform is not None:
230
+ denoised = model.pretransform.decode(denoised)
231
+ denoised = rearrange(denoised, "b d n -> d (b n)")
232
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
233
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
234
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
235
+ if mask_cropfrom is not None:
236
+ mask_args = {
237
+ "cropfrom": mask_cropfrom,
238
+ "pastefrom": mask_pastefrom,
239
+ "pasteto": mask_pasteto,
240
+ "maskstart": mask_maskstart,
241
+ "maskend": mask_maskend,
242
+ "softnessL": mask_softnessL,
243
+ "softnessR": mask_softnessR,
244
+ "marination": mask_marination,
245
+ }
246
+ else:
247
+ mask_args = None
248
+ if model_type == "diffusion_cond":
249
+ audio = generate_diffusion_cond(
250
+ model,
251
+ conditioning=conditioning,
252
+ negative_conditioning=negative_conditioning,
253
+ steps=steps,
254
+ cfg_scale=cfg_scale,
255
+ batch_size=batch_size,
256
+ sample_size=input_sample_size,
257
+ sample_rate=sample_rate,
258
+ seed=seed,
259
+ device=device,
260
+ sampler_type=sampler_type,
261
+ sigma_min=sigma_min,
262
+ sigma_max=sigma_max,
263
+ init_audio=init_audio,
264
+ init_noise_level=init_noise_level,
265
+ mask_args=mask_args,
266
+ callback=progress_callback if preview_every is not None else None,
267
+ scale_phi=cfg_rescale
268
+ )
269
+ elif model_type == "diffusion_uncond":
270
+ audio = generate_diffusion_uncond(
271
+ model,
272
+ steps=steps,
273
+ batch_size=batch_size,
274
+ sample_size=input_sample_size,
275
+ seed=seed,
276
+ device=device,
277
+ sampler_type=sampler_type,
278
+ sigma_min=sigma_min,
279
+ sigma_max=sigma_max,
280
+ init_audio=init_audio,
281
+ init_noise_level=init_noise_level,
282
+ callback=progress_callback if preview_every is not None else None
283
+ )
284
+ else:
285
+ raise ValueError(f"Unsupported model type: {model_type}")
286
+ audio = rearrange(audio, "b d n -> d (b n)")
287
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
288
+ file_name = os.path.basename(video_path) if video_path else "output"
289
+ output_dir = f"demo_result"
290
+ if not os.path.exists(output_dir):
291
+ os.makedirs(output_dir)
292
+ output_video_path = f"{output_dir}/{file_name}"
293
+ torchaudio.save(f"{output_dir}/output.wav", audio, sample_rate)
294
+ if not os.path.exists(output_dir):
295
+ os.makedirs(output_dir)
296
+ if video_path:
297
+ merge_video_audio(video_path, f"{output_dir}/output.wav", output_video_path, seconds_start, seconds_total)
298
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
299
+ del video_path
300
+ torch.cuda.empty_cache()
301
+ gc.collect()
302
+ return (output_video_path, f"{output_dir}/output.wav")
303
+
304
+ def toggle_custom_model(selected_model):
305
+ return gr.Row.update(visible=(selected_model == "Custom Model"))
306
+
307
+ def create_sampling_ui(model_config_map, inpainting=False):
308
+ with gr.Blocks() as demo:
309
+ gr.Markdown(
310
+ """
311
+ # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation
312
+ **[Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/Zeyue7/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)**
313
+ """
314
+ )
315
+
316
+ with gr.Tab("Generation"):
317
+
318
+ with gr.Row():
319
+ with gr.Column():
320
+ prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt")
321
+ negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt", visible=False)
322
+ video_path = gr.Textbox(label="Video Path", placeholder="Enter video file path")
323
+ video_file = gr.File(label="Upload Video File")
324
+ audio_prompt_file = gr.File(label="Upload Audio Prompt File", visible=False)
325
+ audio_prompt_path = gr.Textbox(label="Audio Prompt Path", placeholder="Enter audio file path", visible=False)
326
+ with gr.Row():
327
+ with gr.Column(scale=6):
328
+ with gr.Accordion("Video Params", open=False):
329
+ seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start")
330
+ seconds_total_slider = gr.Slider(minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False)
331
+ with gr.Row():
332
+ with gr.Column(scale=4):
333
+ with gr.Accordion("Sampler Params", open=False):
334
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
335
+ preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
336
+ cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale")
337
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
338
+ sampler_type_dropdown = gr.Dropdown(
339
+ ["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"],
340
+ label="Sampler Type",
341
+ value="dpmpp-3m-sde"
342
+ )
343
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min")
344
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max")
345
+ cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount")
346
+ with gr.Row():
347
+ with gr.Column(scale=4):
348
+ with gr.Accordion("Init Audio", open=False, visible=False):
349
+ init_audio_checkbox = gr.Checkbox(label="Use Init Audio")
350
+ init_audio_input = gr.Audio(label="Init Audio")
351
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level")
352
+ gr.Markdown("## Examples")
353
+ with gr.Accordion("Click to show examples", open=False):
354
+ with gr.Row():
355
+ gr.Markdown("**📝 Task: Text-to-Audio**")
356
+ with gr.Column(scale=1.2):
357
+ gr.Markdown("Prompt: *Typing on a keyboard*")
358
+ ex1 = gr.Button("Load Example")
359
+ with gr.Column(scale=1.2):
360
+ gr.Markdown("Prompt: *Ocean waves crashing*")
361
+ ex2 = gr.Button("Load Example")
362
+ with gr.Column(scale=1.2):
363
+ gr.Markdown("Prompt: *Footsteps in snow*")
364
+ ex3 = gr.Button("Load Example")
365
+ with gr.Row():
366
+ gr.Markdown("**🎶 Task: Text-to-Music**")
367
+ with gr.Column(scale=1.2):
368
+ gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*")
369
+ ex4 = gr.Button("Load Example")
370
+ with gr.Column(scale=1.2):
371
+ gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*")
372
+ ex5 = gr.Button("Load Example")
373
+ with gr.Column(scale=1.2):
374
+ gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*")
375
+ ex6 = gr.Button("Load Example")
376
+ with gr.Row():
377
+ gr.Markdown("**🎬 Task: Video-to-Audio**\nPrompt: *Generate general audio for the video*")
378
+ with gr.Column(scale=1.2):
379
+ gr.Video("example/V2A_sample-1.mp4")
380
+ ex7 = gr.Button("Load Example")
381
+ with gr.Column(scale=1.2):
382
+ gr.Video("example/V2A_sample-2.mp4")
383
+ ex8 = gr.Button("Load Example")
384
+ with gr.Column(scale=1.2):
385
+ gr.Video("example/V2A_sample-3.mp4")
386
+ ex9 = gr.Button("Load Example")
387
+ with gr.Row():
388
+ gr.Markdown("**🎵 Task: Video-to-Music**\nPrompt: *Generate music for the video*")
389
+ with gr.Column(scale=1.2):
390
+ gr.Video("example/V2M_sample-1.mp4")
391
+ ex10 = gr.Button("Load Example")
392
+ with gr.Column(scale=1.2):
393
+ gr.Video("example/V2M_sample-2.mp4")
394
+ ex11 = gr.Button("Load Example")
395
+ with gr.Column(scale=1.2):
396
+ gr.Video("example/V2M_sample-3.mp4")
397
+ ex12 = gr.Button("Load Example")
398
+ with gr.Row():
399
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
400
+ with gr.Row():
401
+ with gr.Column(scale=6):
402
+ video_output = gr.Video(label="Output Video", interactive=False)
403
+ audio_output = gr.Audio(label="Output Audio", interactive=False)
404
+ send_to_init_button = gr.Button("Send to Init Audio", scale=1, visible=False)
405
+ send_to_init_button.click(
406
+ fn=lambda audio: audio,
407
+ inputs=[audio_output],
408
+ outputs=[init_audio_input]
409
+ )
410
+ inputs = [
411
+ prompt,
412
+ negative_prompt,
413
+ video_file,
414
+ video_path,
415
+ audio_prompt_file,
416
+ audio_prompt_path,
417
+ seconds_start_slider,
418
+ seconds_total_slider,
419
+ cfg_scale_slider,
420
+ steps_slider,
421
+ preview_every_slider,
422
+ seed_textbox,
423
+ sampler_type_dropdown,
424
+ sigma_min_slider,
425
+ sigma_max_slider,
426
+ cfg_rescale_slider,
427
+ init_audio_checkbox,
428
+ init_audio_input,
429
+ init_noise_level_slider
430
+ ]
431
+ generate_button.click(
432
+ fn=generate_cond,
433
+ inputs=inputs,
434
+ outputs=[
435
+ video_output,
436
+ audio_output
437
+ ],
438
+ api_name="generate"
439
+ )
440
+ ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
441
+ ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
442
+ ex3.click(lambda: ["Footsteps in snow", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
443
+ ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
444
+ ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
445
+ ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
446
+ ex7.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3737819478", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
447
+ ex8.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "1900718499", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
448
+ ex9.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "2289822202", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
449
+ ex10.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3498087420", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
450
+ ex11.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "3753837734", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
451
+ ex12.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "3510832996", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
452
+ return demo
453
+
454
+ def create_txt2audio_ui(model_config_map):
455
+ with gr.Blocks(css=".gradio-container { max-width: 1120px; margin: auto; }") as ui:
456
+ with gr.Tab("Generation"):
457
+ create_sampling_ui(model_config_map)
458
+ return ui
459
+
460
+ def toggle_custom_model(selected_model):
461
+ return gr.Row.update(visible=(selected_model == "Custom Model"))
462
+
463
+ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
464
+ global model_configurations
465
+ global device
466
+
467
+ try:
468
+ has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
469
+ except Exception:
470
+ has_mps = False
471
+
472
+ if has_mps:
473
+ device = torch.device("mps")
474
+ elif torch.cuda.is_available():
475
+ device = torch.device("cuda")
476
+ else:
477
+ device = torch.device("cpu")
478
+
479
+ print("Using device:", device)
480
+
481
+ model_configurations = {
482
+ "default": {
483
+ "model_config": "./model/config.json",
484
+ "ckpt_path": "./model/model.ckpt"
485
+ }
486
+ }
487
+ ui = create_txt2audio_ui(model_configurations)
488
+ return ui
489
+
490
+ if __name__ == "__main__":
491
+ ui = create_ui(
492
+ model_config_path='./model/config.json',
493
+ share=True
494
+ )
495
+ ui.launch()
stable_audio_tools/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_model_from_config, create_model_from_config_path
stable_audio_tools/models/adp.py ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
2
+ # License can be found in LICENSES/LICENSE_ADP.txt
3
+
4
+ import math
5
+ from inspect import isfunction
6
+ from math import ceil, floor, log, pi, log2
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
8
+ from packaging import version
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, reduce, repeat
13
+ from einops.layers.torch import Rearrange
14
+ from einops_exts import rearrange_many
15
+ from torch import Tensor, einsum
16
+ from torch.backends.cuda import sdp_kernel
17
+ from torch.nn import functional as F
18
+ from dac.nn.layers import Snake1d
19
+
20
+ """
21
+ Utils
22
+ """
23
+
24
+
25
+ class ConditionedSequential(nn.Module):
26
+ def __init__(self, *modules):
27
+ super().__init__()
28
+ self.module_list = nn.ModuleList(*modules)
29
+
30
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
31
+ for module in self.module_list:
32
+ x = module(x, mapping)
33
+ return x
34
+
35
+ T = TypeVar("T")
36
+
37
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+ def exists(val: Optional[T]) -> T:
43
+ return val is not None
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
52
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
53
+ for key in d.keys():
54
+ no_prefix = int(not key.startswith(prefix))
55
+ return_dicts[no_prefix][key] = d[key]
56
+ return return_dicts
57
+
58
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
59
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
60
+ if keep_prefix:
61
+ return kwargs_with_prefix, kwargs
62
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
63
+ return kwargs_no_prefix, kwargs
64
+
65
+ """
66
+ Convolutional Blocks
67
+ """
68
+ import typing as tp
69
+
70
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
71
+ # License available in LICENSES/LICENSE_META.txt
72
+
73
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
74
+ padding_total: int = 0) -> int:
75
+ """See `pad_for_conv1d`."""
76
+ length = x.shape[-1]
77
+ n_frames = (length - kernel_size + padding_total) / stride + 1
78
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
79
+ return ideal_length - length
80
+
81
+
82
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
83
+ """Pad for a convolution to make sure that the last window is full.
84
+ Extra padding is added at the end. This is required to ensure that we can rebuild
85
+ an output of the same length, as otherwise, even with padding, some time steps
86
+ might get removed.
87
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
88
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
89
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
90
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
91
+ 1 2 3 4 # once you removed padding, we are missing one time step !
92
+ """
93
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
94
+ return F.pad(x, (0, extra_padding))
95
+
96
+
97
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
98
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
99
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
100
+ """
101
+ length = x.shape[-1]
102
+ padding_left, padding_right = paddings
103
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
104
+ if mode == 'reflect':
105
+ max_pad = max(padding_left, padding_right)
106
+ extra_pad = 0
107
+ if length <= max_pad:
108
+ extra_pad = max_pad - length + 1
109
+ x = F.pad(x, (0, extra_pad))
110
+ padded = F.pad(x, paddings, mode, value)
111
+ end = padded.shape[-1] - extra_pad
112
+ return padded[..., :end]
113
+ else:
114
+ return F.pad(x, paddings, mode, value)
115
+
116
+
117
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
118
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
119
+ padding_left, padding_right = paddings
120
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
121
+ assert (padding_left + padding_right) <= x.shape[-1]
122
+ end = x.shape[-1] - padding_right
123
+ return x[..., padding_left: end]
124
+
125
+
126
+ class Conv1d(nn.Conv1d):
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+
130
+ def forward(self, x: Tensor, causal=False) -> Tensor:
131
+ kernel_size = self.kernel_size[0]
132
+ stride = self.stride[0]
133
+ dilation = self.dilation[0]
134
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
135
+ padding_total = kernel_size - stride
136
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
137
+ if causal:
138
+ # Left padding for causal
139
+ x = pad1d(x, (padding_total, extra_padding))
140
+ else:
141
+ # Asymmetric padding required for odd strides
142
+ padding_right = padding_total // 2
143
+ padding_left = padding_total - padding_right
144
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
145
+ return super().forward(x)
146
+
147
+ class ConvTranspose1d(nn.ConvTranspose1d):
148
+ def __init__(self, *args, **kwargs):
149
+ super().__init__(*args, **kwargs)
150
+
151
+ def forward(self, x: Tensor, causal=False) -> Tensor:
152
+ kernel_size = self.kernel_size[0]
153
+ stride = self.stride[0]
154
+ padding_total = kernel_size - stride
155
+
156
+ y = super().forward(x)
157
+
158
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
159
+ # removed at the very end, when keeping only the right length for the output,
160
+ # as removing it here would require also passing the length at the matching layer
161
+ # in the encoder.
162
+ if causal:
163
+ padding_right = ceil(padding_total)
164
+ padding_left = padding_total - padding_right
165
+ y = unpad1d(y, (padding_left, padding_right))
166
+ else:
167
+ # Asymmetric padding required for odd strides
168
+ padding_right = padding_total // 2
169
+ padding_left = padding_total - padding_right
170
+ y = unpad1d(y, (padding_left, padding_right))
171
+ return y
172
+
173
+
174
+ def Downsample1d(
175
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
176
+ ) -> nn.Module:
177
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
178
+
179
+ return Conv1d(
180
+ in_channels=in_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=factor * kernel_multiplier + 1,
183
+ stride=factor
184
+ )
185
+
186
+
187
+ def Upsample1d(
188
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
189
+ ) -> nn.Module:
190
+
191
+ if factor == 1:
192
+ return Conv1d(
193
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3
194
+ )
195
+
196
+ if use_nearest:
197
+ return nn.Sequential(
198
+ nn.Upsample(scale_factor=factor, mode="nearest"),
199
+ Conv1d(
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ kernel_size=3
203
+ ),
204
+ )
205
+ else:
206
+ return ConvTranspose1d(
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ kernel_size=factor * 2,
210
+ stride=factor
211
+ )
212
+
213
+
214
+ class ConvBlock1d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ *,
220
+ kernel_size: int = 3,
221
+ stride: int = 1,
222
+ dilation: int = 1,
223
+ num_groups: int = 8,
224
+ use_norm: bool = True,
225
+ use_snake: bool = False
226
+ ) -> None:
227
+ super().__init__()
228
+
229
+ self.groupnorm = (
230
+ nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
231
+ if use_norm
232
+ else nn.Identity()
233
+ )
234
+
235
+ if use_snake:
236
+ self.activation = Snake1d(in_channels)
237
+ else:
238
+ self.activation = nn.SiLU()
239
+
240
+ self.project = Conv1d(
241
+ in_channels=in_channels,
242
+ out_channels=out_channels,
243
+ kernel_size=kernel_size,
244
+ stride=stride,
245
+ dilation=dilation,
246
+ )
247
+
248
+ def forward(
249
+ self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
250
+ ) -> Tensor:
251
+ x = self.groupnorm(x)
252
+ if exists(scale_shift):
253
+ scale, shift = scale_shift
254
+ x = x * (scale + 1) + shift
255
+ x = self.activation(x)
256
+ return self.project(x, causal=causal)
257
+
258
+
259
+ class MappingToScaleShift(nn.Module):
260
+ def __init__(
261
+ self,
262
+ features: int,
263
+ channels: int,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.to_scale_shift = nn.Sequential(
268
+ nn.SiLU(),
269
+ nn.Linear(in_features=features, out_features=channels * 2),
270
+ )
271
+
272
+ def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
273
+ scale_shift = self.to_scale_shift(mapping)
274
+ scale_shift = rearrange(scale_shift, "b c -> b c 1")
275
+ scale, shift = scale_shift.chunk(2, dim=1)
276
+ return scale, shift
277
+
278
+
279
+ class ResnetBlock1d(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_channels: int,
283
+ out_channels: int,
284
+ *,
285
+ kernel_size: int = 3,
286
+ stride: int = 1,
287
+ dilation: int = 1,
288
+ use_norm: bool = True,
289
+ use_snake: bool = False,
290
+ num_groups: int = 8,
291
+ context_mapping_features: Optional[int] = None,
292
+ ) -> None:
293
+ super().__init__()
294
+
295
+ self.use_mapping = exists(context_mapping_features)
296
+
297
+ self.block1 = ConvBlock1d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=kernel_size,
301
+ stride=stride,
302
+ dilation=dilation,
303
+ use_norm=use_norm,
304
+ num_groups=num_groups,
305
+ use_snake=use_snake
306
+ )
307
+
308
+ if self.use_mapping:
309
+ assert exists(context_mapping_features)
310
+ self.to_scale_shift = MappingToScaleShift(
311
+ features=context_mapping_features, channels=out_channels
312
+ )
313
+
314
+ self.block2 = ConvBlock1d(
315
+ in_channels=out_channels,
316
+ out_channels=out_channels,
317
+ use_norm=use_norm,
318
+ num_groups=num_groups,
319
+ use_snake=use_snake
320
+ )
321
+
322
+ self.to_out = (
323
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
324
+ if in_channels != out_channels
325
+ else nn.Identity()
326
+ )
327
+
328
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
329
+ assert_message = "context mapping required if context_mapping_features > 0"
330
+ assert not (self.use_mapping ^ exists(mapping)), assert_message
331
+
332
+ h = self.block1(x, causal=causal)
333
+
334
+ scale_shift = None
335
+ if self.use_mapping:
336
+ scale_shift = self.to_scale_shift(mapping)
337
+
338
+ h = self.block2(h, scale_shift=scale_shift, causal=causal)
339
+
340
+ return h + self.to_out(x)
341
+
342
+
343
+ class Patcher(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_channels: int,
347
+ out_channels: int,
348
+ patch_size: int,
349
+ context_mapping_features: Optional[int] = None,
350
+ use_snake: bool = False,
351
+ ):
352
+ super().__init__()
353
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
354
+ assert out_channels % patch_size == 0, assert_message
355
+ self.patch_size = patch_size
356
+
357
+ self.block = ResnetBlock1d(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels // patch_size,
360
+ num_groups=1,
361
+ context_mapping_features=context_mapping_features,
362
+ use_snake=use_snake
363
+ )
364
+
365
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
366
+ x = self.block(x, mapping, causal=causal)
367
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
368
+ return x
369
+
370
+
371
+ class Unpatcher(nn.Module):
372
+ def __init__(
373
+ self,
374
+ in_channels: int,
375
+ out_channels: int,
376
+ patch_size: int,
377
+ context_mapping_features: Optional[int] = None,
378
+ use_snake: bool = False
379
+ ):
380
+ super().__init__()
381
+ assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
382
+ assert in_channels % patch_size == 0, assert_message
383
+ self.patch_size = patch_size
384
+
385
+ self.block = ResnetBlock1d(
386
+ in_channels=in_channels // patch_size,
387
+ out_channels=out_channels,
388
+ num_groups=1,
389
+ context_mapping_features=context_mapping_features,
390
+ use_snake=use_snake
391
+ )
392
+
393
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
394
+ x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
395
+ x = self.block(x, mapping, causal=causal)
396
+ return x
397
+
398
+
399
+ """
400
+ Attention Components
401
+ """
402
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
403
+ mid_features = features * multiplier
404
+ return nn.Sequential(
405
+ nn.Linear(in_features=features, out_features=mid_features),
406
+ nn.GELU(),
407
+ nn.Linear(in_features=mid_features, out_features=features),
408
+ )
409
+
410
+ def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
411
+ b, ndim = sim.shape[0], mask.ndim
412
+ if ndim == 3:
413
+ mask = rearrange(mask, "b n m -> b 1 n m")
414
+ if ndim == 2:
415
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
416
+ max_neg_value = -torch.finfo(sim.dtype).max
417
+ sim = sim.masked_fill(~mask, max_neg_value)
418
+ return sim
419
+
420
+ def causal_mask(q: Tensor, k: Tensor) -> Tensor:
421
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
422
+ mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
423
+ mask = repeat(mask, "n m -> b n m", b=b)
424
+ return mask
425
+
426
+ class AttentionBase(nn.Module):
427
+ def __init__(
428
+ self,
429
+ features: int,
430
+ *,
431
+ head_features: int,
432
+ num_heads: int,
433
+ out_features: Optional[int] = None,
434
+ ):
435
+ super().__init__()
436
+ self.scale = head_features**-0.5
437
+ self.num_heads = num_heads
438
+ mid_features = head_features * num_heads
439
+ out_features = default(out_features, features)
440
+
441
+ self.to_out = nn.Linear(
442
+ in_features=mid_features, out_features=out_features
443
+ )
444
+
445
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
446
+
447
+ if not self.use_flash:
448
+ return
449
+
450
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
451
+
452
+ if device_properties.major == 8 and device_properties.minor == 0:
453
+ # Use flash attention for A100 GPUs
454
+ self.sdp_kernel_config = (True, False, False)
455
+ else:
456
+ # Don't use flash attention for other GPUs
457
+ self.sdp_kernel_config = (False, True, True)
458
+
459
+ def forward(
460
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
461
+ ) -> Tensor:
462
+ # Split heads
463
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
464
+
465
+ if not self.use_flash:
466
+ if is_causal and not mask:
467
+ # Mask out future tokens for causal attention
468
+ mask = causal_mask(q, k)
469
+
470
+ # Compute similarity matrix and add eventual mask
471
+ sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
472
+ sim = add_mask(sim, mask) if exists(mask) else sim
473
+
474
+ # Get attention matrix with softmax
475
+ attn = sim.softmax(dim=-1, dtype=torch.float32)
476
+
477
+ # Compute values
478
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
479
+ else:
480
+ with sdp_kernel(*self.sdp_kernel_config):
481
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
482
+
483
+ out = rearrange(out, "b h n d -> b n (h d)")
484
+ return self.to_out(out)
485
+
486
+ class Attention(nn.Module):
487
+ def __init__(
488
+ self,
489
+ features: int,
490
+ *,
491
+ head_features: int,
492
+ num_heads: int,
493
+ out_features: Optional[int] = None,
494
+ context_features: Optional[int] = None,
495
+ causal: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.context_features = context_features
499
+ self.causal = causal
500
+ mid_features = head_features * num_heads
501
+ context_features = default(context_features, features)
502
+
503
+ self.norm = nn.LayerNorm(features)
504
+ self.norm_context = nn.LayerNorm(context_features)
505
+ self.to_q = nn.Linear(
506
+ in_features=features, out_features=mid_features, bias=False
507
+ )
508
+ self.to_kv = nn.Linear(
509
+ in_features=context_features, out_features=mid_features * 2, bias=False
510
+ )
511
+ self.attention = AttentionBase(
512
+ features,
513
+ num_heads=num_heads,
514
+ head_features=head_features,
515
+ out_features=out_features,
516
+ )
517
+
518
+ def forward(
519
+ self,
520
+ x: Tensor, # [b, n, c]
521
+ context: Optional[Tensor] = None, # [b, m, d]
522
+ context_mask: Optional[Tensor] = None, # [b, m], false is masked,
523
+ causal: Optional[bool] = False,
524
+ ) -> Tensor:
525
+ assert_message = "You must provide a context when using context_features"
526
+ assert not self.context_features or exists(context), assert_message
527
+ # Use context if provided
528
+ context = default(context, x)
529
+ # Normalize then compute q from input and k,v from context
530
+ x, context = self.norm(x), self.norm_context(context)
531
+
532
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
533
+
534
+ if exists(context_mask):
535
+ # Mask out cross-attention for padding tokens
536
+ mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
537
+ k, v = k * mask, v * mask
538
+
539
+ # Compute and return attention
540
+ return self.attention(q, k, v, is_causal=self.causal or causal)
541
+
542
+
543
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
544
+ mid_features = features * multiplier
545
+ return nn.Sequential(
546
+ nn.Linear(in_features=features, out_features=mid_features),
547
+ nn.GELU(),
548
+ nn.Linear(in_features=mid_features, out_features=features),
549
+ )
550
+
551
+ """
552
+ Transformer Blocks
553
+ """
554
+
555
+
556
+ class TransformerBlock(nn.Module):
557
+ def __init__(
558
+ self,
559
+ features: int,
560
+ num_heads: int,
561
+ head_features: int,
562
+ multiplier: int,
563
+ context_features: Optional[int] = None,
564
+ ):
565
+ super().__init__()
566
+
567
+ self.use_cross_attention = exists(context_features) and context_features > 0
568
+
569
+ self.attention = Attention(
570
+ features=features,
571
+ num_heads=num_heads,
572
+ head_features=head_features
573
+ )
574
+
575
+ if self.use_cross_attention:
576
+ self.cross_attention = Attention(
577
+ features=features,
578
+ num_heads=num_heads,
579
+ head_features=head_features,
580
+ context_features=context_features
581
+ )
582
+
583
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
584
+
585
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
586
+ x = self.attention(x, causal=causal) + x
587
+ if self.use_cross_attention:
588
+ x = self.cross_attention(x, context=context, context_mask=context_mask) + x
589
+ x = self.feed_forward(x) + x
590
+ return x
591
+
592
+
593
+ """
594
+ Transformers
595
+ """
596
+
597
+
598
+ class Transformer1d(nn.Module):
599
+ def __init__(
600
+ self,
601
+ num_layers: int,
602
+ channels: int,
603
+ num_heads: int,
604
+ head_features: int,
605
+ multiplier: int,
606
+ context_features: Optional[int] = None,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.to_in = nn.Sequential(
611
+ nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
612
+ Conv1d(
613
+ in_channels=channels,
614
+ out_channels=channels,
615
+ kernel_size=1,
616
+ ),
617
+ Rearrange("b c t -> b t c"),
618
+ )
619
+
620
+ self.blocks = nn.ModuleList(
621
+ [
622
+ TransformerBlock(
623
+ features=channels,
624
+ head_features=head_features,
625
+ num_heads=num_heads,
626
+ multiplier=multiplier,
627
+ context_features=context_features,
628
+ )
629
+ for i in range(num_layers)
630
+ ]
631
+ )
632
+
633
+ self.to_out = nn.Sequential(
634
+ Rearrange("b t c -> b c t"),
635
+ Conv1d(
636
+ in_channels=channels,
637
+ out_channels=channels,
638
+ kernel_size=1,
639
+ ),
640
+ )
641
+
642
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
643
+ x = self.to_in(x)
644
+ for block in self.blocks:
645
+ x = block(x, context=context, context_mask=context_mask, causal=causal)
646
+ x = self.to_out(x)
647
+ return x
648
+
649
+
650
+ """
651
+ Time Embeddings
652
+ """
653
+
654
+
655
+ class SinusoidalEmbedding(nn.Module):
656
+ def __init__(self, dim: int):
657
+ super().__init__()
658
+ self.dim = dim
659
+
660
+ def forward(self, x: Tensor) -> Tensor:
661
+ device, half_dim = x.device, self.dim // 2
662
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
663
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
664
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
665
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
666
+
667
+
668
+ class LearnedPositionalEmbedding(nn.Module):
669
+ """Used for continuous time"""
670
+
671
+ def __init__(self, dim: int):
672
+ super().__init__()
673
+ assert (dim % 2) == 0
674
+ half_dim = dim // 2
675
+ self.weights = nn.Parameter(torch.randn(half_dim))
676
+
677
+ def forward(self, x: Tensor) -> Tensor:
678
+ x = rearrange(x, "b -> b 1")
679
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
680
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
681
+ fouriered = torch.cat((x, fouriered), dim=-1)
682
+ return fouriered
683
+
684
+
685
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
686
+ return nn.Sequential(
687
+ LearnedPositionalEmbedding(dim),
688
+ nn.Linear(in_features=dim + 1, out_features=out_features),
689
+ )
690
+
691
+
692
+ """
693
+ Encoder/Decoder Components
694
+ """
695
+
696
+
697
+ class DownsampleBlock1d(nn.Module):
698
+ def __init__(
699
+ self,
700
+ in_channels: int,
701
+ out_channels: int,
702
+ *,
703
+ factor: int,
704
+ num_groups: int,
705
+ num_layers: int,
706
+ kernel_multiplier: int = 2,
707
+ use_pre_downsample: bool = True,
708
+ use_skip: bool = False,
709
+ use_snake: bool = False,
710
+ extract_channels: int = 0,
711
+ context_channels: int = 0,
712
+ num_transformer_blocks: int = 0,
713
+ attention_heads: Optional[int] = None,
714
+ attention_features: Optional[int] = None,
715
+ attention_multiplier: Optional[int] = None,
716
+ context_mapping_features: Optional[int] = None,
717
+ context_embedding_features: Optional[int] = None,
718
+ ):
719
+ super().__init__()
720
+ self.use_pre_downsample = use_pre_downsample
721
+ self.use_skip = use_skip
722
+ self.use_transformer = num_transformer_blocks > 0
723
+ self.use_extract = extract_channels > 0
724
+ self.use_context = context_channels > 0
725
+
726
+ channels = out_channels if use_pre_downsample else in_channels
727
+
728
+ self.downsample = Downsample1d(
729
+ in_channels=in_channels,
730
+ out_channels=out_channels,
731
+ factor=factor,
732
+ kernel_multiplier=kernel_multiplier,
733
+ )
734
+
735
+ self.blocks = nn.ModuleList(
736
+ [
737
+ ResnetBlock1d(
738
+ in_channels=channels + context_channels if i == 0 else channels,
739
+ out_channels=channels,
740
+ num_groups=num_groups,
741
+ context_mapping_features=context_mapping_features,
742
+ use_snake=use_snake
743
+ )
744
+ for i in range(num_layers)
745
+ ]
746
+ )
747
+
748
+ if self.use_transformer:
749
+ assert (
750
+ (exists(attention_heads) or exists(attention_features))
751
+ and exists(attention_multiplier)
752
+ )
753
+
754
+ if attention_features is None and attention_heads is not None:
755
+ attention_features = channels // attention_heads
756
+
757
+ if attention_heads is None and attention_features is not None:
758
+ attention_heads = channels // attention_features
759
+
760
+ self.transformer = Transformer1d(
761
+ num_layers=num_transformer_blocks,
762
+ channels=channels,
763
+ num_heads=attention_heads,
764
+ head_features=attention_features,
765
+ multiplier=attention_multiplier,
766
+ context_features=context_embedding_features
767
+ )
768
+
769
+ if self.use_extract:
770
+ num_extract_groups = min(num_groups, extract_channels)
771
+ self.to_extracted = ResnetBlock1d(
772
+ in_channels=out_channels,
773
+ out_channels=extract_channels,
774
+ num_groups=num_extract_groups,
775
+ use_snake=use_snake
776
+ )
777
+
778
+ def forward(
779
+ self,
780
+ x: Tensor,
781
+ *,
782
+ mapping: Optional[Tensor] = None,
783
+ channels: Optional[Tensor] = None,
784
+ embedding: Optional[Tensor] = None,
785
+ embedding_mask: Optional[Tensor] = None,
786
+ causal: Optional[bool] = False
787
+ ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
788
+
789
+ if self.use_pre_downsample:
790
+ x = self.downsample(x)
791
+
792
+ if self.use_context and exists(channels):
793
+ x = torch.cat([x, channels], dim=1)
794
+
795
+ skips = []
796
+ for block in self.blocks:
797
+ x = block(x, mapping=mapping, causal=causal)
798
+ skips += [x] if self.use_skip else []
799
+
800
+ if self.use_transformer:
801
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
802
+ skips += [x] if self.use_skip else []
803
+
804
+ if not self.use_pre_downsample:
805
+ x = self.downsample(x)
806
+
807
+ if self.use_extract:
808
+ extracted = self.to_extracted(x)
809
+ return x, extracted
810
+
811
+ return (x, skips) if self.use_skip else x
812
+
813
+
814
+ class UpsampleBlock1d(nn.Module):
815
+ def __init__(
816
+ self,
817
+ in_channels: int,
818
+ out_channels: int,
819
+ *,
820
+ factor: int,
821
+ num_layers: int,
822
+ num_groups: int,
823
+ use_nearest: bool = False,
824
+ use_pre_upsample: bool = False,
825
+ use_skip: bool = False,
826
+ use_snake: bool = False,
827
+ skip_channels: int = 0,
828
+ use_skip_scale: bool = False,
829
+ extract_channels: int = 0,
830
+ num_transformer_blocks: int = 0,
831
+ attention_heads: Optional[int] = None,
832
+ attention_features: Optional[int] = None,
833
+ attention_multiplier: Optional[int] = None,
834
+ context_mapping_features: Optional[int] = None,
835
+ context_embedding_features: Optional[int] = None,
836
+ ):
837
+ super().__init__()
838
+
839
+ self.use_extract = extract_channels > 0
840
+ self.use_pre_upsample = use_pre_upsample
841
+ self.use_transformer = num_transformer_blocks > 0
842
+ self.use_skip = use_skip
843
+ self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
844
+
845
+ channels = out_channels if use_pre_upsample else in_channels
846
+
847
+ self.blocks = nn.ModuleList(
848
+ [
849
+ ResnetBlock1d(
850
+ in_channels=channels + skip_channels,
851
+ out_channels=channels,
852
+ num_groups=num_groups,
853
+ context_mapping_features=context_mapping_features,
854
+ use_snake=use_snake
855
+ )
856
+ for _ in range(num_layers)
857
+ ]
858
+ )
859
+
860
+ if self.use_transformer:
861
+ assert (
862
+ (exists(attention_heads) or exists(attention_features))
863
+ and exists(attention_multiplier)
864
+ )
865
+
866
+ if attention_features is None and attention_heads is not None:
867
+ attention_features = channels // attention_heads
868
+
869
+ if attention_heads is None and attention_features is not None:
870
+ attention_heads = channels // attention_features
871
+
872
+ self.transformer = Transformer1d(
873
+ num_layers=num_transformer_blocks,
874
+ channels=channels,
875
+ num_heads=attention_heads,
876
+ head_features=attention_features,
877
+ multiplier=attention_multiplier,
878
+ context_features=context_embedding_features,
879
+ )
880
+
881
+ self.upsample = Upsample1d(
882
+ in_channels=in_channels,
883
+ out_channels=out_channels,
884
+ factor=factor,
885
+ use_nearest=use_nearest,
886
+ )
887
+
888
+ if self.use_extract:
889
+ num_extract_groups = min(num_groups, extract_channels)
890
+ self.to_extracted = ResnetBlock1d(
891
+ in_channels=out_channels,
892
+ out_channels=extract_channels,
893
+ num_groups=num_extract_groups,
894
+ use_snake=use_snake
895
+ )
896
+
897
+ def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
898
+ return torch.cat([x, skip * self.skip_scale], dim=1)
899
+
900
+ def forward(
901
+ self,
902
+ x: Tensor,
903
+ *,
904
+ skips: Optional[List[Tensor]] = None,
905
+ mapping: Optional[Tensor] = None,
906
+ embedding: Optional[Tensor] = None,
907
+ embedding_mask: Optional[Tensor] = None,
908
+ causal: Optional[bool] = False
909
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
910
+
911
+ if self.use_pre_upsample:
912
+ x = self.upsample(x)
913
+
914
+ for block in self.blocks:
915
+ x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
916
+ x = block(x, mapping=mapping, causal=causal)
917
+
918
+ if self.use_transformer:
919
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
920
+
921
+ if not self.use_pre_upsample:
922
+ x = self.upsample(x)
923
+
924
+ if self.use_extract:
925
+ extracted = self.to_extracted(x)
926
+ return x, extracted
927
+
928
+ return x
929
+
930
+
931
+ class BottleneckBlock1d(nn.Module):
932
+ def __init__(
933
+ self,
934
+ channels: int,
935
+ *,
936
+ num_groups: int,
937
+ num_transformer_blocks: int = 0,
938
+ attention_heads: Optional[int] = None,
939
+ attention_features: Optional[int] = None,
940
+ attention_multiplier: Optional[int] = None,
941
+ context_mapping_features: Optional[int] = None,
942
+ context_embedding_features: Optional[int] = None,
943
+ use_snake: bool = False,
944
+ ):
945
+ super().__init__()
946
+ self.use_transformer = num_transformer_blocks > 0
947
+
948
+ self.pre_block = ResnetBlock1d(
949
+ in_channels=channels,
950
+ out_channels=channels,
951
+ num_groups=num_groups,
952
+ context_mapping_features=context_mapping_features,
953
+ use_snake=use_snake
954
+ )
955
+
956
+ if self.use_transformer:
957
+ assert (
958
+ (exists(attention_heads) or exists(attention_features))
959
+ and exists(attention_multiplier)
960
+ )
961
+
962
+ if attention_features is None and attention_heads is not None:
963
+ attention_features = channels // attention_heads
964
+
965
+ if attention_heads is None and attention_features is not None:
966
+ attention_heads = channels // attention_features
967
+
968
+ self.transformer = Transformer1d(
969
+ num_layers=num_transformer_blocks,
970
+ channels=channels,
971
+ num_heads=attention_heads,
972
+ head_features=attention_features,
973
+ multiplier=attention_multiplier,
974
+ context_features=context_embedding_features,
975
+ )
976
+
977
+ self.post_block = ResnetBlock1d(
978
+ in_channels=channels,
979
+ out_channels=channels,
980
+ num_groups=num_groups,
981
+ context_mapping_features=context_mapping_features,
982
+ use_snake=use_snake
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x: Tensor,
988
+ *,
989
+ mapping: Optional[Tensor] = None,
990
+ embedding: Optional[Tensor] = None,
991
+ embedding_mask: Optional[Tensor] = None,
992
+ causal: Optional[bool] = False
993
+ ) -> Tensor:
994
+ x = self.pre_block(x, mapping=mapping, causal=causal)
995
+ if self.use_transformer:
996
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
997
+ x = self.post_block(x, mapping=mapping, causal=causal)
998
+ return x
999
+
1000
+
1001
+ """
1002
+ UNet
1003
+ """
1004
+
1005
+
1006
+ class UNet1d(nn.Module):
1007
+ def __init__(
1008
+ self,
1009
+ in_channels: int,
1010
+ channels: int,
1011
+ multipliers: Sequence[int],
1012
+ factors: Sequence[int],
1013
+ num_blocks: Sequence[int],
1014
+ attentions: Sequence[int],
1015
+ patch_size: int = 1,
1016
+ resnet_groups: int = 8,
1017
+ use_context_time: bool = True,
1018
+ kernel_multiplier_downsample: int = 2,
1019
+ use_nearest_upsample: bool = False,
1020
+ use_skip_scale: bool = True,
1021
+ use_snake: bool = False,
1022
+ use_stft: bool = False,
1023
+ use_stft_context: bool = False,
1024
+ out_channels: Optional[int] = None,
1025
+ context_features: Optional[int] = None,
1026
+ context_features_multiplier: int = 4,
1027
+ context_channels: Optional[Sequence[int]] = None,
1028
+ context_embedding_features: Optional[int] = None,
1029
+ **kwargs,
1030
+ ):
1031
+ super().__init__()
1032
+ out_channels = default(out_channels, in_channels)
1033
+ context_channels = list(default(context_channels, []))
1034
+ num_layers = len(multipliers) - 1
1035
+ use_context_features = exists(context_features)
1036
+ use_context_channels = len(context_channels) > 0
1037
+ context_mapping_features = None
1038
+
1039
+ attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
1040
+
1041
+ self.num_layers = num_layers
1042
+ self.use_context_time = use_context_time
1043
+ self.use_context_features = use_context_features
1044
+ self.use_context_channels = use_context_channels
1045
+ self.use_stft = use_stft
1046
+ self.use_stft_context = use_stft_context
1047
+
1048
+ self.context_features = context_features
1049
+ context_channels_pad_length = num_layers + 1 - len(context_channels)
1050
+ context_channels = context_channels + [0] * context_channels_pad_length
1051
+ self.context_channels = context_channels
1052
+ self.context_embedding_features = context_embedding_features
1053
+
1054
+ if use_context_channels:
1055
+ has_context = [c > 0 for c in context_channels]
1056
+ self.has_context = has_context
1057
+ self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
1058
+
1059
+ assert (
1060
+ len(factors) == num_layers
1061
+ and len(attentions) >= num_layers
1062
+ and len(num_blocks) == num_layers
1063
+ )
1064
+
1065
+ if use_context_time or use_context_features:
1066
+ context_mapping_features = channels * context_features_multiplier
1067
+
1068
+ self.to_mapping = nn.Sequential(
1069
+ nn.Linear(context_mapping_features, context_mapping_features),
1070
+ nn.GELU(),
1071
+ nn.Linear(context_mapping_features, context_mapping_features),
1072
+ nn.GELU(),
1073
+ )
1074
+
1075
+ if use_context_time:
1076
+ assert exists(context_mapping_features)
1077
+ self.to_time = nn.Sequential(
1078
+ TimePositionalEmbedding(
1079
+ dim=channels, out_features=context_mapping_features
1080
+ ),
1081
+ nn.GELU(),
1082
+ )
1083
+
1084
+ if use_context_features:
1085
+ assert exists(context_features) and exists(context_mapping_features)
1086
+ self.to_features = nn.Sequential(
1087
+ nn.Linear(
1088
+ in_features=context_features, out_features=context_mapping_features
1089
+ ),
1090
+ nn.GELU(),
1091
+ )
1092
+
1093
+ if use_stft:
1094
+ stft_kwargs, kwargs = groupby("stft_", kwargs)
1095
+ assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1096
+ stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097
+ in_channels *= stft_channels
1098
+ out_channels *= stft_channels
1099
+ context_channels[0] *= stft_channels if use_stft_context else 1
1100
+ assert exists(in_channels) and exists(out_channels)
1101
+ self.stft = STFT(**stft_kwargs)
1102
+
1103
+ assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1104
+
1105
+ self.to_in = Patcher(
1106
+ in_channels=in_channels + context_channels[0],
1107
+ out_channels=channels * multipliers[0],
1108
+ patch_size=patch_size,
1109
+ context_mapping_features=context_mapping_features,
1110
+ use_snake=use_snake
1111
+ )
1112
+
1113
+ self.downsamples = nn.ModuleList(
1114
+ [
1115
+ DownsampleBlock1d(
1116
+ in_channels=channels * multipliers[i],
1117
+ out_channels=channels * multipliers[i + 1],
1118
+ context_mapping_features=context_mapping_features,
1119
+ context_channels=context_channels[i + 1],
1120
+ context_embedding_features=context_embedding_features,
1121
+ num_layers=num_blocks[i],
1122
+ factor=factors[i],
1123
+ kernel_multiplier=kernel_multiplier_downsample,
1124
+ num_groups=resnet_groups,
1125
+ use_pre_downsample=True,
1126
+ use_skip=True,
1127
+ use_snake=use_snake,
1128
+ num_transformer_blocks=attentions[i],
1129
+ **attention_kwargs,
1130
+ )
1131
+ for i in range(num_layers)
1132
+ ]
1133
+ )
1134
+
1135
+ self.bottleneck = BottleneckBlock1d(
1136
+ channels=channels * multipliers[-1],
1137
+ context_mapping_features=context_mapping_features,
1138
+ context_embedding_features=context_embedding_features,
1139
+ num_groups=resnet_groups,
1140
+ num_transformer_blocks=attentions[-1],
1141
+ use_snake=use_snake,
1142
+ **attention_kwargs,
1143
+ )
1144
+
1145
+ self.upsamples = nn.ModuleList(
1146
+ [
1147
+ UpsampleBlock1d(
1148
+ in_channels=channels * multipliers[i + 1],
1149
+ out_channels=channels * multipliers[i],
1150
+ context_mapping_features=context_mapping_features,
1151
+ context_embedding_features=context_embedding_features,
1152
+ num_layers=num_blocks[i] + (1 if attentions[i] else 0),
1153
+ factor=factors[i],
1154
+ use_nearest=use_nearest_upsample,
1155
+ num_groups=resnet_groups,
1156
+ use_skip_scale=use_skip_scale,
1157
+ use_pre_upsample=False,
1158
+ use_skip=True,
1159
+ use_snake=use_snake,
1160
+ skip_channels=channels * multipliers[i + 1],
1161
+ num_transformer_blocks=attentions[i],
1162
+ **attention_kwargs,
1163
+ )
1164
+ for i in reversed(range(num_layers))
1165
+ ]
1166
+ )
1167
+
1168
+ self.to_out = Unpatcher(
1169
+ in_channels=channels * multipliers[0],
1170
+ out_channels=out_channels,
1171
+ patch_size=patch_size,
1172
+ context_mapping_features=context_mapping_features,
1173
+ use_snake=use_snake
1174
+ )
1175
+
1176
+ def get_channels(
1177
+ self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
1178
+ ) -> Optional[Tensor]:
1179
+ """Gets context channels at `layer` and checks that shape is correct"""
1180
+ use_context_channels = self.use_context_channels and self.has_context[layer]
1181
+ if not use_context_channels:
1182
+ return None
1183
+ assert exists(channels_list), "Missing context"
1184
+ # Get channels index (skipping zero channel contexts)
1185
+ channels_id = self.channels_ids[layer]
1186
+ # Get channels
1187
+ channels = channels_list[channels_id]
1188
+ message = f"Missing context for layer {layer} at index {channels_id}"
1189
+ assert exists(channels), message
1190
+ # Check channels
1191
+ num_channels = self.context_channels[layer]
1192
+ message = f"Expected context with {num_channels} channels at idx {channels_id}"
1193
+ assert channels.shape[1] == num_channels, message
1194
+ # STFT channels if requested
1195
+ channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1196
+ return channels
1197
+
1198
+ def get_mapping(
1199
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
1200
+ ) -> Optional[Tensor]:
1201
+ """Combines context time features and features into mapping"""
1202
+ items, mapping = [], None
1203
+ # Compute time features
1204
+ if self.use_context_time:
1205
+ assert_message = "use_context_time=True but no time features provided"
1206
+ assert exists(time), assert_message
1207
+ items += [self.to_time(time)]
1208
+ # Compute features
1209
+ if self.use_context_features:
1210
+ assert_message = "context_features exists but no features provided"
1211
+ assert exists(features), assert_message
1212
+ items += [self.to_features(features)]
1213
+ # Compute joint mapping
1214
+ if self.use_context_time or self.use_context_features:
1215
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
1216
+ mapping = self.to_mapping(mapping)
1217
+ return mapping
1218
+
1219
+ def forward(
1220
+ self,
1221
+ x: Tensor,
1222
+ time: Optional[Tensor] = None,
1223
+ *,
1224
+ features: Optional[Tensor] = None,
1225
+ channels_list: Optional[Sequence[Tensor]] = None,
1226
+ embedding: Optional[Tensor] = None,
1227
+ embedding_mask: Optional[Tensor] = None,
1228
+ causal: Optional[bool] = False,
1229
+ ) -> Tensor:
1230
+ channels = self.get_channels(channels_list, layer=0)
1231
+ # Apply stft if required
1232
+ x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
1233
+ # Concat context channels at layer 0 if provided
1234
+ x = torch.cat([x, channels], dim=1) if exists(channels) else x
1235
+ # Compute mapping from time and features
1236
+ mapping = self.get_mapping(time, features)
1237
+ x = self.to_in(x, mapping, causal=causal)
1238
+ skips_list = [x]
1239
+
1240
+ for i, downsample in enumerate(self.downsamples):
1241
+ channels = self.get_channels(channels_list, layer=i + 1)
1242
+ x, skips = downsample(
1243
+ x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
1244
+ )
1245
+ skips_list += [skips]
1246
+
1247
+ x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1248
+
1249
+ for i, upsample in enumerate(self.upsamples):
1250
+ skips = skips_list.pop()
1251
+ x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1252
+
1253
+ x += skips_list.pop()
1254
+ x = self.to_out(x, mapping, causal=causal)
1255
+ x = self.stft.decode1d(x) if self.use_stft else x
1256
+
1257
+ return x
1258
+
1259
+
1260
+ """ Conditioning Modules """
1261
+
1262
+
1263
+ class FixedEmbedding(nn.Module):
1264
+ def __init__(self, max_length: int, features: int):
1265
+ super().__init__()
1266
+ self.max_length = max_length
1267
+ self.embedding = nn.Embedding(max_length, features)
1268
+
1269
+ def forward(self, x: Tensor) -> Tensor:
1270
+ batch_size, length, device = *x.shape[0:2], x.device
1271
+ assert_message = "Input sequence length must be <= max_length"
1272
+ assert length <= self.max_length, assert_message
1273
+ position = torch.arange(length, device=device)
1274
+ fixed_embedding = self.embedding(position)
1275
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1276
+ return fixed_embedding
1277
+
1278
+
1279
+ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
1280
+ if proba == 1:
1281
+ return torch.ones(shape, device=device, dtype=torch.bool)
1282
+ elif proba == 0:
1283
+ return torch.zeros(shape, device=device, dtype=torch.bool)
1284
+ else:
1285
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
1286
+
1287
+
1288
+ class UNetCFG1d(UNet1d):
1289
+
1290
+ """UNet1d with Classifier-Free Guidance"""
1291
+
1292
+ def __init__(
1293
+ self,
1294
+ context_embedding_max_length: int,
1295
+ context_embedding_features: int,
1296
+ use_xattn_time: bool = False,
1297
+ **kwargs,
1298
+ ):
1299
+ super().__init__(
1300
+ context_embedding_features=context_embedding_features, **kwargs
1301
+ )
1302
+
1303
+ self.use_xattn_time = use_xattn_time
1304
+
1305
+ if use_xattn_time:
1306
+ assert exists(context_embedding_features)
1307
+ self.to_time_embedding = nn.Sequential(
1308
+ TimePositionalEmbedding(
1309
+ dim=kwargs["channels"], out_features=context_embedding_features
1310
+ ),
1311
+ nn.GELU(),
1312
+ )
1313
+
1314
+ context_embedding_max_length += 1 # Add one for time embedding
1315
+
1316
+ self.fixed_embedding = FixedEmbedding(
1317
+ max_length=context_embedding_max_length, features=context_embedding_features
1318
+ )
1319
+
1320
+ def forward( # type: ignore
1321
+ self,
1322
+ x: Tensor,
1323
+ time: Tensor,
1324
+ *,
1325
+ embedding: Tensor,
1326
+ embedding_mask: Optional[Tensor] = None,
1327
+ embedding_scale: float = 1.0,
1328
+ embedding_mask_proba: float = 0.0,
1329
+ batch_cfg: bool = False,
1330
+ rescale_cfg: bool = False,
1331
+ scale_phi: float = 0.4,
1332
+ negative_embedding: Optional[Tensor] = None,
1333
+ negative_embedding_mask: Optional[Tensor] = None,
1334
+ **kwargs,
1335
+ ) -> Tensor:
1336
+ b, device = embedding.shape[0], embedding.device
1337
+
1338
+ if self.use_xattn_time:
1339
+ embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
1340
+
1341
+ if embedding_mask is not None:
1342
+ embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
1343
+
1344
+ fixed_embedding = self.fixed_embedding(embedding)
1345
+
1346
+ if embedding_mask_proba > 0.0:
1347
+ # Randomly mask embedding
1348
+ batch_mask = rand_bool(
1349
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
1350
+ )
1351
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
1352
+
1353
+ if embedding_scale != 1.0:
1354
+ if batch_cfg:
1355
+ batch_x = torch.cat([x, x], dim=0)
1356
+ batch_time = torch.cat([time, time], dim=0)
1357
+
1358
+ if negative_embedding is not None:
1359
+ if negative_embedding_mask is not None:
1360
+ negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
1361
+
1362
+ negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
1363
+
1364
+ batch_embed = torch.cat([embedding, negative_embedding], dim=0)
1365
+
1366
+ else:
1367
+ batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
1368
+
1369
+ batch_mask = None
1370
+ if embedding_mask is not None:
1371
+ batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
1372
+
1373
+ batch_features = None
1374
+ features = kwargs.pop("features", None)
1375
+ if self.use_context_features:
1376
+ batch_features = torch.cat([features, features], dim=0)
1377
+
1378
+ batch_channels = None
1379
+ channels_list = kwargs.pop("channels_list", None)
1380
+ if self.use_context_channels:
1381
+ batch_channels = []
1382
+ for channels in channels_list:
1383
+ batch_channels += [torch.cat([channels, channels], dim=0)]
1384
+
1385
+ # Compute both normal and fixed embedding outputs
1386
+ batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
1387
+ out, out_masked = batch_out.chunk(2, dim=0)
1388
+
1389
+ else:
1390
+ # Compute both normal and fixed embedding outputs
1391
+ out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1392
+ out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
1393
+
1394
+ out_cfg = out_masked + (out - out_masked) * embedding_scale
1395
+
1396
+ if rescale_cfg:
1397
+
1398
+ out_std = out.std(dim=1, keepdim=True)
1399
+ out_cfg_std = out_cfg.std(dim=1, keepdim=True)
1400
+
1401
+ return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
1402
+
1403
+ else:
1404
+
1405
+ return out_cfg
1406
+
1407
+ else:
1408
+ return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1409
+
1410
+
1411
+ class UNetNCCA1d(UNet1d):
1412
+
1413
+ """UNet1d with Noise Channel Conditioning Augmentation"""
1414
+
1415
+ def __init__(self, context_features: int, **kwargs):
1416
+ super().__init__(context_features=context_features, **kwargs)
1417
+ self.embedder = NumberEmbedder(features=context_features)
1418
+
1419
+ def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1420
+ x = x if torch.is_tensor(x) else torch.tensor(x)
1421
+ return x.expand(shape)
1422
+
1423
+ def forward( # type: ignore
1424
+ self,
1425
+ x: Tensor,
1426
+ time: Tensor,
1427
+ *,
1428
+ channels_list: Sequence[Tensor],
1429
+ channels_augmentation: Union[
1430
+ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1431
+ ] = False,
1432
+ channels_scale: Union[
1433
+ float, Sequence[float], Sequence[Sequence[float]], Tensor
1434
+ ] = 0,
1435
+ **kwargs,
1436
+ ) -> Tensor:
1437
+ b, n = x.shape[0], len(channels_list)
1438
+ channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1439
+ channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1440
+
1441
+ # Augmentation (for each channel list item)
1442
+ for i in range(n):
1443
+ scale = channels_scale[:, i] * channels_augmentation[:, i]
1444
+ scale = rearrange(scale, "b -> b 1 1")
1445
+ item = channels_list[i]
1446
+ channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1447
+
1448
+ # Scale embedding (sum reduction if more than one channel list item)
1449
+ channels_scale_emb = self.embedder(channels_scale)
1450
+ channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
1451
+
1452
+ return super().forward(
1453
+ x=x,
1454
+ time=time,
1455
+ channels_list=channels_list,
1456
+ features=channels_scale_emb,
1457
+ **kwargs,
1458
+ )
1459
+
1460
+
1461
+ class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1462
+ def __init__(self, *args, **kwargs):
1463
+ super().__init__(*args, **kwargs)
1464
+
1465
+ def forward(self, *args, **kwargs): # type: ignore
1466
+ return UNetCFG1d.forward(self, *args, **kwargs)
1467
+
1468
+
1469
+ def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1470
+ if type == "base":
1471
+ return UNet1d(**kwargs)
1472
+ elif type == "all":
1473
+ return UNetAll1d(**kwargs)
1474
+ elif type == "cfg":
1475
+ return UNetCFG1d(**kwargs)
1476
+ elif type == "ncca":
1477
+ return UNetNCCA1d(**kwargs)
1478
+ else:
1479
+ raise ValueError(f"Unknown XUNet1d type: {type}")
1480
+
1481
+ class NumberEmbedder(nn.Module):
1482
+ def __init__(
1483
+ self,
1484
+ features: int,
1485
+ dim: int = 256,
1486
+ ):
1487
+ super().__init__()
1488
+ self.features = features
1489
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1490
+
1491
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1492
+ if not torch.is_tensor(x):
1493
+ device = next(self.embedding.parameters()).device
1494
+ x = torch.tensor(x, device=device)
1495
+ assert isinstance(x, Tensor)
1496
+ shape = x.shape
1497
+ x = rearrange(x, "... -> (...)")
1498
+ embedding = self.embedding(x)
1499
+ x = embedding.view(*shape, self.features)
1500
+ return x # type: ignore
1501
+
1502
+
1503
+ """
1504
+ Audio Transforms
1505
+ """
1506
+
1507
+
1508
+ class STFT(nn.Module):
1509
+ """Helper for torch stft and istft"""
1510
+
1511
+ def __init__(
1512
+ self,
1513
+ num_fft: int = 1023,
1514
+ hop_length: int = 256,
1515
+ window_length: Optional[int] = None,
1516
+ length: Optional[int] = None,
1517
+ use_complex: bool = False,
1518
+ ):
1519
+ super().__init__()
1520
+ self.num_fft = num_fft
1521
+ self.hop_length = default(hop_length, floor(num_fft // 4))
1522
+ self.window_length = default(window_length, num_fft)
1523
+ self.length = length
1524
+ self.register_buffer("window", torch.hann_window(self.window_length))
1525
+ self.use_complex = use_complex
1526
+
1527
+ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1528
+ b = wave.shape[0]
1529
+ wave = rearrange(wave, "b c t -> (b c) t")
1530
+
1531
+ stft = torch.stft(
1532
+ wave,
1533
+ n_fft=self.num_fft,
1534
+ hop_length=self.hop_length,
1535
+ win_length=self.window_length,
1536
+ window=self.window, # type: ignore
1537
+ return_complex=True,
1538
+ normalized=True,
1539
+ )
1540
+
1541
+ if self.use_complex:
1542
+ # Returns real and imaginary
1543
+ stft_a, stft_b = stft.real, stft.imag
1544
+ else:
1545
+ # Returns magnitude and phase matrices
1546
+ magnitude, phase = torch.abs(stft), torch.angle(stft)
1547
+ stft_a, stft_b = magnitude, phase
1548
+
1549
+ return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
1550
+
1551
+ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1552
+ b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
1553
+ length = closest_power_2(l * self.hop_length)
1554
+
1555
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1556
+
1557
+ if self.use_complex:
1558
+ real, imag = stft_a, stft_b
1559
+ else:
1560
+ magnitude, phase = stft_a, stft_b
1561
+ real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1562
+
1563
+ stft = torch.stack([real, imag], dim=-1)
1564
+
1565
+ wave = torch.istft(
1566
+ stft,
1567
+ n_fft=self.num_fft,
1568
+ hop_length=self.hop_length,
1569
+ win_length=self.window_length,
1570
+ window=self.window, # type: ignore
1571
+ length=default(self.length, length),
1572
+ normalized=True,
1573
+ )
1574
+
1575
+ return rearrange(wave, "(b c) t -> b c t", b=b)
1576
+
1577
+ def encode1d(
1578
+ self, wave: Tensor, stacked: bool = True
1579
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1580
+ stft_a, stft_b = self.encode(wave)
1581
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1582
+ return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
1583
+
1584
+ def decode1d(self, stft_pair: Tensor) -> Tensor:
1585
+ f = self.num_fft // 2 + 1
1586
+ stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1587
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1588
+ return self.decode(stft_a, stft_b)
stable_audio_tools/models/autoencoders.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+ from alias_free_torch import Activation1d
9
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
10
+ from typing import Literal, Dict, Any
11
+
12
+ from ..inference.sampling import sample
13
+ from ..inference.utils import prepare_audio
14
+ from .blocks import SnakeBeta
15
+ from .bottleneck import Bottleneck, DiscreteBottleneck
16
+ from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
17
+ from .factory import create_pretransform_from_config, create_bottleneck_from_config
18
+ from .pretransforms import Pretransform
19
+
20
+ def checkpoint(function, *args, **kwargs):
21
+ kwargs.setdefault("use_reentrant", False)
22
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
23
+
24
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
25
+ if activation == "elu":
26
+ act = nn.ELU()
27
+ elif activation == "snake":
28
+ act = SnakeBeta(channels)
29
+ elif activation == "none":
30
+ act = nn.Identity()
31
+ else:
32
+ raise ValueError(f"Unknown activation {activation}")
33
+
34
+ if antialias:
35
+ act = Activation1d(act)
36
+
37
+ return act
38
+
39
+ class ResidualUnit(nn.Module):
40
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
41
+ super().__init__()
42
+
43
+ self.dilation = dilation
44
+
45
+ padding = (dilation * (7-1)) // 2
46
+
47
+ self.layers = nn.Sequential(
48
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
49
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
50
+ kernel_size=7, dilation=dilation, padding=padding),
51
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
52
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
53
+ kernel_size=1)
54
+ )
55
+
56
+ def forward(self, x):
57
+ res = x
58
+
59
+ #x = checkpoint(self.layers, x)
60
+ x = self.layers(x)
61
+
62
+ return x + res
63
+
64
+ class EncoderBlock(nn.Module):
65
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
66
+ super().__init__()
67
+
68
+ self.layers = nn.Sequential(
69
+ ResidualUnit(in_channels=in_channels,
70
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
71
+ ResidualUnit(in_channels=in_channels,
72
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
73
+ ResidualUnit(in_channels=in_channels,
74
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
75
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
76
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
77
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
78
+ )
79
+
80
+ def forward(self, x):
81
+ return self.layers(x)
82
+
83
+ class DecoderBlock(nn.Module):
84
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
85
+ super().__init__()
86
+
87
+ if use_nearest_upsample:
88
+ upsample_layer = nn.Sequential(
89
+ nn.Upsample(scale_factor=stride, mode="nearest"),
90
+ WNConv1d(in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ kernel_size=2*stride,
93
+ stride=1,
94
+ bias=False,
95
+ padding='same')
96
+ )
97
+ else:
98
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
99
+ out_channels=out_channels,
100
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
101
+
102
+ self.layers = nn.Sequential(
103
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
104
+ upsample_layer,
105
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
106
+ dilation=1, use_snake=use_snake),
107
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
108
+ dilation=3, use_snake=use_snake),
109
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
110
+ dilation=9, use_snake=use_snake),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.layers(x)
115
+
116
+ class OobleckEncoder(nn.Module):
117
+ def __init__(self,
118
+ in_channels=2,
119
+ channels=128,
120
+ latent_dim=32,
121
+ c_mults = [1, 2, 4, 8],
122
+ strides = [2, 4, 8, 8],
123
+ use_snake=False,
124
+ antialias_activation=False
125
+ ):
126
+ super().__init__()
127
+
128
+ c_mults = [1] + c_mults
129
+
130
+ self.depth = len(c_mults)
131
+
132
+ layers = [
133
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
134
+ ]
135
+
136
+ for i in range(self.depth-1):
137
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
138
+
139
+ layers += [
140
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
141
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
142
+ ]
143
+
144
+ self.layers = nn.Sequential(*layers)
145
+
146
+ def forward(self, x):
147
+ return self.layers(x)
148
+
149
+
150
+ class OobleckDecoder(nn.Module):
151
+ def __init__(self,
152
+ out_channels=2,
153
+ channels=128,
154
+ latent_dim=32,
155
+ c_mults = [1, 2, 4, 8],
156
+ strides = [2, 4, 8, 8],
157
+ use_snake=False,
158
+ antialias_activation=False,
159
+ use_nearest_upsample=False,
160
+ final_tanh=True):
161
+ super().__init__()
162
+
163
+ c_mults = [1] + c_mults
164
+
165
+ self.depth = len(c_mults)
166
+
167
+ layers = [
168
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
169
+ ]
170
+
171
+ for i in range(self.depth-1, 0, -1):
172
+ layers += [DecoderBlock(
173
+ in_channels=c_mults[i]*channels,
174
+ out_channels=c_mults[i-1]*channels,
175
+ stride=strides[i-1],
176
+ use_snake=use_snake,
177
+ antialias_activation=antialias_activation,
178
+ use_nearest_upsample=use_nearest_upsample
179
+ )
180
+ ]
181
+
182
+ layers += [
183
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
184
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
185
+ nn.Tanh() if final_tanh else nn.Identity()
186
+ ]
187
+
188
+ self.layers = nn.Sequential(*layers)
189
+
190
+ def forward(self, x):
191
+ return self.layers(x)
192
+
193
+
194
+ class DACEncoderWrapper(nn.Module):
195
+ def __init__(self, in_channels=1, **kwargs):
196
+ super().__init__()
197
+
198
+ from dac.model.dac import Encoder as DACEncoder
199
+
200
+ latent_dim = kwargs.pop("latent_dim", None)
201
+
202
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
203
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
204
+ self.latent_dim = latent_dim
205
+
206
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
207
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
208
+
209
+ if in_channels != 1:
210
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
211
+
212
+ def forward(self, x):
213
+ x = self.encoder(x)
214
+ x = self.proj_out(x)
215
+ return x
216
+
217
+ class DACDecoderWrapper(nn.Module):
218
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
219
+ super().__init__()
220
+
221
+ from dac.model.dac import Decoder as DACDecoder
222
+
223
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
224
+
225
+ self.latent_dim = latent_dim
226
+
227
+ def forward(self, x):
228
+ return self.decoder(x)
229
+
230
+ class AudioAutoencoder(nn.Module):
231
+ def __init__(
232
+ self,
233
+ encoder,
234
+ decoder,
235
+ latent_dim,
236
+ downsampling_ratio,
237
+ sample_rate,
238
+ io_channels=2,
239
+ bottleneck: Bottleneck = None,
240
+ pretransform: Pretransform = None,
241
+ in_channels = None,
242
+ out_channels = None,
243
+ soft_clip = False
244
+ ):
245
+ super().__init__()
246
+
247
+ self.downsampling_ratio = downsampling_ratio
248
+ self.sample_rate = sample_rate
249
+
250
+ self.latent_dim = latent_dim
251
+ self.io_channels = io_channels
252
+ self.in_channels = io_channels
253
+ self.out_channels = io_channels
254
+
255
+ self.min_length = self.downsampling_ratio
256
+
257
+ if in_channels is not None:
258
+ self.in_channels = in_channels
259
+
260
+ if out_channels is not None:
261
+ self.out_channels = out_channels
262
+
263
+ self.bottleneck = bottleneck
264
+
265
+ self.encoder = encoder
266
+
267
+ self.decoder = decoder
268
+
269
+ self.pretransform = pretransform
270
+
271
+ self.soft_clip = soft_clip
272
+
273
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
274
+
275
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
276
+
277
+ info = {}
278
+
279
+ if self.pretransform is not None and not skip_pretransform:
280
+ if self.pretransform.enable_grad:
281
+ if iterate_batch:
282
+ audios = []
283
+ for i in range(audio.shape[0]):
284
+ audios.append(self.pretransform.encode(audio[i:i+1]))
285
+ audio = torch.cat(audios, dim=0)
286
+ else:
287
+ audio = self.pretransform.encode(audio)
288
+ else:
289
+ with torch.no_grad():
290
+ if iterate_batch:
291
+ audios = []
292
+ for i in range(audio.shape[0]):
293
+ audios.append(self.pretransform.encode(audio[i:i+1]))
294
+ audio = torch.cat(audios, dim=0)
295
+ else:
296
+ audio = self.pretransform.encode(audio)
297
+
298
+ if self.encoder is not None:
299
+ if iterate_batch:
300
+ latents = []
301
+ for i in range(audio.shape[0]):
302
+ latents.append(self.encoder(audio[i:i+1]))
303
+ latents = torch.cat(latents, dim=0)
304
+ else:
305
+ latents = self.encoder(audio)
306
+ else:
307
+ latents = audio
308
+
309
+ if self.bottleneck is not None:
310
+ # TODO: Add iterate batch logic, needs to merge the info dicts
311
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
312
+
313
+ info.update(bottleneck_info)
314
+
315
+ if return_info:
316
+ return latents, info
317
+
318
+ return latents
319
+
320
+ def decode(self, latents, iterate_batch=False, **kwargs):
321
+
322
+ if self.bottleneck is not None:
323
+ if iterate_batch:
324
+ decoded = []
325
+ for i in range(latents.shape[0]):
326
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
327
+ latents = torch.cat(decoded, dim=0)
328
+ else:
329
+ latents = self.bottleneck.decode(latents)
330
+
331
+ if iterate_batch:
332
+ decoded = []
333
+ for i in range(latents.shape[0]):
334
+ decoded.append(self.decoder(latents[i:i+1]))
335
+ decoded = torch.cat(decoded, dim=0)
336
+ else:
337
+ decoded = self.decoder(latents, **kwargs)
338
+
339
+ if self.pretransform is not None:
340
+ if self.pretransform.enable_grad:
341
+ if iterate_batch:
342
+ decodeds = []
343
+ for i in range(decoded.shape[0]):
344
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
345
+ decoded = torch.cat(decodeds, dim=0)
346
+ else:
347
+ decoded = self.pretransform.decode(decoded)
348
+ else:
349
+ with torch.no_grad():
350
+ if iterate_batch:
351
+ decodeds = []
352
+ for i in range(latents.shape[0]):
353
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
354
+ decoded = torch.cat(decodeds, dim=0)
355
+ else:
356
+ decoded = self.pretransform.decode(decoded)
357
+
358
+ if self.soft_clip:
359
+ decoded = torch.tanh(decoded)
360
+
361
+ return decoded
362
+
363
+ def decode_tokens(self, tokens, **kwargs):
364
+ '''
365
+ Decode discrete tokens to audio
366
+ Only works with discrete autoencoders
367
+ '''
368
+
369
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
370
+
371
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
372
+
373
+ return self.decode(latents, **kwargs)
374
+
375
+
376
+ def preprocess_audio_for_encoder(self, audio, in_sr):
377
+ '''
378
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
379
+ If the model is mono, stereo audio will be converted to mono.
380
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
381
+ Audio will be resampled to the model's sample rate.
382
+ The output will have batch size 1 and be shape (1 x Channels x Length)
383
+ '''
384
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
385
+
386
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
387
+ '''
388
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
389
+ The audio in that list can be of different lengths and channels.
390
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
391
+ All audio will be resampled to the model's sample rate.
392
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
393
+ If the model is mono, all audio will be converted to mono.
394
+ The output will be a tensor of shape (Batch x Channels x Length)
395
+ '''
396
+ batch_size = len(audio_list)
397
+ if isinstance(in_sr_list, int):
398
+ in_sr_list = [in_sr_list]*batch_size
399
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
400
+ new_audio = []
401
+ max_length = 0
402
+ # resample & find the max length
403
+ for i in range(batch_size):
404
+ audio = audio_list[i]
405
+ in_sr = in_sr_list[i]
406
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
407
+ # batchsize 1 was given by accident. Just squeeze it.
408
+ audio = audio.squeeze(0)
409
+ elif len(audio.shape) == 1:
410
+ # Mono signal, channel dimension is missing, unsqueeze it in
411
+ audio = audio.unsqueeze(0)
412
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
413
+ # Resample audio
414
+ if in_sr != self.sample_rate:
415
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
416
+ audio = resample_tf(audio)
417
+ new_audio.append(audio)
418
+ if audio.shape[-1] > max_length:
419
+ max_length = audio.shape[-1]
420
+ # Pad every audio to the same length, multiple of model's downsampling ratio
421
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
422
+ for i in range(batch_size):
423
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
424
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
425
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
426
+ # convert to tensor
427
+ return torch.stack(new_audio)
428
+
429
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
430
+ '''
431
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
432
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
433
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
434
+ # and therefore you likely could use the same values with decode_audio.
435
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
436
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
437
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
438
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
439
+ Smaller chunk_size uses less memory, but more compute.
440
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
441
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
442
+ '''
443
+ if not chunked:
444
+ # default behavior. Encode the entire audio in parallel
445
+ return self.encode(audio, **kwargs)
446
+ else:
447
+ # CHUNKED ENCODING
448
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
449
+ samples_per_latent = self.downsampling_ratio
450
+ total_size = audio.shape[2] # in samples
451
+ batch_size = audio.shape[0]
452
+ chunk_size *= samples_per_latent # converting metric in latents to samples
453
+ overlap *= samples_per_latent # converting metric in latents to samples
454
+ hop_size = chunk_size - overlap
455
+ chunks = []
456
+ for i in range(0, total_size - chunk_size + 1, hop_size):
457
+ chunk = audio[:,:,i:i+chunk_size]
458
+ chunks.append(chunk)
459
+ if i+chunk_size != total_size:
460
+ # Final chunk
461
+ chunk = audio[:,:,-chunk_size:]
462
+ chunks.append(chunk)
463
+ chunks = torch.stack(chunks)
464
+ num_chunks = chunks.shape[0]
465
+ # Note: y_size might be a different value from the latent length used in diffusion training
466
+ # because we can encode audio of varying lengths
467
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
468
+ y_size = total_size // samples_per_latent
469
+ # Create an empty latent, we will populate it with chunks as we encode them
470
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
471
+ for i in range(num_chunks):
472
+ x_chunk = chunks[i,:]
473
+ # encode the chunk
474
+ y_chunk = self.encode(x_chunk)
475
+ # figure out where to put the audio along the time domain
476
+ if i == num_chunks-1:
477
+ # final chunk always goes at the end
478
+ t_end = y_size
479
+ t_start = t_end - y_chunk.shape[2]
480
+ else:
481
+ t_start = i * hop_size // samples_per_latent
482
+ t_end = t_start + chunk_size // samples_per_latent
483
+ # remove the edges of the overlaps
484
+ ol = overlap//samples_per_latent//2
485
+ chunk_start = 0
486
+ chunk_end = y_chunk.shape[2]
487
+ if i > 0:
488
+ # no overlap for the start of the first chunk
489
+ t_start += ol
490
+ chunk_start += ol
491
+ if i < num_chunks-1:
492
+ # no overlap for the end of the last chunk
493
+ t_end -= ol
494
+ chunk_end -= ol
495
+ # paste the chunked audio into our y_final output audio
496
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
497
+ return y_final
498
+
499
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
500
+ '''
501
+ Decode latents to audio.
502
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
503
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
504
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
505
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
506
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
507
+ Smaller chunk_size uses less memory, but more compute.
508
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
509
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
510
+ '''
511
+ if not chunked:
512
+ # default behavior. Decode the entire latent in parallel
513
+ return self.decode(latents, **kwargs)
514
+ else:
515
+ # chunked decoding
516
+ hop_size = chunk_size - overlap
517
+ total_size = latents.shape[2]
518
+ batch_size = latents.shape[0]
519
+ chunks = []
520
+ for i in range(0, total_size - chunk_size + 1, hop_size):
521
+ chunk = latents[:,:,i:i+chunk_size]
522
+ chunks.append(chunk)
523
+ if i+chunk_size != total_size:
524
+ # Final chunk
525
+ chunk = latents[:,:,-chunk_size:]
526
+ chunks.append(chunk)
527
+ chunks = torch.stack(chunks)
528
+ num_chunks = chunks.shape[0]
529
+ # samples_per_latent is just the downsampling ratio
530
+ samples_per_latent = self.downsampling_ratio
531
+ # Create an empty waveform, we will populate it with chunks as decode them
532
+ y_size = total_size * samples_per_latent
533
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
534
+ for i in range(num_chunks):
535
+ x_chunk = chunks[i,:]
536
+ # decode the chunk
537
+ y_chunk = self.decode(x_chunk)
538
+ # figure out where to put the audio along the time domain
539
+ if i == num_chunks-1:
540
+ # final chunk always goes at the end
541
+ t_end = y_size
542
+ t_start = t_end - y_chunk.shape[2]
543
+ else:
544
+ t_start = i * hop_size * samples_per_latent
545
+ t_end = t_start + chunk_size * samples_per_latent
546
+ # remove the edges of the overlaps
547
+ ol = (overlap//2) * samples_per_latent
548
+ chunk_start = 0
549
+ chunk_end = y_chunk.shape[2]
550
+ if i > 0:
551
+ # no overlap for the start of the first chunk
552
+ t_start += ol
553
+ chunk_start += ol
554
+ if i < num_chunks-1:
555
+ # no overlap for the end of the last chunk
556
+ t_end -= ol
557
+ chunk_end -= ol
558
+ # paste the chunked audio into our y_final output audio
559
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
560
+ return y_final
561
+
562
+
563
+ class DiffusionAutoencoder(AudioAutoencoder):
564
+ def __init__(
565
+ self,
566
+ diffusion: ConditionedDiffusionModel,
567
+ diffusion_downsampling_ratio,
568
+ *args,
569
+ **kwargs
570
+ ):
571
+ super().__init__(*args, **kwargs)
572
+
573
+ self.diffusion = diffusion
574
+
575
+ self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
576
+
577
+ if self.encoder is not None:
578
+ # Shrink the initial encoder parameters to avoid saturated latents
579
+ with torch.no_grad():
580
+ for param in self.encoder.parameters():
581
+ param *= 0.5
582
+
583
+ def decode(self, latents, steps=100):
584
+
585
+ upsampled_length = latents.shape[2] * self.downsampling_ratio
586
+
587
+ if self.bottleneck is not None:
588
+ latents = self.bottleneck.decode(latents)
589
+
590
+ if self.decoder is not None:
591
+ latents = self.decode(latents)
592
+
593
+ # Upsample latents to match diffusion length
594
+ if latents.shape[2] != upsampled_length:
595
+ latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
596
+
597
+ noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
598
+ decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
599
+
600
+ if self.pretransform is not None:
601
+ if self.pretransform.enable_grad:
602
+ decoded = self.pretransform.decode(decoded)
603
+ else:
604
+ with torch.no_grad():
605
+ decoded = self.pretransform.decode(decoded)
606
+
607
+ return decoded
608
+
609
+ # AE factories
610
+
611
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
612
+ encoder_type = encoder_config.get("type", None)
613
+ assert encoder_type is not None, "Encoder type must be specified"
614
+
615
+ if encoder_type == "oobleck":
616
+ encoder = OobleckEncoder(
617
+ **encoder_config["config"]
618
+ )
619
+
620
+ elif encoder_type == "seanet":
621
+ from encodec.modules import SEANetEncoder
622
+ seanet_encoder_config = encoder_config["config"]
623
+
624
+ #SEANet encoder expects strides in reverse order
625
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
626
+ encoder = SEANetEncoder(
627
+ **seanet_encoder_config
628
+ )
629
+ elif encoder_type == "dac":
630
+ dac_config = encoder_config["config"]
631
+
632
+ encoder = DACEncoderWrapper(**dac_config)
633
+ elif encoder_type == "local_attn":
634
+ from .local_attention import TransformerEncoder1D
635
+
636
+ local_attn_config = encoder_config["config"]
637
+
638
+ encoder = TransformerEncoder1D(
639
+ **local_attn_config
640
+ )
641
+ else:
642
+ raise ValueError(f"Unknown encoder type {encoder_type}")
643
+
644
+ requires_grad = encoder_config.get("requires_grad", True)
645
+ if not requires_grad:
646
+ for param in encoder.parameters():
647
+ param.requires_grad = False
648
+
649
+ return encoder
650
+
651
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
652
+ decoder_type = decoder_config.get("type", None)
653
+ assert decoder_type is not None, "Decoder type must be specified"
654
+
655
+ if decoder_type == "oobleck":
656
+ decoder = OobleckDecoder(
657
+ **decoder_config["config"]
658
+ )
659
+ elif decoder_type == "seanet":
660
+ from encodec.modules import SEANetDecoder
661
+
662
+ decoder = SEANetDecoder(
663
+ **decoder_config["config"]
664
+ )
665
+ elif decoder_type == "dac":
666
+ dac_config = decoder_config["config"]
667
+
668
+ decoder = DACDecoderWrapper(**dac_config)
669
+ elif decoder_type == "local_attn":
670
+ from .local_attention import TransformerDecoder1D
671
+
672
+ local_attn_config = decoder_config["config"]
673
+
674
+ decoder = TransformerDecoder1D(
675
+ **local_attn_config
676
+ )
677
+ else:
678
+ raise ValueError(f"Unknown decoder type {decoder_type}")
679
+
680
+ requires_grad = decoder_config.get("requires_grad", True)
681
+ if not requires_grad:
682
+ for param in decoder.parameters():
683
+ param.requires_grad = False
684
+
685
+ return decoder
686
+
687
+ def create_autoencoder_from_config(config: Dict[str, Any]):
688
+
689
+ ae_config = config["model"]
690
+
691
+ encoder = create_encoder_from_config(ae_config["encoder"])
692
+ decoder = create_decoder_from_config(ae_config["decoder"])
693
+
694
+ bottleneck = ae_config.get("bottleneck", None)
695
+
696
+ latent_dim = ae_config.get("latent_dim", None)
697
+ assert latent_dim is not None, "latent_dim must be specified in model config"
698
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
699
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
700
+ io_channels = ae_config.get("io_channels", None)
701
+ assert io_channels is not None, "io_channels must be specified in model config"
702
+ sample_rate = config.get("sample_rate", None)
703
+ assert sample_rate is not None, "sample_rate must be specified in model config"
704
+
705
+ in_channels = ae_config.get("in_channels", None)
706
+ out_channels = ae_config.get("out_channels", None)
707
+
708
+ pretransform = ae_config.get("pretransform", None)
709
+
710
+ if pretransform is not None:
711
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
712
+
713
+ if bottleneck is not None:
714
+ bottleneck = create_bottleneck_from_config(bottleneck)
715
+
716
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
717
+
718
+ return AudioAutoencoder(
719
+ encoder,
720
+ decoder,
721
+ io_channels=io_channels,
722
+ latent_dim=latent_dim,
723
+ downsampling_ratio=downsampling_ratio,
724
+ sample_rate=sample_rate,
725
+ bottleneck=bottleneck,
726
+ pretransform=pretransform,
727
+ in_channels=in_channels,
728
+ out_channels=out_channels,
729
+ soft_clip=soft_clip
730
+ )
731
+
732
+ def create_diffAE_from_config(config: Dict[str, Any]):
733
+
734
+ diffae_config = config["model"]
735
+
736
+ if "encoder" in diffae_config:
737
+ encoder = create_encoder_from_config(diffae_config["encoder"])
738
+ else:
739
+ encoder = None
740
+
741
+ if "decoder" in diffae_config:
742
+ decoder = create_decoder_from_config(diffae_config["decoder"])
743
+ else:
744
+ decoder = None
745
+
746
+ diffusion_model_type = diffae_config["diffusion"]["type"]
747
+
748
+ if diffusion_model_type == "DAU1d":
749
+ diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
750
+ elif diffusion_model_type == "adp_1d":
751
+ diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
752
+ elif diffusion_model_type == "dit":
753
+ diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
754
+
755
+ latent_dim = diffae_config.get("latent_dim", None)
756
+ assert latent_dim is not None, "latent_dim must be specified in model config"
757
+ downsampling_ratio = diffae_config.get("downsampling_ratio", None)
758
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
759
+ io_channels = diffae_config.get("io_channels", None)
760
+ assert io_channels is not None, "io_channels must be specified in model config"
761
+ sample_rate = config.get("sample_rate", None)
762
+ assert sample_rate is not None, "sample_rate must be specified in model config"
763
+
764
+ bottleneck = diffae_config.get("bottleneck", None)
765
+
766
+ pretransform = diffae_config.get("pretransform", None)
767
+
768
+ if pretransform is not None:
769
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
770
+
771
+ if bottleneck is not None:
772
+ bottleneck = create_bottleneck_from_config(bottleneck)
773
+
774
+ diffusion_downsampling_ratio = None,
775
+
776
+ if diffusion_model_type == "DAU1d":
777
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
778
+ elif diffusion_model_type == "adp_1d":
779
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
780
+ elif diffusion_model_type == "dit":
781
+ diffusion_downsampling_ratio = 1
782
+
783
+ return DiffusionAutoencoder(
784
+ encoder=encoder,
785
+ decoder=decoder,
786
+ diffusion=diffusion,
787
+ io_channels=io_channels,
788
+ sample_rate=sample_rate,
789
+ latent_dim=latent_dim,
790
+ downsampling_ratio=downsampling_ratio,
791
+ diffusion_downsampling_ratio=diffusion_downsampling_ratio,
792
+ bottleneck=bottleneck,
793
+ pretransform=pretransform
794
+ )
stable_audio_tools/models/blocks.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.backends.cuda import sdp_kernel
9
+ from packaging import version
10
+
11
+ from dac.nn.layers import Snake1d
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, main, skip=None):
15
+ super().__init__()
16
+ self.main = nn.Sequential(*main)
17
+ self.skip = skip if skip else nn.Identity()
18
+
19
+ def forward(self, input):
20
+ return self.main(input) + self.skip(input)
21
+
22
+ class ResConvBlock(ResidualBlock):
23
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
24
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
25
+ super().__init__([
26
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
27
+ nn.GroupNorm(1, c_mid),
28
+ Snake1d(c_mid) if use_snake else nn.GELU(),
29
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
30
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
31
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
32
+ ], skip)
33
+
34
+ class SelfAttention1d(nn.Module):
35
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
36
+ super().__init__()
37
+ assert c_in % n_head == 0
38
+ self.norm = nn.GroupNorm(1, c_in)
39
+ self.n_head = n_head
40
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
41
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
42
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
43
+
44
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
45
+
46
+ if not self.use_flash:
47
+ return
48
+
49
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
50
+
51
+ if device_properties.major == 8 and device_properties.minor == 0:
52
+ # Use flash attention for A100 GPUs
53
+ self.sdp_kernel_config = (True, False, False)
54
+ else:
55
+ # Don't use flash attention for other GPUs
56
+ self.sdp_kernel_config = (False, True, True)
57
+
58
+ def forward(self, input):
59
+ n, c, s = input.shape
60
+ qkv = self.qkv_proj(self.norm(input))
61
+ qkv = qkv.view(
62
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
63
+ q, k, v = qkv.chunk(3, dim=1)
64
+ scale = k.shape[3]**-0.25
65
+
66
+ if self.use_flash:
67
+ with sdp_kernel(*self.sdp_kernel_config):
68
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
69
+ else:
70
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
71
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
72
+
73
+
74
+ return input + self.dropout(self.out_proj(y))
75
+
76
+ class SkipBlock(nn.Module):
77
+ def __init__(self, *main):
78
+ super().__init__()
79
+ self.main = nn.Sequential(*main)
80
+
81
+ def forward(self, input):
82
+ return torch.cat([self.main(input), input], dim=1)
83
+
84
+ class FourierFeatures(nn.Module):
85
+ def __init__(self, in_features, out_features, std=1.):
86
+ super().__init__()
87
+ assert out_features % 2 == 0
88
+ self.weight = nn.Parameter(torch.randn(
89
+ [out_features // 2, in_features]) * std)
90
+
91
+ def forward(self, input):
92
+ f = 2 * math.pi * input @ self.weight.T
93
+ return torch.cat([f.cos(), f.sin()], dim=-1)
94
+
95
+ def expand_to_planes(input, shape):
96
+ return input[..., None].repeat([1, 1, shape[2]])
97
+
98
+ _kernels = {
99
+ 'linear':
100
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
101
+ 'cubic':
102
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
103
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
104
+ 'lanczos3':
105
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
106
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
107
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
108
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
109
+ }
110
+
111
+ class Downsample1d(nn.Module):
112
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
113
+ super().__init__()
114
+ self.pad_mode = pad_mode
115
+ kernel_1d = torch.tensor(_kernels[kernel])
116
+ self.pad = kernel_1d.shape[0] // 2 - 1
117
+ self.register_buffer('kernel', kernel_1d)
118
+ self.channels_last = channels_last
119
+
120
+ def forward(self, x):
121
+ if self.channels_last:
122
+ x = x.permute(0, 2, 1)
123
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
124
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
125
+ indices = torch.arange(x.shape[1], device=x.device)
126
+ weight[indices, indices] = self.kernel.to(weight)
127
+ x = F.conv1d(x, weight, stride=2)
128
+ if self.channels_last:
129
+ x = x.permute(0, 2, 1)
130
+ return x
131
+
132
+
133
+ class Upsample1d(nn.Module):
134
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
135
+ super().__init__()
136
+ self.pad_mode = pad_mode
137
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
138
+ self.pad = kernel_1d.shape[0] // 2 - 1
139
+ self.register_buffer('kernel', kernel_1d)
140
+ self.channels_last = channels_last
141
+
142
+ def forward(self, x):
143
+ if self.channels_last:
144
+ x = x.permute(0, 2, 1)
145
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
146
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
147
+ indices = torch.arange(x.shape[1], device=x.device)
148
+ weight[indices, indices] = self.kernel.to(weight)
149
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
150
+ if self.channels_last:
151
+ x = x.permute(0, 2, 1)
152
+ return x
153
+
154
+ def Downsample1d_2(
155
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
156
+ ) -> nn.Module:
157
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
158
+
159
+ return nn.Conv1d(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ kernel_size=factor * kernel_multiplier + 1,
163
+ stride=factor,
164
+ padding=factor * (kernel_multiplier // 2),
165
+ )
166
+
167
+
168
+ def Upsample1d_2(
169
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
170
+ ) -> nn.Module:
171
+
172
+ if factor == 1:
173
+ return nn.Conv1d(
174
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
175
+ )
176
+
177
+ if use_nearest:
178
+ return nn.Sequential(
179
+ nn.Upsample(scale_factor=factor, mode="nearest"),
180
+ nn.Conv1d(
181
+ in_channels=in_channels,
182
+ out_channels=out_channels,
183
+ kernel_size=3,
184
+ padding=1,
185
+ ),
186
+ )
187
+ else:
188
+ return nn.ConvTranspose1d(
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ kernel_size=factor * 2,
192
+ stride=factor,
193
+ padding=factor // 2 + factor % 2,
194
+ output_padding=factor % 2,
195
+ )
196
+
197
+ def zero_init(layer):
198
+ nn.init.zeros_(layer.weight)
199
+ if layer.bias is not None:
200
+ nn.init.zeros_(layer.bias)
201
+ return layer
202
+
203
+ def rms_norm(x, scale, eps):
204
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
205
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
206
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
207
+ return x * scale.to(x.dtype)
208
+
209
+ #rms_norm = torch.compile(rms_norm)
210
+
211
+ class AdaRMSNorm(nn.Module):
212
+ def __init__(self, features, cond_features, eps=1e-6):
213
+ super().__init__()
214
+ self.eps = eps
215
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
216
+
217
+ def extra_repr(self):
218
+ return f"eps={self.eps},"
219
+
220
+ def forward(self, x, cond):
221
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
222
+
223
+ def normalize(x, eps=1e-4):
224
+ dim = list(range(1, x.ndim))
225
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
226
+ alpha = np.sqrt(n.numel() / x.numel())
227
+ return x / torch.add(eps, n, alpha=alpha)
228
+
229
+ class ForcedWNConv1d(nn.Module):
230
+ def __init__(self, in_channels, out_channels, kernel_size=1):
231
+ super().__init__()
232
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
233
+
234
+ def forward(self, x):
235
+ if self.training:
236
+ with torch.no_grad():
237
+ self.weight.copy_(normalize(self.weight))
238
+
239
+ fan_in = self.weight[0].numel()
240
+
241
+ w = normalize(self.weight) / math.sqrt(fan_in)
242
+
243
+ return F.conv1d(x, w, padding='same')
244
+
245
+ # Kernels
246
+
247
+ use_compile = True
248
+
249
+ def compile(function, *args, **kwargs):
250
+ if not use_compile:
251
+ return function
252
+ try:
253
+ return torch.compile(function, *args, **kwargs)
254
+ except RuntimeError:
255
+ return function
256
+
257
+
258
+ @compile
259
+ def linear_geglu(x, weight, bias=None):
260
+ x = x @ weight.mT
261
+ if bias is not None:
262
+ x = x + bias
263
+ x, gate = x.chunk(2, dim=-1)
264
+ return x * F.gelu(gate)
265
+
266
+
267
+ @compile
268
+ def rms_norm(x, scale, eps):
269
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
270
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
271
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
272
+ return x * scale.to(x.dtype)
273
+
274
+ # Layers
275
+
276
+ class LinearGEGLU(nn.Linear):
277
+ def __init__(self, in_features, out_features, bias=True):
278
+ super().__init__(in_features, out_features * 2, bias=bias)
279
+ self.out_features = out_features
280
+
281
+ def forward(self, x):
282
+ return linear_geglu(x, self.weight, self.bias)
283
+
284
+
285
+ class RMSNorm(nn.Module):
286
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
287
+ super().__init__()
288
+ self.eps = eps
289
+
290
+ if fix_scale:
291
+ self.register_buffer("scale", torch.ones(shape))
292
+ else:
293
+ self.scale = nn.Parameter(torch.ones(shape))
294
+
295
+ def extra_repr(self):
296
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
297
+
298
+ def forward(self, x):
299
+ return rms_norm(x, self.scale, self.eps)
300
+
301
+ def snake_beta(x, alpha, beta):
302
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
303
+
304
+ # try:
305
+ # snake_beta = torch.compile(snake_beta)
306
+ # except RuntimeError:
307
+ # pass
308
+
309
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
310
+ # License available in LICENSES/LICENSE_NVIDIA.txt
311
+ class SnakeBeta(nn.Module):
312
+
313
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
314
+ super(SnakeBeta, self).__init__()
315
+ self.in_features = in_features
316
+
317
+ # initialize alpha
318
+ self.alpha_logscale = alpha_logscale
319
+ if self.alpha_logscale: # log scale alphas initialized to zeros
320
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
321
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
322
+ else: # linear scale alphas initialized to ones
323
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
324
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
325
+
326
+ self.alpha.requires_grad = alpha_trainable
327
+ self.beta.requires_grad = alpha_trainable
328
+
329
+ self.no_div_by_zero = 0.000000001
330
+
331
+ def forward(self, x):
332
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
333
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
334
+ if self.alpha_logscale:
335
+ alpha = torch.exp(alpha)
336
+ beta = torch.exp(beta)
337
+ x = snake_beta(x, alpha, beta)
338
+
339
+ return x
stable_audio_tools/models/bottleneck.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import ResidualVQ, FSQ
8
+ from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
9
+
10
+ class Bottleneck(nn.Module):
11
+ def __init__(self, is_discrete: bool = False):
12
+ super().__init__()
13
+
14
+ self.is_discrete = is_discrete
15
+
16
+ def encode(self, x, return_info=False, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, x):
20
+ raise NotImplementedError
21
+
22
+ class DiscreteBottleneck(Bottleneck):
23
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
24
+ super().__init__(is_discrete=True)
25
+
26
+ self.num_quantizers = num_quantizers
27
+ self.codebook_size = codebook_size
28
+ self.tokens_id = tokens_id
29
+
30
+ def decode_tokens(self, codes, **kwargs):
31
+ raise NotImplementedError
32
+
33
+ class TanhBottleneck(Bottleneck):
34
+ def __init__(self):
35
+ super().__init__(is_discrete=False)
36
+ self.tanh = nn.Tanh()
37
+
38
+ def encode(self, x, return_info=False):
39
+ info = {}
40
+
41
+ x = torch.tanh(x)
42
+
43
+ if return_info:
44
+ return x, info
45
+ else:
46
+ return x
47
+
48
+ def decode(self, x):
49
+ return x
50
+
51
+ def vae_sample(mean, scale):
52
+ stdev = nn.functional.softplus(scale) + 1e-4
53
+ var = stdev * stdev
54
+ logvar = torch.log(var)
55
+ latents = torch.randn_like(mean) * stdev + mean
56
+
57
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
58
+
59
+ return latents, kl
60
+
61
+ class VAEBottleneck(Bottleneck):
62
+ def __init__(self):
63
+ super().__init__(is_discrete=False)
64
+
65
+ def encode(self, x, return_info=False, **kwargs):
66
+ info = {}
67
+
68
+ mean, scale = x.chunk(2, dim=1)
69
+
70
+ x, kl = vae_sample(mean, scale)
71
+
72
+ info["kl"] = kl
73
+
74
+ if return_info:
75
+ return x, info
76
+ else:
77
+ return x
78
+
79
+ def decode(self, x):
80
+ return x
81
+
82
+ def compute_mean_kernel(x, y):
83
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
84
+ return torch.exp(-kernel_input).mean()
85
+
86
+ def compute_mmd(latents):
87
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
88
+ noise = torch.randn_like(latents_reshaped)
89
+
90
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
91
+ noise_kernel = compute_mean_kernel(noise, noise)
92
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
93
+
94
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
95
+ return mmd.mean()
96
+
97
+ class WassersteinBottleneck(Bottleneck):
98
+ def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
99
+ super().__init__(is_discrete=False)
100
+
101
+ self.noise_augment_dim = noise_augment_dim
102
+ self.bypass_mmd = bypass_mmd
103
+
104
+ def encode(self, x, return_info=False):
105
+ info = {}
106
+
107
+ if self.training and return_info:
108
+ if self.bypass_mmd:
109
+ mmd = torch.tensor(0.0)
110
+ else:
111
+ mmd = compute_mmd(x)
112
+
113
+ info["mmd"] = mmd
114
+
115
+ if return_info:
116
+ return x, info
117
+
118
+ return x
119
+
120
+ def decode(self, x):
121
+
122
+ if self.noise_augment_dim > 0:
123
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
124
+ x.shape[-1]).type_as(x)
125
+ x = torch.cat([x, noise], dim=1)
126
+
127
+ return x
128
+
129
+ class L2Bottleneck(Bottleneck):
130
+ def __init__(self):
131
+ super().__init__(is_discrete=False)
132
+
133
+ def encode(self, x, return_info=False):
134
+ info = {}
135
+
136
+ x = F.normalize(x, dim=1)
137
+
138
+ if return_info:
139
+ return x, info
140
+ else:
141
+ return x
142
+
143
+ def decode(self, x):
144
+ return F.normalize(x, dim=1)
145
+
146
+ class RVQBottleneck(DiscreteBottleneck):
147
+ def __init__(self, **quantizer_kwargs):
148
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
149
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
150
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
151
+
152
+ def encode(self, x, return_info=False, **kwargs):
153
+ info = {}
154
+
155
+ x = rearrange(x, "b c n -> b n c")
156
+ x, indices, loss = self.quantizer(x)
157
+ x = rearrange(x, "b n c -> b c n")
158
+
159
+ info["quantizer_indices"] = indices
160
+ info["quantizer_loss"] = loss.mean()
161
+
162
+ if return_info:
163
+ return x, info
164
+ else:
165
+ return x
166
+
167
+ def decode(self, x):
168
+ return x
169
+
170
+ def decode_tokens(self, codes, **kwargs):
171
+ latents = self.quantizer.get_outputs_from_indices(codes)
172
+
173
+ return self.decode(latents, **kwargs)
174
+
175
+ class RVQVAEBottleneck(DiscreteBottleneck):
176
+ def __init__(self, **quantizer_kwargs):
177
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
178
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
179
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
180
+
181
+ def encode(self, x, return_info=False):
182
+ info = {}
183
+
184
+ x, kl = vae_sample(*x.chunk(2, dim=1))
185
+
186
+ info["kl"] = kl
187
+
188
+ x = rearrange(x, "b c n -> b n c")
189
+ x, indices, loss = self.quantizer(x)
190
+ x = rearrange(x, "b n c -> b c n")
191
+
192
+ info["quantizer_indices"] = indices
193
+ info["quantizer_loss"] = loss.mean()
194
+
195
+ if return_info:
196
+ return x, info
197
+ else:
198
+ return x
199
+
200
+ def decode(self, x):
201
+ return x
202
+
203
+ def decode_tokens(self, codes, **kwargs):
204
+ latents = self.quantizer.get_outputs_from_indices(codes)
205
+
206
+ return self.decode(latents, **kwargs)
207
+
208
+ class DACRVQBottleneck(DiscreteBottleneck):
209
+ def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
210
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
211
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
212
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
213
+ self.quantize_on_decode = quantize_on_decode
214
+ self.noise_augment_dim = noise_augment_dim
215
+
216
+ def encode(self, x, return_info=False, **kwargs):
217
+ info = {}
218
+
219
+ info["pre_quantizer"] = x
220
+
221
+ if self.quantize_on_decode:
222
+ return x, info if return_info else x
223
+
224
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
225
+
226
+ output = {
227
+ "z": z,
228
+ "codes": codes,
229
+ "latents": latents,
230
+ "vq/commitment_loss": commitment_loss,
231
+ "vq/codebook_loss": codebook_loss,
232
+ }
233
+
234
+ output["vq/commitment_loss"] /= self.num_quantizers
235
+ output["vq/codebook_loss"] /= self.num_quantizers
236
+
237
+ info.update(output)
238
+
239
+ if return_info:
240
+ return output["z"], info
241
+
242
+ return output["z"]
243
+
244
+ def decode(self, x):
245
+
246
+ if self.quantize_on_decode:
247
+ x = self.quantizer(x)[0]
248
+
249
+ if self.noise_augment_dim > 0:
250
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
251
+ x.shape[-1]).type_as(x)
252
+ x = torch.cat([x, noise], dim=1)
253
+
254
+ return x
255
+
256
+ def decode_tokens(self, codes, **kwargs):
257
+ latents, _, _ = self.quantizer.from_codes(codes)
258
+
259
+ return self.decode(latents, **kwargs)
260
+
261
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
262
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
263
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
264
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
265
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
266
+ self.quantize_on_decode = quantize_on_decode
267
+
268
+ def encode(self, x, return_info=False, n_quantizers: int = None):
269
+ info = {}
270
+
271
+ mean, scale = x.chunk(2, dim=1)
272
+
273
+ x, kl = vae_sample(mean, scale)
274
+
275
+ info["pre_quantizer"] = x
276
+ info["kl"] = kl
277
+
278
+ if self.quantize_on_decode:
279
+ return x, info if return_info else x
280
+
281
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
282
+
283
+ output = {
284
+ "z": z,
285
+ "codes": codes,
286
+ "latents": latents,
287
+ "vq/commitment_loss": commitment_loss,
288
+ "vq/codebook_loss": codebook_loss,
289
+ }
290
+
291
+ output["vq/commitment_loss"] /= self.num_quantizers
292
+ output["vq/codebook_loss"] /= self.num_quantizers
293
+
294
+ info.update(output)
295
+
296
+ if return_info:
297
+ return output["z"], info
298
+
299
+ return output["z"]
300
+
301
+ def decode(self, x):
302
+
303
+ if self.quantize_on_decode:
304
+ x = self.quantizer(x)[0]
305
+
306
+ return x
307
+
308
+ def decode_tokens(self, codes, **kwargs):
309
+ latents, _, _ = self.quantizer.from_codes(codes)
310
+
311
+ return self.decode(latents, **kwargs)
312
+
313
+ class FSQBottleneck(DiscreteBottleneck):
314
+ def __init__(self, noise_augment_dim=0, **kwargs):
315
+ super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
316
+
317
+ self.noise_augment_dim = noise_augment_dim
318
+
319
+ self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
320
+
321
+ def encode(self, x, return_info=False):
322
+ info = {}
323
+
324
+ orig_dtype = x.dtype
325
+ x = x.float()
326
+
327
+ x = rearrange(x, "b c n -> b n c")
328
+ x, indices = self.quantizer(x)
329
+ x = rearrange(x, "b n c -> b c n")
330
+
331
+ x = x.to(orig_dtype)
332
+
333
+ # Reorder indices to match the expected format
334
+ indices = rearrange(indices, "b n q -> b q n")
335
+
336
+ info["quantizer_indices"] = indices
337
+
338
+ if return_info:
339
+ return x, info
340
+ else:
341
+ return x
342
+
343
+ def decode(self, x):
344
+
345
+ if self.noise_augment_dim > 0:
346
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
347
+ x.shape[-1]).type_as(x)
348
+ x = torch.cat([x, noise], dim=1)
349
+
350
+ return x
351
+
352
+ def decode_tokens(self, tokens, **kwargs):
353
+ latents = self.quantizer.indices_to_codes(tokens)
354
+
355
+ return self.decode(latents, **kwargs)
stable_audio_tools/models/codebook_patterns.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License
2
+ # License available in LICENSES/LICENSE_META.txt
3
+
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ import logging
8
+ import typing as tp
9
+
10
+ from abc import ABC, abstractmethod
11
+ import torch
12
+
13
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
14
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class Pattern:
20
+ """Base implementation of a pattern over a sequence with multiple codebooks.
21
+
22
+ The codebook pattern consists in a layout, defining for each sequence step
23
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
24
+ The first item of the pattern is always an empty list in order to properly insert a special token
25
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
26
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
27
+
28
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
29
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
30
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
31
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
32
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
33
+ is returned along with a mask indicating valid tokens.
34
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
35
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
36
+ to fill and specify invalid positions if needed.
37
+ See the dedicated methods for more details.
38
+ """
39
+ # Pattern layout, for each sequence step, we have a list of coordinates
40
+ # corresponding to the original codebook timestep and position.
41
+ # The first list is always an empty list in order to properly insert
42
+ # a special token to start with.
43
+ layout: PatternLayout
44
+ timesteps: int
45
+ n_q: int
46
+
47
+ def __post_init__(self):
48
+ assert len(self.layout) > 0
49
+ self._validate_layout()
50
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
51
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
52
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
53
+
54
+ def _validate_layout(self):
55
+ """Runs checks on the layout to ensure a valid pattern is defined.
56
+ A pattern is considered invalid if:
57
+ - Multiple timesteps for a same codebook are defined in the same sequence step
58
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
59
+ (this would mean that we have future timesteps before past timesteps).
60
+ """
61
+ q_timesteps = {q: 0 for q in range(self.n_q)}
62
+ for s, seq_coords in enumerate(self.layout):
63
+ if len(seq_coords) > 0:
64
+ qs = set()
65
+ for coord in seq_coords:
66
+ qs.add(coord.q)
67
+ last_q_timestep = q_timesteps[coord.q]
68
+ assert coord.t >= last_q_timestep, \
69
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
70
+ q_timesteps[coord.q] = coord.t
71
+ # each sequence step contains at max 1 coordinate per codebook
72
+ assert len(qs) == len(seq_coords), \
73
+ f"Multiple entries for a same codebook are found at step {s}"
74
+
75
+ @property
76
+ def num_sequence_steps(self):
77
+ return len(self.layout) - 1
78
+
79
+ @property
80
+ def max_delay(self):
81
+ max_t_in_seq_coords = 0
82
+ for seq_coords in self.layout[1:]:
83
+ for coords in seq_coords:
84
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
85
+ return max_t_in_seq_coords - self.timesteps
86
+
87
+ @property
88
+ def valid_layout(self):
89
+ valid_step = len(self.layout) - self.max_delay
90
+ return self.layout[:valid_step]
91
+
92
+ def starts_with_special_token(self):
93
+ return self.layout[0] == []
94
+
95
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
96
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
97
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
98
+ and the actual codebook coordinates.
99
+ """
100
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
101
+ if q is not None:
102
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
103
+ coords = []
104
+ for s, seq_codes in enumerate(self.layout):
105
+ for code in seq_codes:
106
+ if code.t == t and (q is None or code.q == q):
107
+ coords.append((s, code))
108
+ return coords
109
+
110
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
111
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
112
+
113
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
114
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
115
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
116
+
117
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
118
+ device: tp.Union[torch.device, str] = 'cpu'):
119
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
120
+
121
+ Args:
122
+ timesteps (int): Maximum number of timesteps steps to consider.
123
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
124
+ device (torch.device or str): Device for created tensors.
125
+ Returns:
126
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128
+ """
129
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
132
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
133
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
134
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
135
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
136
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
137
+ # fill indexes with last sequence step value that will correspond to our special token
138
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
139
+ # which will correspond to the index: n_q * timesteps
140
+ indexes[:] = n_q * timesteps
141
+ # iterate over the pattern and fill scattered indexes and mask
142
+ for s, sequence_coords in enumerate(ref_layout):
143
+ for coords in sequence_coords:
144
+ if coords.t < timesteps:
145
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
146
+ mask[coords.q, s] = 1
147
+ indexes = torch.from_numpy(indexes).to(device)
148
+ mask = torch.from_numpy(mask).to(device)
149
+ return indexes, mask
150
+
151
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
152
+ """Build sequence corresponding to the pattern from the input tensor z.
153
+ The sequence is built using up to sequence_steps if specified, and non-pattern
154
+ coordinates are filled with the special token.
155
+
156
+ Args:
157
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
158
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
159
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
160
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
161
+ Returns:
162
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
163
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
164
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
165
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
166
+ """
167
+ B, K, T = z.shape
168
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
169
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
170
+ )
171
+ z = z.view(B, -1)
172
+ # we append the special token as the last index of our flattened z tensor
173
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
174
+ values = z[:, indexes.view(-1)]
175
+ values = values.view(B, K, indexes.shape[-1])
176
+ return values, indexes, mask
177
+
178
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
179
+ keep_only_valid_steps: bool = False,
180
+ is_model_output: bool = False,
181
+ device: tp.Union[torch.device, str] = 'cpu'):
182
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
183
+ from interleaving pattern.
184
+
185
+ Args:
186
+ sequence_steps (int): Sequence steps.
187
+ n_q (int): Number of codebooks.
188
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
189
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
190
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
191
+ device (torch.device or str): Device for created tensors.
192
+ Returns:
193
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
194
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
195
+ """
196
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198
+ timesteps = self.timesteps
199
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200
+ assert sequence_steps <= len(ref_layout), \
201
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
+
203
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
204
+ if is_model_output and self.starts_with_special_token():
205
+ ref_layout = ref_layout[1:]
206
+
207
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
208
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
209
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
210
+ # fill indexes with last sequence step value that will correspond to our special token
211
+ indexes[:] = n_q * sequence_steps
212
+ for s, sequence_codes in enumerate(ref_layout):
213
+ if s < sequence_steps:
214
+ for code in sequence_codes:
215
+ if code.t < timesteps:
216
+ indexes[code.q, code.t] = s + code.q * sequence_steps
217
+ mask[code.q, code.t] = 1
218
+ indexes = torch.from_numpy(indexes).to(device)
219
+ mask = torch.from_numpy(mask).to(device)
220
+ return indexes, mask
221
+
222
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
223
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
224
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
225
+ are filled with the special token.
226
+
227
+ Args:
228
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
229
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
230
+ Returns:
231
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
232
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
233
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
234
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
+ """
236
+ B, K, S = s.shape
237
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
238
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
239
+ )
240
+ s = s.view(B, -1)
241
+ # we append the special token as the last index of our flattened z tensor
242
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
243
+ values = s[:, indexes.view(-1)]
244
+ values = values.view(B, K, indexes.shape[-1])
245
+ return values, indexes, mask
246
+
247
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
248
+ """Revert model logits obtained on a sequence built from the pattern
249
+ back to a tensor matching the original sequence.
250
+
251
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
252
+ 1. It is designed to work with the extra cardinality dimension
253
+ 2. We return the logits for the first sequence item that matches the special_token and
254
+ which matching target in the original sequence is the first item of the sequence,
255
+ while we skip the last logits as there is no matching target
256
+ """
257
+ B, card, K, S = logits.shape
258
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
259
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
260
+ )
261
+ logits = logits.reshape(B, card, -1)
262
+ # we append the special token as the last index of our flattened z tensor
263
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
264
+ values = logits[:, :, indexes.view(-1)]
265
+ values = values.view(B, card, K, indexes.shape[-1])
266
+ return values, indexes, mask
267
+
268
+
269
+ class CodebooksPatternProvider(ABC):
270
+ """Abstraction around providing pattern for interleaving codebooks.
271
+
272
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
273
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
274
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
275
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
276
+ can be used to construct a new sequence from the original codes respecting the specified
277
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
278
+ being a tuple with the original timestep and codebook to build the new sequence.
279
+ Note that all patterns must start with an empty list that is then used to insert a first
280
+ sequence step of special tokens in the newly generated sequence.
281
+
282
+ Args:
283
+ n_q (int): number of codebooks.
284
+ cached (bool): if True, patterns for a given length are cached. In general
285
+ that should be true for efficiency reason to avoid synchronization points.
286
+ """
287
+ def __init__(self, n_q: int, cached: bool = True):
288
+ assert n_q > 0
289
+ self.n_q = n_q
290
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
291
+
292
+ @abstractmethod
293
+ def get_pattern(self, timesteps: int) -> Pattern:
294
+ """Builds pattern with specific interleaving between codebooks.
295
+
296
+ Args:
297
+ timesteps (int): Total number of timesteps.
298
+ """
299
+ raise NotImplementedError()
300
+
301
+
302
+ class DelayedPatternProvider(CodebooksPatternProvider):
303
+ """Provider for delayed pattern across delayed codebooks.
304
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
305
+ from different timesteps.
306
+
307
+ Example:
308
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
309
+ [[1, 2, 3, 4],
310
+ [1, 2, 3, 4],
311
+ [1, 2, 3, 4]]
312
+ The resulting sequence obtained from the returned pattern is:
313
+ [[S, 1, 2, 3, 4],
314
+ [S, S, 1, 2, 3],
315
+ [S, S, S, 1, 2]]
316
+ (with S being a special token)
317
+
318
+ Args:
319
+ n_q (int): Number of codebooks.
320
+ delays (list of int, optional): Delay for each of the codebooks.
321
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
322
+ flatten_first (int): Flatten the first N timesteps.
323
+ empty_initial (int): Prepend with N empty list of coordinates.
324
+ """
325
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
326
+ flatten_first: int = 0, empty_initial: int = 0):
327
+ super().__init__(n_q)
328
+ if delays is None:
329
+ delays = list(range(n_q))
330
+ self.delays = delays
331
+ self.flatten_first = flatten_first
332
+ self.empty_initial = empty_initial
333
+ assert len(self.delays) == self.n_q
334
+ assert sorted(self.delays) == self.delays
335
+
336
+ def get_pattern(self, timesteps: int) -> Pattern:
337
+ omit_special_token = self.empty_initial < 0
338
+ out: PatternLayout = [] if omit_special_token else [[]]
339
+ max_delay = max(self.delays)
340
+ if self.empty_initial:
341
+ out += [[] for _ in range(self.empty_initial)]
342
+ if self.flatten_first:
343
+ for t in range(min(timesteps, self.flatten_first)):
344
+ for q in range(self.n_q):
345
+ out.append([LayoutCoord(t, q)])
346
+ for t in range(self.flatten_first, timesteps + max_delay):
347
+ v = []
348
+ for q, delay in enumerate(self.delays):
349
+ t_for_q = t - delay
350
+ if t_for_q >= self.flatten_first:
351
+ v.append(LayoutCoord(t_for_q, q))
352
+ out.append(v)
353
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
354
+
355
+
356
+ class ParallelPatternProvider(DelayedPatternProvider):
357
+ """Provider for parallel pattern across codebooks.
358
+ This pattern provider is a special case of the delayed pattern with actually no delay,
359
+ hence delays=repeat(0, n_q).
360
+
361
+ Args:
362
+ n_q (int): Number of codebooks.
363
+ empty_initial (int): Prepend with N empty list of coordinates.
364
+ """
365
+ def __init__(self, n_q: int, empty_initial: int = 0):
366
+ super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
367
+
368
+
369
+ class UnrolledPatternProvider(CodebooksPatternProvider):
370
+ """Provider for unrolling codebooks pattern.
371
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
372
+ while also specifying a given delay between the flattened codebooks representation, allowing to
373
+ unroll the codebooks in the sequence.
374
+
375
+ Example:
376
+ 1. Flattening of the codebooks.
377
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
378
+ taking n_q = 3 and timesteps = 4:
379
+ [[1, 2, 3, 4],
380
+ [1, 2, 3, 4],
381
+ [1, 2, 3, 4]]
382
+ will result into:
383
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
384
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
385
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
386
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
387
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
388
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
389
+ [[1, 2, 3, 4],
390
+ [1, 2, 3, 4],
391
+ [1, 2, 3, 4]]
392
+ will result into:
393
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
394
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
395
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
396
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
397
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
398
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
399
+ and delays = [0, 3, 3]:
400
+ [[1, 2, 3, 4],
401
+ [1, 2, 3, 4],
402
+ [1, 2, 3, 4]]
403
+ will result into:
404
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
405
+ [S, S, S, 1, S, 2, S, 3, S, 4],
406
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
407
+
408
+ Args:
409
+ n_q (int): Number of codebooks.
410
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
411
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
412
+ have n_q extra steps for each timestep.
413
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
414
+ no delay is added and therefore will default to [0] * ``n_q``.
415
+ Note that two codebooks that will be flattened to the same inner step
416
+ should have the same delay, otherwise the pattern is considered as invalid.
417
+ """
418
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
419
+
420
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
421
+ delays: tp.Optional[tp.List[int]] = None):
422
+ super().__init__(n_q)
423
+ if flattening is None:
424
+ flattening = list(range(n_q))
425
+ if delays is None:
426
+ delays = [0] * n_q
427
+ assert len(flattening) == n_q
428
+ assert len(delays) == n_q
429
+ assert sorted(flattening) == flattening
430
+ assert sorted(delays) == delays
431
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
432
+ self.max_delay = max(delays)
433
+
434
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
435
+ """Build a flattened codebooks representation as a dictionary of inner step
436
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
437
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
438
+ """
439
+ flattened_codebooks: dict = {}
440
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
441
+ if inner_step not in flattened_codebooks:
442
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
443
+ else:
444
+ flat_codebook = flattened_codebooks[inner_step]
445
+ assert flat_codebook.delay == delay, (
446
+ "Delay and flattening between codebooks is inconsistent: ",
447
+ "two codebooks flattened to the same position should have the same delay."
448
+ )
449
+ flat_codebook.codebooks.append(q)
450
+ flattened_codebooks[inner_step] = flat_codebook
451
+ return flattened_codebooks
452
+
453
+ @property
454
+ def _num_inner_steps(self):
455
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
456
+ """
457
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
458
+
459
+ def num_virtual_steps(self, timesteps: int) -> int:
460
+ return timesteps * self._num_inner_steps + 1
461
+
462
+ def get_pattern(self, timesteps: int) -> Pattern:
463
+ """Builds pattern for delay across codebooks.
464
+
465
+ Args:
466
+ timesteps (int): Total number of timesteps.
467
+ """
468
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
469
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
470
+ indexed_out: list = [(-1, [])]
471
+ max_timesteps = timesteps + self.max_delay
472
+ for t in range(max_timesteps):
473
+ # for each timestep, we unroll the flattened codebooks,
474
+ # emitting the sequence step with the corresponding delay
475
+ for step in range(self._num_inner_steps):
476
+ if step in self._flattened_codebooks:
477
+ # we have codebooks at this virtual step to emit
478
+ step_codebooks = self._flattened_codebooks[step]
479
+ t_for_q = t + step_codebooks.delay
480
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
481
+ if t_for_q < max_timesteps and t < max_timesteps:
482
+ indexed_out.append((t_for_q, coords))
483
+ else:
484
+ # there is no codebook in this virtual step so we emit an empty list
485
+ indexed_out.append((t, []))
486
+ out = [coords for _, coords in sorted(indexed_out)]
487
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
488
+
489
+
490
+ class CoarseFirstPattern(CodebooksPatternProvider):
491
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
492
+ potentially with delays.
493
+
494
+ ..Warning:: You must always generate the full training duration at test time, for instance,
495
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
496
+ location. This is due to the non causality of the remaining codebooks with respect to
497
+ the first ones.
498
+
499
+ Args:
500
+ n_q (int): Number of codebooks.
501
+ delays (list of int, optional): Delay for each of the codebooks.
502
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
503
+ """
504
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
505
+ super().__init__(n_q)
506
+ if delays is None:
507
+ delays = [0] * (n_q - 1)
508
+ self.delays = delays
509
+ assert len(self.delays) == self.n_q - 1
510
+ assert sorted(self.delays) == self.delays
511
+
512
+ def get_pattern(self, timesteps: int) -> Pattern:
513
+ out: PatternLayout = [[]]
514
+ for t in range(timesteps):
515
+ out.append([LayoutCoord(t, 0)])
516
+ max_delay = max(self.delays)
517
+ for t in range(timesteps + max_delay):
518
+ v = []
519
+ for q, delay in enumerate(self.delays):
520
+ t_for_q = t - delay
521
+ if t_for_q >= 0:
522
+ v.append(LayoutCoord(t_for_q, q + 1))
523
+ out.append(v)
524
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
525
+
526
+
527
+ class MusicLMPattern(CodebooksPatternProvider):
528
+ """Almost MusicLM style pattern. This is equivalent to full flattening
529
+ but in a different order.
530
+
531
+ Args:
532
+ n_q (int): Number of codebooks.
533
+ group_by (int): Number of codebooks to group together.
534
+ """
535
+ def __init__(self, n_q: int, group_by: int = 2):
536
+ super().__init__(n_q)
537
+ self.group_by = group_by
538
+
539
+ def get_pattern(self, timesteps: int) -> Pattern:
540
+ out: PatternLayout = [[]]
541
+ for offset in range(0, self.n_q, self.group_by):
542
+ for t in range(timesteps):
543
+ for q in range(offset, offset + self.group_by):
544
+ out.append([LayoutCoord(t, q)])
545
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
stable_audio_tools/models/conditioners.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
2
+
3
+ import torch
4
+ import logging, warnings
5
+ import string
6
+ import typing as tp
7
+ import gc
8
+
9
+ from .adp import NumberEmbedder
10
+ from ..inference.utils import set_audio_channels
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from .utils import load_ckpt_state_dict
14
+
15
+ from torch import nn
16
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
17
+ import einops
18
+ from .temptransformer import SA_Transformer
19
+ from torchvision import transforms
20
+ import torch
21
+ import einops
22
+ import torchvision.transforms as transforms
23
+
24
+
25
+ class Conditioner(nn.Module):
26
+ def __init__(
27
+ self,
28
+ dim: int,
29
+ output_dim: int,
30
+ project_out: bool = False
31
+ ):
32
+
33
+ super().__init__()
34
+
35
+ self.dim = dim
36
+ self.output_dim = output_dim
37
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
38
+
39
+ def forward(self, x: tp.Any) -> tp.Any:
40
+ raise NotImplementedError()
41
+
42
+ class IntConditioner(Conditioner):
43
+ def __init__(self,
44
+ output_dim: int,
45
+ min_val: int=0,
46
+ max_val: int=512
47
+ ):
48
+ super().__init__(output_dim, output_dim)
49
+
50
+ self.min_val = min_val
51
+ self.max_val = max_val
52
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
53
+
54
+ def forward(self, ints: tp.List[int], device=None) -> tp.Any:
55
+
56
+ #self.int_embedder.to(device)
57
+
58
+ ints = torch.tensor(ints).to(device)
59
+ ints = ints.clamp(self.min_val, self.max_val)
60
+
61
+ int_embeds = self.int_embedder(ints).unsqueeze(1)
62
+
63
+ return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
64
+
65
+ class NumberConditioner(Conditioner):
66
+ '''
67
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
68
+ '''
69
+ def __init__(self,
70
+ output_dim: int,
71
+ min_val: float=0,
72
+ max_val: float=1
73
+ ):
74
+ super().__init__(output_dim, output_dim)
75
+
76
+ self.min_val = min_val
77
+ self.max_val = max_val
78
+
79
+ self.embedder = NumberEmbedder(features=output_dim)
80
+
81
+ def forward(self, floats: tp.List[float], device=None) -> tp.Any:
82
+
83
+ # Cast the inputs to floats
84
+ floats = [float(x) for x in floats]
85
+
86
+ floats = torch.tensor(floats).to(device)
87
+
88
+ floats = floats.clamp(self.min_val, self.max_val)
89
+
90
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
91
+
92
+ # Cast floats to same type as embedder
93
+ embedder_dtype = next(self.embedder.parameters()).dtype
94
+ normalized_floats = normalized_floats.to(embedder_dtype)
95
+
96
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
97
+
98
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
99
+
100
+ class CLAPTextConditioner(Conditioner):
101
+ def __init__(self,
102
+ output_dim: int,
103
+ clap_ckpt_path,
104
+ use_text_features = False,
105
+ feature_layer_ix: int = -1,
106
+ audio_model_type="HTSAT-base",
107
+ enable_fusion=True,
108
+ project_out: bool = False,
109
+ finetune: bool = False):
110
+ super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
111
+
112
+ self.use_text_features = use_text_features
113
+ self.feature_layer_ix = feature_layer_ix
114
+ self.finetune = finetune
115
+
116
+ # Suppress logging from transformers
117
+ previous_level = logging.root.manager.disable
118
+ logging.disable(logging.ERROR)
119
+ with warnings.catch_warnings():
120
+ warnings.simplefilter("ignore")
121
+ try:
122
+ import laion_clap
123
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
124
+
125
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
126
+
127
+ if self.finetune:
128
+ self.model = model
129
+ else:
130
+ self.__dict__["model"] = model
131
+
132
+ state_dict = clap_load_state_dict(clap_ckpt_path)
133
+ self.model.model.load_state_dict(state_dict, strict=False)
134
+
135
+ if self.finetune:
136
+ self.model.model.text_branch.requires_grad_(True)
137
+ self.model.model.text_branch.train()
138
+ else:
139
+ self.model.model.text_branch.requires_grad_(False)
140
+ self.model.model.text_branch.eval()
141
+
142
+ finally:
143
+ logging.disable(previous_level)
144
+
145
+ del self.model.model.audio_branch
146
+
147
+ gc.collect()
148
+ torch.cuda.empty_cache()
149
+
150
+ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
151
+ prompt_tokens = self.model.tokenizer(prompts)
152
+ attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
153
+ prompt_features = self.model.model.text_branch(
154
+ input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
155
+ attention_mask=attention_mask,
156
+ output_hidden_states=True
157
+ )["hidden_states"][layer_ix]
158
+
159
+ return prompt_features, attention_mask
160
+
161
+ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
162
+ self.model.to(device)
163
+
164
+ if self.use_text_features:
165
+ if len(texts) == 1:
166
+ text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
167
+ text_features = text_features[:1, ...]
168
+ text_attention_mask = text_attention_mask[:1, ...]
169
+ else:
170
+ text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
171
+ return [self.proj_out(text_features), text_attention_mask]
172
+
173
+ # Fix for CLAP bug when only one text is passed
174
+ if len(texts) == 1:
175
+ text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
176
+ else:
177
+ text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
178
+
179
+ text_embedding = text_embedding.unsqueeze(1).to(device)
180
+
181
+ return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
182
+
183
+ class CLAPAudioConditioner(Conditioner):
184
+ def __init__(self,
185
+ output_dim: int,
186
+ clap_ckpt_path,
187
+ audio_model_type="HTSAT-base",
188
+ enable_fusion=True,
189
+ project_out: bool = False):
190
+ super().__init__(512, output_dim, project_out=project_out)
191
+
192
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
193
+
194
+ # Suppress logging from transformers
195
+ previous_level = logging.root.manager.disable
196
+ logging.disable(logging.ERROR)
197
+ with warnings.catch_warnings():
198
+ warnings.simplefilter("ignore")
199
+ try:
200
+ import laion_clap
201
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
202
+
203
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
204
+
205
+ if self.finetune:
206
+ self.model = model
207
+ else:
208
+ self.__dict__["model"] = model
209
+
210
+ state_dict = clap_load_state_dict(clap_ckpt_path)
211
+ self.model.model.load_state_dict(state_dict, strict=False)
212
+
213
+ if self.finetune:
214
+ self.model.model.audio_branch.requires_grad_(True)
215
+ self.model.model.audio_branch.train()
216
+ else:
217
+ self.model.model.audio_branch.requires_grad_(False)
218
+ self.model.model.audio_branch.eval()
219
+
220
+ finally:
221
+ logging.disable(previous_level)
222
+
223
+ del self.model.model.text_branch
224
+
225
+ gc.collect()
226
+ torch.cuda.empty_cache()
227
+
228
+ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
229
+
230
+ self.model.to(device)
231
+
232
+ if isinstance(audios, list) or isinstance(audios, tuple):
233
+ audios = torch.cat(audios, dim=0)
234
+
235
+ # Convert to mono
236
+ mono_audios = audios.mean(dim=1)
237
+
238
+ with torch.cuda.amp.autocast(enabled=False):
239
+ audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
240
+
241
+ audio_embedding = audio_embedding.unsqueeze(1).to(device)
242
+
243
+ return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
244
+
245
+
246
+ class CLIPConditioner(Conditioner):
247
+ CLIP_MODELS = ["clip-vit-base-patch32"]
248
+
249
+ def __init__(
250
+ self,
251
+ output_dim: int,
252
+ clip_model_name: str = "clip-vit-base-patch32",
253
+ video_fps: int = 5,
254
+ out_features: str = 128,
255
+ enable_grad: bool = False,
256
+ in_features: int = 5000,
257
+ project_out: bool = False,
258
+ ):
259
+ assert clip_model_name in self.CLIP_MODELS, f"Unknown clip model name: {clip_model_name}"
260
+ super().__init__(dim = 768, output_dim=output_dim, project_out=project_out)
261
+
262
+ sa_depth=4
263
+ num_heads=16
264
+ dim_head=64
265
+ hidden_scale=4
266
+ duration = 10
267
+
268
+ self.clip_model_name=clip_model_name
269
+
270
+ if self.clip_model_name=='clip-vit-base-patch32':
271
+ out_features = 128
272
+ temporal_dim=768
273
+
274
+ self.empty_visual_feat = nn.Parameter(torch.zeros(1, out_features, temporal_dim), requires_grad=True)
275
+ nn.init.constant_(self.empty_visual_feat, 0)
276
+
277
+ in_features = 50*video_fps*duration
278
+
279
+ self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained('openai/clip-vit-base-patch32')
280
+ self.proj = nn.Linear(in_features=in_features, out_features=out_features)
281
+
282
+ self.in_features = in_features
283
+ self.out_features = out_features
284
+
285
+ self.Temp_transformer = SA_Transformer(temporal_dim, sa_depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.)
286
+ self.Temp_pos_embedding = nn.Parameter(torch.randn(1, duration*video_fps, temporal_dim))
287
+
288
+ clip_mean = [0.48145466, 0.4578275, 0.40821073]
289
+ clip_std = [0.26862954, 0.26130258, 0.27577711]
290
+ self.preprocess_CLIP = transforms.Compose([
291
+ transforms.Normalize(mean=clip_mean, std=clip_std)
292
+ ])
293
+
294
+ def process_video_with_custom_preprocessing(self, video_tensor):
295
+ video_tensor = video_tensor / 255.0
296
+ video_tensor = self.preprocess_CLIP(video_tensor)
297
+ return video_tensor
298
+
299
+ def init_first_from_ckpt(self, path):
300
+ model = torch.load(path, map_location="cpu")
301
+ if "state_dict" in list(model.keys()):
302
+ model = model["state_dict"]
303
+ # Remove: module prefix
304
+ new_model = {}
305
+ for key in model.keys():
306
+ new_key = key.replace("module.","")
307
+ new_model[new_key] = model[key]
308
+ missing, unexpected = self.visual_encoder_model.load_state_dict(new_model, strict=False)
309
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
310
+ if len(missing) > 0:
311
+ print(f"Missing Keys: {missing}")
312
+ if len(unexpected) > 0:
313
+ print(f"Unexpected Keys: {unexpected}")
314
+
315
+ def forward(self, Video_tensors: tp.List[torch.Tensor], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
316
+ visual_encoder_model = self.visual_encoder_model.eval().to(device)
317
+ proj = self.proj.to(device)
318
+
319
+ original_videos = torch.cat(Video_tensors, dim=0).to(device)
320
+ batch_size, time_length, _, _, _ = original_videos.size()
321
+ is_zero = torch.all(original_videos == 0, dim=(1,2,3,4))
322
+ Video_tensors = original_videos
323
+ Video_tensors = einops.rearrange(Video_tensors, 'b t c h w -> (b t) c h w')
324
+
325
+ video_cond_pixel_values = self.process_video_with_custom_preprocessing(video_tensor=Video_tensors.to(device)).to(device)
326
+ if self.clip_model_name=='clip-vit-base-patch32':
327
+ with torch.no_grad():
328
+ outputs = visual_encoder_model(pixel_values=video_cond_pixel_values)
329
+ video_hidden = outputs.last_hidden_state
330
+
331
+ video_hidden = einops.rearrange(video_hidden, '(b t) q h -> (b q) t h',b=batch_size,t=time_length)
332
+ video_hidden += self.Temp_pos_embedding
333
+ video_hidden = self.Temp_transformer(video_hidden)
334
+ video_hidden = einops.rearrange(video_hidden, '(b q) t h -> b (t q) h',b=batch_size,t=time_length)
335
+
336
+ video_hidden = proj(video_hidden.view(-1, self.in_features))
337
+ video_hidden = video_hidden.view(batch_size, self.out_features, -1)
338
+
339
+ empty_visual_feat = self.empty_visual_feat.expand(batch_size, -1, -1)
340
+ is_zero_expanded = is_zero.view(batch_size, 1, 1)
341
+ video_hidden = torch.where(is_zero_expanded, empty_visual_feat, video_hidden)
342
+
343
+ return video_hidden, torch.ones(video_hidden.shape[0], 1).to(device)
344
+
345
+
346
+
347
+ class T5Conditioner(Conditioner):
348
+
349
+ T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
350
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
351
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
352
+
353
+ T5_MODEL_DIMS = {
354
+ "t5-small": 512,
355
+ "t5-base": 768,
356
+ "t5-large": 1024,
357
+ "t5-3b": 1024,
358
+ "t5-11b": 1024,
359
+ "t5-xl": 2048,
360
+ "t5-xxl": 4096,
361
+ "google/flan-t5-small": 512,
362
+ "google/flan-t5-base": 768,
363
+ "google/flan-t5-large": 1024,
364
+ "google/flan-t5-3b": 1024,
365
+ "google/flan-t5-11b": 1024,
366
+ "google/flan-t5-xl": 2048,
367
+ "google/flan-t5-xxl": 4096,
368
+ }
369
+
370
+ def __init__(
371
+ self,
372
+ output_dim: int,
373
+ t5_model_name: str = "t5-base",
374
+ max_length: str = 128,
375
+ enable_grad: bool = False,
376
+ project_out: bool = False,
377
+ ):
378
+ assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
379
+ super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
380
+
381
+ from transformers import T5EncoderModel, AutoTokenizer
382
+
383
+ self.max_length = max_length
384
+ self.enable_grad = enable_grad
385
+ # Suppress logging from transformers
386
+ previous_level = logging.root.manager.disable
387
+ logging.disable(logging.ERROR)
388
+ with warnings.catch_warnings():
389
+ warnings.simplefilter("ignore")
390
+ try:
391
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
392
+ model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
393
+ finally:
394
+ logging.disable(previous_level)
395
+
396
+ if self.enable_grad:
397
+ self.model = model
398
+ else:
399
+ self.__dict__["model"] = model
400
+
401
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
402
+
403
+ self.model.to(device)
404
+ self.proj_out.to(device)
405
+
406
+ encoded = self.tokenizer(
407
+ texts,
408
+ truncation=True,
409
+ max_length=self.max_length,
410
+ padding="max_length",
411
+ return_tensors="pt",
412
+ )
413
+
414
+ input_ids = encoded["input_ids"].to(device)
415
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
416
+
417
+ self.model.eval()
418
+
419
+ with torch.cuda.amp.autocast(dtype=torch.float16), torch.set_grad_enabled(self.enable_grad):
420
+ embeddings = self.model(
421
+ input_ids=input_ids, attention_mask=attention_mask
422
+ )["last_hidden_state"]
423
+
424
+ embeddings = self.proj_out(embeddings.float())
425
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
426
+
427
+ return embeddings, attention_mask
428
+
429
+ class PhonemeConditioner(Conditioner):
430
+ """
431
+ A conditioner that turns text into phonemes and embeds them using a lookup table
432
+ Only works for English text
433
+
434
+ Args:
435
+ output_dim: the dimension of the output embeddings
436
+ max_length: the maximum number of phonemes to embed
437
+ project_out: whether to add another linear projection to the output embeddings
438
+ """
439
+
440
+ def __init__(
441
+ self,
442
+ output_dim: int,
443
+ max_length: int = 1024,
444
+ project_out: bool = False,
445
+ ):
446
+ super().__init__(output_dim, output_dim, project_out=project_out)
447
+
448
+ from g2p_en import G2p
449
+ self.max_length = max_length
450
+ self.g2p = G2p()
451
+ # Reserving 0 for padding, 1 for ignored
452
+ self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
453
+
454
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
455
+
456
+ self.phoneme_embedder.to(device)
457
+ self.proj_out.to(device)
458
+
459
+ batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
460
+ phoneme_ignore = [" ", *string.punctuation]
461
+ # Remove ignored phonemes and cut to max length
462
+ batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
463
+
464
+ # Convert to ids
465
+ phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
466
+
467
+ #Pad to match longest and make a mask tensor for the padding
468
+ longest = max([len(ids) for ids in phoneme_ids])
469
+ phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
470
+ phoneme_ids = torch.tensor(phoneme_ids).to(device)
471
+
472
+ # Convert to embeddings
473
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids)
474
+ phoneme_embeds = self.proj_out(phoneme_embeds)
475
+
476
+ return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
477
+
478
+
479
+
480
+ class TokenizerLUTConditioner(Conditioner):
481
+ """
482
+ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
483
+
484
+ Args:
485
+ tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
486
+ output_dim: the dimension of the output embeddings
487
+ max_length: the maximum length of the text to embed
488
+ project_out: whether to add another linear projection to the output embeddings
489
+ """
490
+
491
+ def __init__(
492
+ self,
493
+ tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
494
+ output_dim: int,
495
+ max_length: int = 1024,
496
+ project_out: bool = False,
497
+ ):
498
+ super().__init__(output_dim, output_dim, project_out=project_out)
499
+
500
+ from transformers import AutoTokenizer
501
+
502
+ # Suppress logging from transformers
503
+ previous_level = logging.root.manager.disable
504
+ logging.disable(logging.ERROR)
505
+ with warnings.catch_warnings():
506
+ warnings.simplefilter("ignore")
507
+ try:
508
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
509
+ finally:
510
+ logging.disable(previous_level)
511
+
512
+ self.max_length = max_length
513
+
514
+ self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
515
+
516
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
517
+ self.proj_out.to(device)
518
+
519
+ encoded = self.tokenizer(
520
+ texts,
521
+ truncation=True,
522
+ max_length=self.max_length,
523
+ padding="max_length",
524
+ return_tensors="pt",
525
+ )
526
+
527
+ input_ids = encoded["input_ids"].to(device)
528
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
529
+
530
+ embeddings = self.token_embedder(input_ids)
531
+
532
+ embeddings = self.proj_out(embeddings)
533
+
534
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
535
+
536
+ return embeddings, attention_mask
537
+
538
+ class PretransformConditioner(Conditioner):
539
+ """
540
+ A conditioner that uses a pretransform's encoder for conditioning
541
+
542
+ Args:
543
+ pretransform: an instantiated pretransform to use for conditioning
544
+ output_dim: the dimension of the output embeddings
545
+ """
546
+ def __init__(self, pretransform: Pretransform, output_dim: int):
547
+ super().__init__(pretransform.encoded_channels, output_dim)
548
+
549
+ self.pretransform = pretransform
550
+
551
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
552
+
553
+ self.pretransform.to(device)
554
+ self.proj_out.to(device)
555
+
556
+ if isinstance(audio, list) or isinstance(audio, tuple):
557
+ audio = torch.cat(audio, dim=0)
558
+
559
+ # Convert audio to pretransform input channels
560
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
561
+
562
+ latents = self.pretransform.encode(audio)
563
+ latents = self.proj_out(latents)
564
+
565
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
566
+
567
+
568
+ class AudioAutoencoderConditioner(Conditioner):
569
+ """
570
+ A conditioner that uses a pretransform's encoder for conditioning
571
+
572
+ Args:
573
+ pretransform: an instantiated pretransform to use for conditioning
574
+ output_dim: the dimension of the output embeddings
575
+ """
576
+ def __init__(self, pretransform: Pretransform, output_dim: int):
577
+ super().__init__(pretransform.encoded_channels, output_dim)
578
+
579
+ self.pretransform = pretransform
580
+ self.empty_audio_feat = nn.Parameter(torch.zeros(1, 215, self.proj_out.out_features), requires_grad=True)
581
+ nn.init.constant_(self.empty_audio_feat, 0)
582
+
583
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
584
+
585
+ self.pretransform.to(device)
586
+ self.proj_out.to(device)
587
+
588
+ if isinstance(audio, list) or isinstance(audio, tuple):
589
+ original_audios = torch.cat(audio, dim=0).to(device)
590
+ is_zero = torch.all(original_audios == 0, dim=(1,2))
591
+ audio = original_audios
592
+
593
+ # Convert audio to pretransform input channels
594
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
595
+
596
+ latents = self.pretransform.encode(audio)
597
+ latents = latents.permute(0, 2, 1)
598
+ latents = self.proj_out(latents)
599
+
600
+ empty_audio_feat = self.empty_audio_feat.expand(latents.shape[0], -1, -1)
601
+ is_zero_expanded = is_zero.view(latents.shape[0], 1, 1)
602
+ latents = torch.where(is_zero_expanded, empty_audio_feat, latents)
603
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
604
+
605
+
606
+ class MultiConditioner(nn.Module):
607
+ """
608
+ A module that applies multiple conditioners to an input dictionary based on the keys
609
+
610
+ Args:
611
+ conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
612
+ default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
613
+ """
614
+ def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
615
+ super().__init__()
616
+
617
+ self.conditioners = nn.ModuleDict(conditioners)
618
+ self.default_keys = default_keys
619
+
620
+ def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
621
+ output = {}
622
+
623
+ for key, conditioner in self.conditioners.items():
624
+ condition_key = key
625
+
626
+ conditioner_inputs = []
627
+
628
+ for x in batch_metadata:
629
+
630
+ if condition_key not in x:
631
+ if condition_key in self.default_keys:
632
+ condition_key = self.default_keys[condition_key]
633
+ else:
634
+ raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
635
+
636
+ if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
637
+ conditioner_input = x[condition_key][0]
638
+
639
+ else:
640
+ conditioner_input = x[condition_key]
641
+
642
+ conditioner_inputs.append(conditioner_input)
643
+
644
+ output[key] = conditioner(conditioner_inputs, device)
645
+
646
+ return output
647
+
648
+ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
649
+ """
650
+ Create a MultiConditioner from a conditioning config dictionary
651
+
652
+ Args:
653
+ config: the conditioning config dictionary
654
+ device: the device to put the conditioners on
655
+ """
656
+ conditioners = {}
657
+ cond_dim = config["cond_dim"]
658
+
659
+ default_keys = config.get("default_keys", {})
660
+
661
+ for conditioner_info in config["configs"]:
662
+ id = conditioner_info["id"]
663
+
664
+ conditioner_type = conditioner_info["type"]
665
+
666
+ conditioner_config = {"output_dim": cond_dim}
667
+
668
+ conditioner_config.update(conditioner_info["config"])
669
+
670
+ if conditioner_type == "t5":
671
+ conditioners[id] = T5Conditioner(**conditioner_config)
672
+ elif conditioner_type == "clip":
673
+ conditioners[id] = CLIPConditioner(**conditioner_config)
674
+ elif conditioner_type == "clap_text":
675
+ conditioners[id] = CLAPTextConditioner(**conditioner_config)
676
+ elif conditioner_type == "clap_audio":
677
+ conditioners[id] = CLAPAudioConditioner(**conditioner_config)
678
+ elif conditioner_type == "int":
679
+ conditioners[id] = IntConditioner(**conditioner_config)
680
+ elif conditioner_type == "number":
681
+ conditioners[id] = NumberConditioner(**conditioner_config)
682
+ elif conditioner_type == "phoneme":
683
+ conditioners[id] = PhonemeConditioner(**conditioner_config)
684
+ elif conditioner_type == "lut":
685
+ conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
686
+ elif conditioner_type == "pretransform":
687
+ sample_rate = conditioner_config.pop("sample_rate", None)
688
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
689
+
690
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
691
+
692
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
693
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
694
+
695
+ conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
696
+
697
+ elif conditioner_type == "audio_autoencoder":
698
+ sample_rate = conditioner_config.pop("sample_rate", None)
699
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
700
+
701
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
702
+
703
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
704
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
705
+
706
+ conditioners[id] = AudioAutoencoderConditioner(pretransform, **conditioner_config)
707
+ else:
708
+ raise ValueError(f"Unknown conditioner type: {conditioner_type}")
709
+
710
+ return MultiConditioner(conditioners, default_keys=default_keys)
stable_audio_tools/models/diffusion.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from functools import partial
5
+ import numpy as np
6
+ import typing as tp
7
+
8
+ from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
+ from .dit import DiffusionTransformer
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from ..inference.generation import generate_diffusion_cond
14
+
15
+ from .adp import UNetCFG1d, UNet1d
16
+
17
+ from time import time
18
+
19
+ class Profiler:
20
+
21
+ def __init__(self):
22
+ self.ticks = [[time(), None]]
23
+
24
+ def tick(self, msg):
25
+ self.ticks.append([time(), msg])
26
+
27
+ def __repr__(self):
28
+ rep = 80 * "=" + "\n"
29
+ for i in range(1, len(self.ticks)):
30
+ msg = self.ticks[i][1]
31
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
32
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
33
+ rep += 80 * "=" + "\n\n\n"
34
+ return rep
35
+
36
+ class DiffusionModel(nn.Module):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+ def forward(self, x, t, **kwargs):
41
+ raise NotImplementedError()
42
+
43
+ class DiffusionModelWrapper(nn.Module):
44
+ def __init__(
45
+ self,
46
+ model: DiffusionModel,
47
+ io_channels,
48
+ sample_size,
49
+ sample_rate,
50
+ min_input_length,
51
+ pretransform: tp.Optional[Pretransform] = None,
52
+ ):
53
+ super().__init__()
54
+ self.io_channels = io_channels
55
+ self.sample_size = sample_size
56
+ self.sample_rate = sample_rate
57
+ self.min_input_length = min_input_length
58
+
59
+ self.model = model
60
+
61
+ if pretransform is not None:
62
+ self.pretransform = pretransform
63
+ else:
64
+ self.pretransform = None
65
+
66
+ def forward(self, x, t, **kwargs):
67
+ return self.model(x, t, **kwargs)
68
+
69
+ class ConditionedDiffusionModel(nn.Module):
70
+ def __init__(self,
71
+ *args,
72
+ supports_cross_attention: bool = False,
73
+ supports_input_concat: bool = False,
74
+ supports_global_cond: bool = False,
75
+ supports_prepend_cond: bool = False,
76
+ **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+ self.supports_cross_attention = supports_cross_attention
79
+ self.supports_input_concat = supports_input_concat
80
+ self.supports_global_cond = supports_global_cond
81
+ self.supports_prepend_cond = supports_prepend_cond
82
+
83
+ def forward(self,
84
+ x: torch.Tensor,
85
+ t: torch.Tensor,
86
+ cross_attn_cond: torch.Tensor = None,
87
+ cross_attn_mask: torch.Tensor = None,
88
+ input_concat_cond: torch.Tensor = None,
89
+ global_embed: torch.Tensor = None,
90
+ prepend_cond: torch.Tensor = None,
91
+ prepend_cond_mask: torch.Tensor = None,
92
+ cfg_scale: float = 1.0,
93
+ cfg_dropout_prob: float = 0.0,
94
+ batch_cfg: bool = False,
95
+ rescale_cfg: bool = False,
96
+ **kwargs):
97
+ raise NotImplementedError()
98
+
99
+ class ConditionedDiffusionModelWrapper(nn.Module):
100
+ """
101
+ A diffusion model that takes in conditioning
102
+ """
103
+ def __init__(
104
+ self,
105
+ model: ConditionedDiffusionModel,
106
+ conditioner: MultiConditioner,
107
+ io_channels,
108
+ sample_rate,
109
+ min_input_length: int,
110
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
111
+ pretransform: tp.Optional[Pretransform] = None,
112
+ cross_attn_cond_ids: tp.List[str] = [],
113
+ global_cond_ids: tp.List[str] = [],
114
+ input_concat_ids: tp.List[str] = [],
115
+ prepend_cond_ids: tp.List[str] = [],
116
+ ):
117
+ super().__init__()
118
+
119
+ self.model = model
120
+ self.conditioner = conditioner
121
+ self.io_channels = io_channels
122
+ self.sample_rate = sample_rate
123
+ self.diffusion_objective = diffusion_objective
124
+ self.pretransform = pretransform
125
+ self.cross_attn_cond_ids = cross_attn_cond_ids # ['prompt', 'seconds_start', 'seconds_total']
126
+ self.global_cond_ids = global_cond_ids # ['seconds_start', 'seconds_total']
127
+ self.input_concat_ids = input_concat_ids
128
+ self.prepend_cond_ids = prepend_cond_ids
129
+ self.min_input_length = min_input_length
130
+
131
+ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[torch.Tensor, tp.Any], negative=False):
132
+ cross_attention_input = None
133
+ cross_attention_masks = None
134
+ global_cond = None
135
+ input_concat_cond = None
136
+ prepend_cond = None
137
+ prepend_cond_mask = None
138
+
139
+ if len(self.cross_attn_cond_ids) > 0:
140
+ # Concatenate all cross-attention inputs over the sequence dimension
141
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
142
+ cross_attention_input = []
143
+ cross_attention_masks = []
144
+
145
+ for key in self.cross_attn_cond_ids:
146
+ cross_attn_in, cross_attn_mask = conditioning_tensors[key]
147
+
148
+ # Add sequence dimension if it's not there
149
+ if len(cross_attn_in.shape) == 2:
150
+ cross_attn_in = cross_attn_in.unsqueeze(1)
151
+ cross_attn_mask = cross_attn_mask.unsqueeze(1)
152
+
153
+ cross_attention_input.append(cross_attn_in)
154
+ cross_attention_masks.append(cross_attn_mask)
155
+
156
+ cross_attention_input = torch.cat(cross_attention_input, dim=1) # [1, 130, 768] (text feature:128)
157
+ cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
158
+
159
+ if len(self.global_cond_ids) > 0:
160
+ # Concatenate all global conditioning inputs over the channel dimension
161
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
162
+ global_conds = []
163
+ for key in self.global_cond_ids:
164
+
165
+ global_cond_input = conditioning_tensors[key][0]
166
+
167
+ global_conds.append(global_cond_input)
168
+
169
+ # Concatenate over the channel dimension
170
+ global_cond = torch.cat(global_conds, dim=-1)
171
+
172
+ if len(global_cond.shape) == 3:
173
+ global_cond = global_cond.squeeze(1)
174
+
175
+ if len(self.input_concat_ids) > 0: # False
176
+ # Concatenate all input concat conditioning inputs over the channel dimension
177
+ # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
178
+ input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
179
+
180
+ if len(self.prepend_cond_ids) > 0: # False
181
+ # Concatenate all prepend conditioning inputs over the sequence dimension
182
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
183
+ prepend_conds = []
184
+ prepend_cond_masks = []
185
+
186
+ for key in self.prepend_cond_ids:
187
+ prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
188
+ prepend_conds.append(prepend_cond_input)
189
+ prepend_cond_masks.append(prepend_cond_mask)
190
+
191
+ prepend_cond = torch.cat(prepend_conds, dim=1)
192
+ prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
193
+
194
+ if negative: # False
195
+ return {
196
+ "negative_cross_attn_cond": cross_attention_input,
197
+ "negative_cross_attn_mask": cross_attention_masks,
198
+ "negative_global_cond": global_cond,
199
+ "negative_input_concat_cond": input_concat_cond
200
+ }
201
+ else:
202
+ return {
203
+ "cross_attn_cond": cross_attention_input,
204
+ "cross_attn_mask": cross_attention_masks,
205
+ "global_cond": global_cond,
206
+ "input_concat_cond": input_concat_cond,
207
+ "prepend_cond": prepend_cond,
208
+ "prepend_cond_mask": prepend_cond_mask
209
+ }
210
+
211
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
212
+ return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
213
+
214
+ def generate(self, *args, **kwargs):
215
+ return generate_diffusion_cond(self, *args, **kwargs)
216
+
217
+ class UNetCFG1DWrapper(ConditionedDiffusionModel):
218
+ def __init__(
219
+ self,
220
+ *args,
221
+ **kwargs
222
+ ):
223
+ super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
224
+
225
+ self.model = UNetCFG1d(*args, **kwargs)
226
+
227
+ with torch.no_grad():
228
+ for param in self.model.parameters():
229
+ param *= 0.5
230
+
231
+ def forward(self,
232
+ x,
233
+ t,
234
+ cross_attn_cond=None,
235
+ cross_attn_mask=None,
236
+ input_concat_cond=None,
237
+ global_cond=None,
238
+ cfg_scale=1.0,
239
+ cfg_dropout_prob: float = 0.0,
240
+ batch_cfg: bool = False,
241
+ rescale_cfg: bool = False,
242
+ negative_cross_attn_cond=None,
243
+ negative_cross_attn_mask=None,
244
+ negative_global_cond=None,
245
+ negative_input_concat_cond=None,
246
+ prepend_cond=None,
247
+ prepend_cond_mask=None,
248
+ **kwargs):
249
+ p = Profiler()
250
+
251
+ p.tick("start")
252
+
253
+ channels_list = None
254
+ if input_concat_cond is not None:
255
+ channels_list = [input_concat_cond]
256
+
257
+ outputs = self.model(
258
+ x,
259
+ t,
260
+ embedding=cross_attn_cond,
261
+ embedding_mask=cross_attn_mask,
262
+ features=global_cond,
263
+ channels_list=channels_list,
264
+ embedding_scale=cfg_scale,
265
+ embedding_mask_proba=cfg_dropout_prob,
266
+ batch_cfg=batch_cfg,
267
+ rescale_cfg=rescale_cfg,
268
+ negative_embedding=negative_cross_attn_cond,
269
+ negative_embedding_mask=negative_cross_attn_mask,
270
+ **kwargs)
271
+
272
+ p.tick("UNetCFG1D forward")
273
+
274
+ #print(f"Profiler: {p}")
275
+ return outputs
276
+
277
+ class UNet1DCondWrapper(ConditionedDiffusionModel):
278
+ def __init__(
279
+ self,
280
+ *args,
281
+ **kwargs
282
+ ):
283
+ super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
284
+
285
+ self.model = UNet1d(*args, **kwargs)
286
+
287
+ with torch.no_grad():
288
+ for param in self.model.parameters():
289
+ param *= 0.5
290
+
291
+ def forward(self,
292
+ x,
293
+ t,
294
+ input_concat_cond=None,
295
+ global_cond=None,
296
+ cross_attn_cond=None,
297
+ cross_attn_mask=None,
298
+ prepend_cond=None,
299
+ prepend_cond_mask=None,
300
+ cfg_scale=1.0,
301
+ cfg_dropout_prob: float = 0.0,
302
+ batch_cfg: bool = False,
303
+ rescale_cfg: bool = False,
304
+ negative_cross_attn_cond=None,
305
+ negative_cross_attn_mask=None,
306
+ negative_global_cond=None,
307
+ negative_input_concat_cond=None,
308
+ **kwargs):
309
+
310
+ channels_list = None
311
+ if input_concat_cond is not None:
312
+
313
+ # Interpolate input_concat_cond to the same length as x
314
+ if input_concat_cond.shape[2] != x.shape[2]:
315
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
316
+
317
+ channels_list = [input_concat_cond]
318
+
319
+ outputs = self.model(
320
+ x,
321
+ t,
322
+ features=global_cond,
323
+ channels_list=channels_list,
324
+ **kwargs)
325
+
326
+ return outputs
327
+
328
+ class UNet1DUncondWrapper(DiffusionModel):
329
+ def __init__(
330
+ self,
331
+ in_channels,
332
+ *args,
333
+ **kwargs
334
+ ):
335
+ super().__init__()
336
+
337
+ self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
338
+
339
+ self.io_channels = in_channels
340
+
341
+ with torch.no_grad():
342
+ for param in self.model.parameters():
343
+ param *= 0.5
344
+
345
+ def forward(self, x, t, **kwargs):
346
+ return self.model(x, t, **kwargs)
347
+
348
+ class DAU1DCondWrapper(ConditionedDiffusionModel):
349
+ def __init__(
350
+ self,
351
+ *args,
352
+ **kwargs
353
+ ):
354
+ super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
355
+
356
+ self.model = DiffusionAttnUnet1D(*args, **kwargs)
357
+
358
+ with torch.no_grad():
359
+ for param in self.model.parameters():
360
+ param *= 0.5
361
+
362
+ def forward(self,
363
+ x,
364
+ t,
365
+ input_concat_cond=None,
366
+ cross_attn_cond=None,
367
+ cross_attn_mask=None,
368
+ global_cond=None,
369
+ cfg_scale=1.0,
370
+ cfg_dropout_prob: float = 0.0,
371
+ batch_cfg: bool = False,
372
+ rescale_cfg: bool = False,
373
+ negative_cross_attn_cond=None,
374
+ negative_cross_attn_mask=None,
375
+ negative_global_cond=None,
376
+ negative_input_concat_cond=None,
377
+ prepend_cond=None,
378
+ **kwargs):
379
+
380
+ return self.model(x, t, cond = input_concat_cond)
381
+
382
+ class DiffusionAttnUnet1D(nn.Module):
383
+ def __init__(
384
+ self,
385
+ io_channels = 2,
386
+ depth=14,
387
+ n_attn_layers = 6,
388
+ channels = [128, 128, 256, 256] + [512] * 10,
389
+ cond_dim = 0,
390
+ cond_noise_aug = False,
391
+ kernel_size = 5,
392
+ learned_resample = False,
393
+ strides = [2] * 13,
394
+ conv_bias = True,
395
+ use_snake = False
396
+ ):
397
+ super().__init__()
398
+
399
+ self.cond_noise_aug = cond_noise_aug
400
+
401
+ self.io_channels = io_channels
402
+
403
+ if self.cond_noise_aug:
404
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
405
+
406
+ self.timestep_embed = FourierFeatures(1, 16)
407
+
408
+ attn_layer = depth - n_attn_layers
409
+
410
+ strides = [1] + strides
411
+
412
+ block = nn.Identity()
413
+
414
+ conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
415
+
416
+ for i in range(depth, 0, -1):
417
+ c = channels[i - 1]
418
+ stride = strides[i-1]
419
+ if stride > 2 and not learned_resample:
420
+ raise ValueError("Must have stride 2 without learned resampling")
421
+
422
+ if i > 1:
423
+ c_prev = channels[i - 2]
424
+ add_attn = i >= attn_layer and n_attn_layers > 0
425
+ block = SkipBlock(
426
+ Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
427
+ conv_block(c_prev, c, c),
428
+ SelfAttention1d(
429
+ c, c // 32) if add_attn else nn.Identity(),
430
+ conv_block(c, c, c),
431
+ SelfAttention1d(
432
+ c, c // 32) if add_attn else nn.Identity(),
433
+ conv_block(c, c, c),
434
+ SelfAttention1d(
435
+ c, c // 32) if add_attn else nn.Identity(),
436
+ block,
437
+ conv_block(c * 2 if i != depth else c, c, c),
438
+ SelfAttention1d(
439
+ c, c // 32) if add_attn else nn.Identity(),
440
+ conv_block(c, c, c),
441
+ SelfAttention1d(
442
+ c, c // 32) if add_attn else nn.Identity(),
443
+ conv_block(c, c, c_prev),
444
+ SelfAttention1d(c_prev, c_prev //
445
+ 32) if add_attn else nn.Identity(),
446
+ Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
447
+ )
448
+ else:
449
+ cond_embed_dim = 16 if not self.cond_noise_aug else 32
450
+ block = nn.Sequential(
451
+ conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
452
+ conv_block(c, c, c),
453
+ conv_block(c, c, c),
454
+ block,
455
+ conv_block(c * 2, c, c),
456
+ conv_block(c, c, c),
457
+ conv_block(c, c, io_channels, is_last=True),
458
+ )
459
+ self.net = block
460
+
461
+ with torch.no_grad():
462
+ for param in self.net.parameters():
463
+ param *= 0.5
464
+
465
+ def forward(self, x, t, cond=None, cond_aug_scale=None):
466
+
467
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
468
+
469
+ inputs = [x, timestep_embed]
470
+
471
+ if cond is not None:
472
+ if cond.shape[2] != x.shape[2]:
473
+ cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
474
+
475
+ if self.cond_noise_aug:
476
+ # Get a random number between 0 and 1, uniformly sampled
477
+ if cond_aug_scale is None:
478
+ aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
479
+ else:
480
+ aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
481
+
482
+ # Add noise to the conditioning signal
483
+ cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
484
+
485
+ # Get embedding for noise cond level, reusing timestamp_embed
486
+ aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
487
+
488
+ inputs.append(aug_level_embed)
489
+
490
+ inputs.append(cond)
491
+
492
+ outputs = self.net(torch.cat(inputs, dim=1))
493
+
494
+ return outputs
495
+
496
+ class DiTWrapper(ConditionedDiffusionModel):
497
+ def __init__(
498
+ self,
499
+ *args,
500
+ **kwargs
501
+ ):
502
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
503
+
504
+ self.model = DiffusionTransformer(*args, **kwargs)
505
+
506
+ with torch.no_grad():
507
+ for param in self.model.parameters():
508
+ param *= 0.5
509
+
510
+ def forward(self,
511
+ x,
512
+ t,
513
+ cross_attn_cond=None,
514
+ cross_attn_mask=None,
515
+ negative_cross_attn_cond=None,
516
+ negative_cross_attn_mask=None,
517
+ input_concat_cond=None,
518
+ negative_input_concat_cond=None,
519
+ global_cond=None,
520
+ negative_global_cond=None,
521
+ prepend_cond=None,
522
+ prepend_cond_mask=None,
523
+ cfg_scale=1.0,
524
+ cfg_dropout_prob: float = 0.0,
525
+ batch_cfg: bool = True,
526
+ rescale_cfg: bool = False,
527
+ scale_phi: float = 0.0,
528
+ **kwargs):
529
+
530
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
531
+ #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
532
+
533
+ return self.model(
534
+ x,
535
+ t,
536
+ cross_attn_cond=cross_attn_cond,
537
+ cross_attn_cond_mask=cross_attn_mask,
538
+ negative_cross_attn_cond=negative_cross_attn_cond,
539
+ negative_cross_attn_mask=negative_cross_attn_mask,
540
+ input_concat_cond=input_concat_cond,
541
+ prepend_cond=prepend_cond,
542
+ prepend_cond_mask=prepend_cond_mask,
543
+ cfg_scale=cfg_scale,
544
+ cfg_dropout_prob=cfg_dropout_prob,
545
+ scale_phi=scale_phi,
546
+ global_embed=global_cond,
547
+ **kwargs)
548
+
549
+ class DiTUncondWrapper(DiffusionModel):
550
+ def __init__(
551
+ self,
552
+ in_channels,
553
+ *args,
554
+ **kwargs
555
+ ):
556
+ super().__init__()
557
+
558
+ self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
559
+
560
+ self.io_channels = in_channels
561
+
562
+ with torch.no_grad():
563
+ for param in self.model.parameters():
564
+ param *= 0.5
565
+
566
+ def forward(self, x, t, **kwargs):
567
+ return self.model(x, t, **kwargs)
568
+
569
+ def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
570
+ diffusion_uncond_config = config["model"]
571
+
572
+ model_type = diffusion_uncond_config.get('type', None)
573
+
574
+ diffusion_config = diffusion_uncond_config.get('config', {})
575
+
576
+ assert model_type is not None, "Must specify model type in config"
577
+
578
+ pretransform = diffusion_uncond_config.get("pretransform", None)
579
+
580
+ sample_size = config.get("sample_size", None)
581
+ assert sample_size is not None, "Must specify sample size in config"
582
+
583
+ sample_rate = config.get("sample_rate", None)
584
+ assert sample_rate is not None, "Must specify sample rate in config"
585
+
586
+ if pretransform is not None:
587
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
588
+ min_input_length = pretransform.downsampling_ratio
589
+ else:
590
+ min_input_length = 1
591
+
592
+ if model_type == 'DAU1d':
593
+
594
+ model = DiffusionAttnUnet1D(
595
+ **diffusion_config
596
+ )
597
+
598
+ elif model_type == "adp_uncond_1d":
599
+
600
+ model = UNet1DUncondWrapper(
601
+ **diffusion_config
602
+ )
603
+
604
+ elif model_type == "dit":
605
+ model = DiTUncondWrapper(
606
+ **diffusion_config
607
+ )
608
+
609
+ else:
610
+ raise NotImplementedError(f'Unknown model type: {model_type}')
611
+
612
+ return DiffusionModelWrapper(model,
613
+ io_channels=model.io_channels,
614
+ sample_size=sample_size,
615
+ sample_rate=sample_rate,
616
+ pretransform=pretransform,
617
+ min_input_length=min_input_length)
618
+
619
+ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
620
+
621
+ model_config = config["model"]
622
+
623
+ model_type = config["model_type"]
624
+
625
+ diffusion_config = model_config.get('diffusion', None)
626
+ assert diffusion_config is not None, "Must specify diffusion config"
627
+
628
+ diffusion_model_type = diffusion_config.get('type', None)
629
+ assert diffusion_model_type is not None, "Must specify diffusion model type"
630
+
631
+ diffusion_model_config = diffusion_config.get('config', None)
632
+ if diffusion_model_config.get('video_fps', None) is not None:
633
+ diffusion_model_config.pop('video_fps')
634
+ assert diffusion_model_config is not None, "Must specify diffusion model config"
635
+
636
+ if diffusion_model_type == 'adp_cfg_1d':
637
+ diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
638
+ elif diffusion_model_type == 'adp_1d':
639
+ diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
640
+ elif diffusion_model_type == 'dit':
641
+ diffusion_model = DiTWrapper(**diffusion_model_config)
642
+
643
+ io_channels = model_config.get('io_channels', None)
644
+ assert io_channels is not None, "Must specify io_channels in model config"
645
+
646
+ sample_rate = config.get('sample_rate', None)
647
+ assert sample_rate is not None, "Must specify sample_rate in config"
648
+
649
+ diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
650
+
651
+ conditioning_config = model_config.get('conditioning', None)
652
+
653
+ conditioner = None
654
+ if conditioning_config is not None:
655
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
656
+
657
+ cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
658
+ global_cond_ids = diffusion_config.get('global_cond_ids', [])
659
+ input_concat_ids = diffusion_config.get('input_concat_ids', [])
660
+ prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
661
+
662
+ pretransform = model_config.get("pretransform", None)
663
+
664
+ if pretransform is not None:
665
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
666
+ min_input_length = pretransform.downsampling_ratio
667
+ else:
668
+ min_input_length = 1
669
+
670
+ if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
671
+ min_input_length *= np.prod(diffusion_model_config["factors"])
672
+ elif diffusion_model_type == "dit":
673
+ min_input_length *= diffusion_model.model.patch_size
674
+
675
+ # Get the proper wrapper class
676
+
677
+ extra_kwargs = {}
678
+
679
+ if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
680
+ wrapper_fn = ConditionedDiffusionModelWrapper
681
+
682
+ extra_kwargs["diffusion_objective"] = diffusion_objective
683
+
684
+ elif model_type == "diffusion_prior":
685
+ prior_type = model_config.get("prior_type", None)
686
+ assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
687
+
688
+ if prior_type == "mono_stereo":
689
+ from .diffusion_prior import MonoToStereoDiffusionPrior
690
+ wrapper_fn = MonoToStereoDiffusionPrior
691
+
692
+ return wrapper_fn(
693
+ diffusion_model,
694
+ conditioner,
695
+ min_input_length=min_input_length,
696
+ sample_rate=sample_rate,
697
+ cross_attn_cond_ids=cross_attention_ids,
698
+ global_cond_ids=global_cond_ids,
699
+ input_concat_ids=input_concat_ids,
700
+ prepend_cond_ids=prepend_cond_ids,
701
+ pretransform=pretransform,
702
+ io_channels=io_channels,
703
+ **extra_kwargs
704
+ )
stable_audio_tools/models/discriminators.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from functools import reduce
6
+ import typing as tp
7
+ from einops import rearrange
8
+ from audiotools import AudioSignal, STFTParams
9
+ from dac.model.discriminator import WNConv1d, WNConv2d
10
+
11
+ def get_hinge_losses(score_real, score_fake):
12
+ gen_loss = -score_fake.mean()
13
+ dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean()
14
+ return dis_loss, gen_loss
15
+
16
+ class EncodecDiscriminator(nn.Module):
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__()
20
+
21
+ from encodec.msstftd import MultiScaleSTFTDiscriminator
22
+
23
+ self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs)
24
+
25
+ def forward(self, x):
26
+ logits, features = self.discriminators(x)
27
+ return logits, features
28
+
29
+ def loss(self, x, y):
30
+ feature_matching_distance = 0.
31
+ logits_true, feature_true = self.forward(x)
32
+ logits_fake, feature_fake = self.forward(y)
33
+
34
+ dis_loss = torch.tensor(0.)
35
+ adv_loss = torch.tensor(0.)
36
+
37
+ for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)):
38
+
39
+ feature_matching_distance = feature_matching_distance + sum(
40
+ map(
41
+ lambda x, y: abs(x - y).mean(),
42
+ scale_true,
43
+ scale_fake,
44
+ )) / len(scale_true)
45
+
46
+ _dis, _adv = get_hinge_losses(
47
+ logits_true[i],
48
+ logits_fake[i],
49
+ )
50
+
51
+ dis_loss = dis_loss + _dis
52
+ adv_loss = adv_loss + _adv
53
+
54
+ return dis_loss, adv_loss, feature_matching_distance
55
+
56
+ # Discriminators from oobleck
57
+
58
+ IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]]
59
+
60
+ TensorDict = tp.Dict[str, torch.Tensor]
61
+
62
+ class SharedDiscriminatorConvNet(nn.Module):
63
+
64
+ def __init__(
65
+ self,
66
+ in_size: int,
67
+ convolution: tp.Union[nn.Conv1d, nn.Conv2d],
68
+ out_size: int = 1,
69
+ capacity: int = 32,
70
+ n_layers: int = 4,
71
+ kernel_size: int = 15,
72
+ stride: int = 4,
73
+ activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(),
74
+ normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm,
75
+ ) -> None:
76
+ super().__init__()
77
+ channels = [in_size]
78
+ channels += list(capacity * 2**np.arange(n_layers))
79
+
80
+ if isinstance(stride, int):
81
+ stride = n_layers * [stride]
82
+
83
+ net = []
84
+ for i in range(n_layers):
85
+ if isinstance(kernel_size, int):
86
+ pad = kernel_size // 2
87
+ s = stride[i]
88
+ else:
89
+ pad = kernel_size[0] // 2
90
+ s = (stride[i], 1)
91
+
92
+ net.append(
93
+ normalization(
94
+ convolution(
95
+ channels[i],
96
+ channels[i + 1],
97
+ kernel_size,
98
+ stride=s,
99
+ padding=pad,
100
+ )))
101
+ net.append(activation())
102
+
103
+ net.append(convolution(channels[-1], out_size, 1))
104
+
105
+ self.net = nn.ModuleList(net)
106
+
107
+ def forward(self, x) -> IndividualDiscriminatorOut:
108
+ features = []
109
+ for layer in self.net:
110
+ x = layer(x)
111
+ if isinstance(layer, nn.modules.conv._ConvNd):
112
+ features.append(x)
113
+ score = x.reshape(x.shape[0], -1).mean(-1)
114
+ return score, features
115
+
116
+
117
+ class MultiScaleDiscriminator(nn.Module):
118
+
119
+ def __init__(self,
120
+ in_channels: int,
121
+ n_scales: int,
122
+ **conv_kwargs) -> None:
123
+ super().__init__()
124
+ layers = []
125
+ for _ in range(n_scales):
126
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs))
127
+ self.layers = nn.ModuleList(layers)
128
+
129
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
130
+ score = 0
131
+ features = []
132
+ for layer in self.layers:
133
+ s, f = layer(x)
134
+ score = score + s
135
+ features.extend(f)
136
+ x = nn.functional.avg_pool1d(x, 2)
137
+ return score, features
138
+
139
+ class MultiPeriodDiscriminator(nn.Module):
140
+
141
+ def __init__(self,
142
+ in_channels: int,
143
+ periods: tp.Sequence[int],
144
+ **conv_kwargs) -> None:
145
+ super().__init__()
146
+ layers = []
147
+ self.periods = periods
148
+
149
+ for _ in periods:
150
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs))
151
+
152
+ self.layers = nn.ModuleList(layers)
153
+
154
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
155
+ score = 0
156
+ features = []
157
+ for layer, n in zip(self.layers, self.periods):
158
+ s, f = layer(self.fold(x, n))
159
+ score = score + s
160
+ features.extend(f)
161
+ return score, features
162
+
163
+ def fold(self, x: torch.Tensor, n: int) -> torch.Tensor:
164
+ pad = (n - (x.shape[-1] % n)) % n
165
+ x = nn.functional.pad(x, (0, pad))
166
+ return x.reshape(*x.shape[:2], -1, n)
167
+
168
+
169
+ class MultiDiscriminator(nn.Module):
170
+ """
171
+ Individual discriminators should take a single tensor as input (NxB C T) and
172
+ return a tuple composed of a score tensor (NxB) and a Sequence of Features
173
+ Sequence[NxB C' T'].
174
+ """
175
+
176
+ def __init__(self, discriminator_list: tp.Sequence[nn.Module],
177
+ keys: tp.Sequence[str]) -> None:
178
+ super().__init__()
179
+ self.discriminators = nn.ModuleList(discriminator_list)
180
+ self.keys = keys
181
+
182
+ def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict:
183
+ features = features.chunk(len(self.keys), 0)
184
+ return {k: features[i] for i, k in enumerate(self.keys)}
185
+
186
+ @staticmethod
187
+ def concat_dicts(dict_a, dict_b):
188
+ out_dict = {}
189
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
190
+ for k in keys:
191
+ out_dict[k] = []
192
+ if k in dict_a:
193
+ if isinstance(dict_a[k], list):
194
+ out_dict[k].extend(dict_a[k])
195
+ else:
196
+ out_dict[k].append(dict_a[k])
197
+ if k in dict_b:
198
+ if isinstance(dict_b[k], list):
199
+ out_dict[k].extend(dict_b[k])
200
+ else:
201
+ out_dict[k].append(dict_b[k])
202
+ return out_dict
203
+
204
+ @staticmethod
205
+ def sum_dicts(dict_a, dict_b):
206
+ out_dict = {}
207
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
208
+ for k in keys:
209
+ out_dict[k] = 0.
210
+ if k in dict_a:
211
+ out_dict[k] = out_dict[k] + dict_a[k]
212
+ if k in dict_b:
213
+ out_dict[k] = out_dict[k] + dict_b[k]
214
+ return out_dict
215
+
216
+ def forward(self, inputs: TensorDict) -> TensorDict:
217
+ discriminator_input = torch.cat([inputs[k] for k in self.keys], 0)
218
+ all_scores = []
219
+ all_features = []
220
+
221
+ for discriminator in self.discriminators:
222
+ score, features = discriminator(discriminator_input)
223
+ scores = self.unpack_tensor_to_dict(score)
224
+ scores = {f"score_{k}": scores[k] for k in scores.keys()}
225
+ all_scores.append(scores)
226
+
227
+ features = map(self.unpack_tensor_to_dict, features)
228
+ features = reduce(self.concat_dicts, features)
229
+ features = {f"features_{k}": features[k] for k in features.keys()}
230
+ all_features.append(features)
231
+
232
+ all_scores = reduce(self.sum_dicts, all_scores)
233
+ all_features = reduce(self.concat_dicts, all_features)
234
+
235
+ inputs.update(all_scores)
236
+ inputs.update(all_features)
237
+
238
+ return inputs
239
+
240
+ class OobleckDiscriminator(nn.Module):
241
+
242
+ def __init__(
243
+ self,
244
+ in_channels=1,
245
+ ):
246
+ super().__init__()
247
+
248
+ multi_scale_discriminator = MultiScaleDiscriminator(
249
+ in_channels=in_channels,
250
+ n_scales=3,
251
+ )
252
+
253
+ multi_period_discriminator = MultiPeriodDiscriminator(
254
+ in_channels=in_channels,
255
+ periods=[2, 3, 5, 7, 11]
256
+ )
257
+
258
+ # multi_resolution_discriminator = MultiScaleSTFTDiscriminator(
259
+ # filters=32,
260
+ # in_channels = in_channels,
261
+ # out_channels = 1,
262
+ # n_ffts = [2048, 1024, 512, 256, 128],
263
+ # hop_lengths = [512, 256, 128, 64, 32],
264
+ # win_lengths = [2048, 1024, 512, 256, 128]
265
+ # )
266
+
267
+ self.multi_discriminator = MultiDiscriminator(
268
+ [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator],
269
+ ["reals", "fakes"]
270
+ )
271
+
272
+ def loss(self, reals, fakes):
273
+ inputs = {
274
+ "reals": reals,
275
+ "fakes": fakes,
276
+ }
277
+
278
+ inputs = self.multi_discriminator(inputs)
279
+
280
+ scores_real = inputs["score_reals"]
281
+ scores_fake = inputs["score_fakes"]
282
+
283
+ features_real = inputs["features_reals"]
284
+ features_fake = inputs["features_fakes"]
285
+
286
+ dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake)
287
+
288
+ feature_matching_distance = torch.tensor(0.)
289
+
290
+ for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)):
291
+
292
+ feature_matching_distance = feature_matching_distance + sum(
293
+ map(
294
+ lambda real, fake: abs(real - fake).mean(),
295
+ scale_real,
296
+ scale_fake,
297
+ )) / len(scale_real)
298
+
299
+ return dis_loss, gen_loss, feature_matching_distance
300
+
301
+
302
+ ## Discriminators from Descript Audio Codec repo
303
+ ## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt
304
+ class MPD(nn.Module):
305
+ def __init__(self, period, channels=1):
306
+ super().__init__()
307
+
308
+ self.period = period
309
+ self.convs = nn.ModuleList(
310
+ [
311
+ WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)),
312
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
313
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
314
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
315
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
316
+ ]
317
+ )
318
+ self.conv_post = WNConv2d(
319
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
320
+ )
321
+
322
+ def pad_to_period(self, x):
323
+ t = x.shape[-1]
324
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
325
+ return x
326
+
327
+ def forward(self, x):
328
+ fmap = []
329
+
330
+ x = self.pad_to_period(x)
331
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
332
+
333
+ for layer in self.convs:
334
+ x = layer(x)
335
+ fmap.append(x)
336
+
337
+ x = self.conv_post(x)
338
+ fmap.append(x)
339
+
340
+ return fmap
341
+
342
+
343
+ class MSD(nn.Module):
344
+ def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1):
345
+ super().__init__()
346
+
347
+ self.convs = nn.ModuleList(
348
+ [
349
+ WNConv1d(channels, 16, 15, 1, padding=7),
350
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
351
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
352
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
353
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
354
+ WNConv1d(1024, 1024, 5, 1, padding=2),
355
+ ]
356
+ )
357
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
358
+ self.sample_rate = sample_rate
359
+ self.rate = rate
360
+
361
+ def forward(self, x):
362
+ x = AudioSignal(x, self.sample_rate)
363
+ x.resample(self.sample_rate // self.rate)
364
+ x = x.audio_data
365
+
366
+ fmap = []
367
+
368
+ for l in self.convs:
369
+ x = l(x)
370
+ fmap.append(x)
371
+ x = self.conv_post(x)
372
+ fmap.append(x)
373
+
374
+ return fmap
375
+
376
+
377
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
378
+
379
+
380
+ class MRD(nn.Module):
381
+ def __init__(
382
+ self,
383
+ window_length: int,
384
+ hop_factor: float = 0.25,
385
+ sample_rate: int = 44100,
386
+ bands: list = BANDS,
387
+ channels: int = 1
388
+ ):
389
+ """Complex multi-band spectrogram discriminator.
390
+ Parameters
391
+ ----------
392
+ window_length : int
393
+ Window length of STFT.
394
+ hop_factor : float, optional
395
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
396
+ sample_rate : int, optional
397
+ Sampling rate of audio in Hz, by default 44100
398
+ bands : list, optional
399
+ Bands to run discriminator over.
400
+ """
401
+ super().__init__()
402
+
403
+ self.window_length = window_length
404
+ self.hop_factor = hop_factor
405
+ self.sample_rate = sample_rate
406
+ self.stft_params = STFTParams(
407
+ window_length=window_length,
408
+ hop_length=int(window_length * hop_factor),
409
+ match_stride=True,
410
+ )
411
+
412
+ self.channels = channels
413
+
414
+ n_fft = window_length // 2 + 1
415
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
416
+ self.bands = bands
417
+
418
+ ch = 32
419
+ convs = lambda: nn.ModuleList(
420
+ [
421
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
422
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
423
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
424
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
425
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
426
+ ]
427
+ )
428
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
429
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
430
+
431
+ def spectrogram(self, x):
432
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
433
+ x = torch.view_as_real(x.stft())
434
+ x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels)
435
+ # Split into bands
436
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
437
+ return x_bands
438
+
439
+ def forward(self, x):
440
+ x_bands = self.spectrogram(x)
441
+ fmap = []
442
+
443
+ x = []
444
+ for band, stack in zip(x_bands, self.band_convs):
445
+ for layer in stack:
446
+ band = layer(band)
447
+ fmap.append(band)
448
+ x.append(band)
449
+
450
+ x = torch.cat(x, dim=-1)
451
+ x = self.conv_post(x)
452
+ fmap.append(x)
453
+
454
+ return fmap
455
+
456
+
457
+ class DACDiscriminator(nn.Module):
458
+ def __init__(
459
+ self,
460
+ channels: int = 1,
461
+ rates: list = [],
462
+ periods: list = [2, 3, 5, 7, 11],
463
+ fft_sizes: list = [2048, 1024, 512],
464
+ sample_rate: int = 44100,
465
+ bands: list = BANDS,
466
+ ):
467
+ """Discriminator that combines multiple discriminators.
468
+
469
+ Parameters
470
+ ----------
471
+ rates : list, optional
472
+ sampling rates (in Hz) to run MSD at, by default []
473
+ If empty, MSD is not used.
474
+ periods : list, optional
475
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
476
+ fft_sizes : list, optional
477
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
478
+ sample_rate : int, optional
479
+ Sampling rate of audio in Hz, by default 44100
480
+ bands : list, optional
481
+ Bands to run MRD at, by default `BANDS`
482
+ """
483
+ super().__init__()
484
+ discs = []
485
+ discs += [MPD(p, channels=channels) for p in periods]
486
+ discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates]
487
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes]
488
+ self.discriminators = nn.ModuleList(discs)
489
+
490
+ def preprocess(self, y):
491
+ # Remove DC offset
492
+ y = y - y.mean(dim=-1, keepdims=True)
493
+ # Peak normalize the volume of input audio
494
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
495
+ return y
496
+
497
+ def forward(self, x):
498
+ x = self.preprocess(x)
499
+ fmaps = [d(x) for d in self.discriminators]
500
+ return fmaps
501
+
502
+ class DACGANLoss(nn.Module):
503
+ """
504
+ Computes a discriminator loss, given a discriminator on
505
+ generated waveforms/spectrograms compared to ground truth
506
+ waveforms/spectrograms. Computes the loss for both the
507
+ discriminator and the generator in separate functions.
508
+ """
509
+
510
+ def __init__(self, **discriminator_kwargs):
511
+ super().__init__()
512
+ self.discriminator = DACDiscriminator(**discriminator_kwargs)
513
+
514
+ def forward(self, fake, real):
515
+ d_fake = self.discriminator(fake)
516
+ d_real = self.discriminator(real)
517
+ return d_fake, d_real
518
+
519
+ def discriminator_loss(self, fake, real):
520
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
521
+
522
+ loss_d = 0
523
+ for x_fake, x_real in zip(d_fake, d_real):
524
+ loss_d += torch.mean(x_fake[-1] ** 2)
525
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
526
+ return loss_d
527
+
528
+ def generator_loss(self, fake, real):
529
+ d_fake, d_real = self.forward(fake, real)
530
+
531
+ loss_g = 0
532
+ for x_fake in d_fake:
533
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
534
+
535
+ loss_feature = 0
536
+
537
+ for i in range(len(d_fake)):
538
+ for j in range(len(d_fake[i]) - 1):
539
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
540
+ return loss_g, loss_feature
541
+
542
+ def loss(self, fake, real):
543
+ gen_loss, feature_distance = self.generator_loss(fake, real)
544
+ dis_loss = self.discriminator_loss(fake, real)
545
+
546
+ return dis_loss, gen_loss, feature_distance
stable_audio_tools/models/dit.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from x_transformers import ContinuousTransformerWrapper, Encoder
9
+
10
+ from .blocks import FourierFeatures
11
+ from .transformer import ContinuousTransformer
12
+
13
+ class DiffusionTransformer(nn.Module):
14
+ def __init__(self,
15
+ io_channels=32,
16
+ patch_size=1,
17
+ embed_dim=768,
18
+ cond_token_dim=0,
19
+ project_cond_tokens=True,
20
+ global_cond_dim=0,
21
+ project_global_cond=True,
22
+ input_concat_dim=0,
23
+ prepend_cond_dim=0,
24
+ depth=12,
25
+ num_heads=8,
26
+ transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
27
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
28
+ **kwargs):
29
+
30
+ super().__init__()
31
+
32
+ self.cond_token_dim = cond_token_dim
33
+
34
+ # Timestep embeddings
35
+ timestep_features_dim = 256
36
+
37
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
38
+
39
+ self.to_timestep_embed = nn.Sequential(
40
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
41
+ nn.SiLU(),
42
+ nn.Linear(embed_dim, embed_dim, bias=True),
43
+ )
44
+
45
+ if cond_token_dim > 0:
46
+ # Conditioning tokens
47
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
48
+ self.to_cond_embed = nn.Sequential(
49
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
50
+ nn.SiLU(),
51
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
52
+ )
53
+ else:
54
+ cond_embed_dim = 0
55
+
56
+ if global_cond_dim > 0:
57
+ # Global conditioning
58
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
59
+ self.to_global_embed = nn.Sequential(
60
+ nn.Linear(global_cond_dim, global_embed_dim, bias=False),
61
+ nn.SiLU(),
62
+ nn.Linear(global_embed_dim, global_embed_dim, bias=False)
63
+ )
64
+
65
+ if prepend_cond_dim > 0:
66
+ # Prepend conditioning
67
+ self.to_prepend_embed = nn.Sequential(
68
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
69
+ nn.SiLU(),
70
+ nn.Linear(embed_dim, embed_dim, bias=False)
71
+ )
72
+
73
+ self.input_concat_dim = input_concat_dim
74
+
75
+ dim_in = io_channels + self.input_concat_dim
76
+
77
+ self.patch_size = patch_size
78
+
79
+ # Transformer
80
+
81
+ self.transformer_type = transformer_type
82
+
83
+ self.global_cond_type = global_cond_type
84
+
85
+ if self.transformer_type == "x-transformers":
86
+ self.transformer = ContinuousTransformerWrapper(
87
+ dim_in=dim_in * patch_size,
88
+ dim_out=io_channels * patch_size,
89
+ max_seq_len=0, #Not relevant without absolute positional embeds
90
+ attn_layers = Encoder(
91
+ dim=embed_dim,
92
+ depth=depth,
93
+ heads=num_heads,
94
+ attn_flash = True,
95
+ cross_attend = cond_token_dim > 0,
96
+ dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
97
+ zero_init_branch_output=True,
98
+ use_abs_pos_emb = False,
99
+ rotary_pos_emb=True,
100
+ ff_swish = True,
101
+ ff_glu = True,
102
+ **kwargs
103
+ )
104
+ )
105
+
106
+ elif self.transformer_type == "continuous_transformer":
107
+
108
+ global_dim = None
109
+
110
+ if self.global_cond_type == "adaLN":
111
+ # The global conditioning is projected to the embed_dim already at this point
112
+ global_dim = embed_dim
113
+
114
+ self.transformer = ContinuousTransformer(
115
+ dim=embed_dim,
116
+ depth=depth,
117
+ dim_heads=embed_dim // num_heads,
118
+ dim_in=dim_in * patch_size,
119
+ dim_out=io_channels * patch_size,
120
+ cross_attend = cond_token_dim > 0,
121
+ cond_token_dim = cond_embed_dim,
122
+ global_cond_dim=global_dim,
123
+ **kwargs
124
+ )
125
+
126
+ else:
127
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
128
+
129
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
130
+ nn.init.zeros_(self.preprocess_conv.weight)
131
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
132
+ nn.init.zeros_(self.postprocess_conv.weight)
133
+
134
+ def _forward(
135
+ self,
136
+ x,
137
+ t,
138
+ mask=None,
139
+ cross_attn_cond=None,
140
+ cross_attn_cond_mask=None,
141
+ input_concat_cond=None,
142
+ global_embed=None,
143
+ prepend_cond=None,
144
+ prepend_cond_mask=None,
145
+ return_info=False,
146
+ **kwargs):
147
+
148
+ if cross_attn_cond is not None:
149
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond) # MLP endecoder, shape: [1, 130, 768]
150
+
151
+ if global_embed is not None:
152
+ # Project the global conditioning to the embedding dimension
153
+ global_embed = self.to_global_embed(global_embed)
154
+
155
+ prepend_inputs = None
156
+ prepend_mask = None
157
+ prepend_length = 0
158
+ if prepend_cond is not None:
159
+ # Project the prepend conditioning to the embedding dimension
160
+ prepend_cond = self.to_prepend_embed(prepend_cond)
161
+
162
+ prepend_inputs = prepend_cond
163
+ if prepend_cond_mask is not None:
164
+ prepend_mask = prepend_cond_mask
165
+
166
+ if input_concat_cond is not None:
167
+
168
+ # Interpolate input_concat_cond to the same length as x
169
+ if input_concat_cond.shape[2] != x.shape[2]:
170
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
171
+
172
+ x = torch.cat([x, input_concat_cond], dim=1)
173
+
174
+ # Get the batch of timestep embeddings
175
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
176
+
177
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
178
+ if global_embed is not None:
179
+ global_embed = global_embed + timestep_embed
180
+ else:
181
+ global_embed = timestep_embed
182
+
183
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
184
+ if self.global_cond_type == "prepend": # True
185
+ if prepend_inputs is None: # True
186
+ # Prepend inputs are just the global embed, and the mask is all ones
187
+ prepend_inputs = global_embed.unsqueeze(1) # [1, 1, 1536]
188
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
189
+ else:
190
+ # Prepend inputs are the prepend conditioning + the global embed
191
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
192
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
193
+
194
+ prepend_length = prepend_inputs.shape[1] # 1
195
+
196
+ x = self.preprocess_conv(x) + x # [1, 64, 1024]
197
+
198
+ x = rearrange(x, "b c t -> b t c") # [1, 1024, 64]
199
+
200
+ extra_args = {}
201
+
202
+ if self.global_cond_type == "adaLN": # 'prepend'
203
+ extra_args["global_cond"] = global_embed
204
+
205
+ if self.patch_size > 1: # self.patch_size==1
206
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
207
+
208
+ if self.transformer_type == "x-transformers":
209
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
210
+ elif self.transformer_type == "continuous_transformer":
211
+
212
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
213
+
214
+ if return_info:
215
+ output, info = output
216
+ elif self.transformer_type == "mm_transformer":
217
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
218
+
219
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
220
+
221
+ if self.patch_size > 1:
222
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
223
+
224
+ output = self.postprocess_conv(output) + output
225
+
226
+ if return_info:
227
+ return output, info
228
+
229
+ return output
230
+
231
+ def forward(
232
+ self,
233
+ x,
234
+ t,
235
+ cross_attn_cond=None,
236
+ cross_attn_cond_mask=None,
237
+ negative_cross_attn_cond=None,
238
+ negative_cross_attn_mask=None,
239
+ input_concat_cond=None,
240
+ global_embed=None,
241
+ negative_global_embed=None,
242
+ prepend_cond=None,
243
+ prepend_cond_mask=None,
244
+ cfg_scale=1.0,
245
+ cfg_dropout_prob=0.0,
246
+ causal=False,
247
+ scale_phi=0.0,
248
+ mask=None,
249
+ return_info=False,
250
+ **kwargs):
251
+
252
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
253
+
254
+ if cross_attn_cond_mask is not None:
255
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
256
+
257
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
258
+
259
+ if prepend_cond_mask is not None:
260
+ prepend_cond_mask = prepend_cond_mask.bool()
261
+
262
+ # CFG dropout
263
+ if cfg_dropout_prob > 0.0:
264
+ if cross_attn_cond is not None:
265
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
266
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
267
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
268
+
269
+ if prepend_cond is not None:
270
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
271
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
272
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
273
+
274
+
275
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
276
+ # Classifier-free guidance
277
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
278
+ batch_inputs = torch.cat([x, x], dim=0)
279
+ batch_timestep = torch.cat([t, t], dim=0)
280
+
281
+ if global_embed is not None:
282
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
283
+ else:
284
+ batch_global_cond = None
285
+
286
+ if input_concat_cond is not None:
287
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
288
+ else:
289
+ batch_input_concat_cond = None
290
+
291
+ batch_cond = None
292
+ batch_cond_masks = None
293
+
294
+ # Handle CFG for cross-attention conditioning
295
+ if cross_attn_cond is not None:
296
+
297
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
298
+
299
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
300
+ if negative_cross_attn_cond is not None:
301
+
302
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
303
+ if negative_cross_attn_mask is not None:
304
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
305
+
306
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
307
+
308
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
309
+
310
+ else:
311
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
312
+
313
+ if cross_attn_cond_mask is not None:
314
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
315
+
316
+ batch_prepend_cond = None
317
+ batch_prepend_cond_mask = None
318
+
319
+ if prepend_cond is not None:
320
+
321
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
322
+
323
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
324
+
325
+ if prepend_cond_mask is not None:
326
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
327
+
328
+
329
+ if mask is not None:
330
+ batch_masks = torch.cat([mask, mask], dim=0)
331
+ else:
332
+ batch_masks = None
333
+
334
+ batch_output = self._forward(
335
+ batch_inputs,
336
+ batch_timestep,
337
+ cross_attn_cond=batch_cond,
338
+ cross_attn_cond_mask=batch_cond_masks,
339
+ mask = batch_masks,
340
+ input_concat_cond=batch_input_concat_cond,
341
+ global_embed = batch_global_cond,
342
+ prepend_cond = batch_prepend_cond,
343
+ prepend_cond_mask = batch_prepend_cond_mask,
344
+ return_info = return_info,
345
+ **kwargs)
346
+
347
+ if return_info:
348
+ batch_output, info = batch_output
349
+
350
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
351
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
352
+
353
+ # CFG Rescale
354
+ if scale_phi != 0.0:
355
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
356
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
357
+ output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
358
+ else:
359
+ output = cfg_output
360
+
361
+ if return_info:
362
+ return output, info
363
+
364
+ return output
365
+
366
+ else:
367
+ return self._forward(
368
+ x,
369
+ t,
370
+ cross_attn_cond=cross_attn_cond,
371
+ cross_attn_cond_mask=cross_attn_cond_mask,
372
+ input_concat_cond=input_concat_cond,
373
+ global_embed=global_embed,
374
+ prepend_cond=prepend_cond,
375
+ prepend_cond_mask=prepend_cond_mask,
376
+ mask=mask,
377
+ return_info=return_info,
378
+ **kwargs
379
+ )
stable_audio_tools/models/factory.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def create_model_from_config(model_config):
4
+ model_type = model_config.get('model_type', None)
5
+
6
+ assert model_type is not None, 'model_type must be specified in model config'
7
+
8
+ if model_type == 'autoencoder':
9
+ from .autoencoders import create_autoencoder_from_config
10
+ return create_autoencoder_from_config(model_config)
11
+ elif model_type == 'diffusion_uncond':
12
+ from .diffusion import create_diffusion_uncond_from_config
13
+ return create_diffusion_uncond_from_config(model_config)
14
+ elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
15
+ from .diffusion import create_diffusion_cond_from_config
16
+ return create_diffusion_cond_from_config(model_config)
17
+ elif model_type == 'diffusion_autoencoder':
18
+ from .autoencoders import create_diffAE_from_config
19
+ return create_diffAE_from_config(model_config)
20
+ elif model_type == 'lm':
21
+ from .lm import create_audio_lm_from_config
22
+ return create_audio_lm_from_config(model_config)
23
+ else:
24
+ raise NotImplementedError(f'Unknown model type: {model_type}')
25
+
26
+ def create_model_from_config_path(model_config_path):
27
+ with open(model_config_path) as f:
28
+ model_config = json.load(f)
29
+
30
+ return create_model_from_config(model_config)
31
+
32
+ def create_pretransform_from_config(pretransform_config, sample_rate):
33
+ pretransform_type = pretransform_config.get('type', None)
34
+
35
+ assert pretransform_type is not None, 'type must be specified in pretransform config'
36
+
37
+ if pretransform_type == 'autoencoder':
38
+ from .autoencoders import create_autoencoder_from_config
39
+ from .pretransforms import AutoencoderPretransform
40
+
41
+ # Create fake top-level config to pass sample rate to autoencoder constructor
42
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
43
+ autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
44
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
45
+
46
+ scale = pretransform_config.get("scale", 1.0)
47
+ model_half = pretransform_config.get("model_half", False)
48
+ iterate_batch = pretransform_config.get("iterate_batch", False)
49
+ chunked = pretransform_config.get("chunked", False)
50
+
51
+ pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
52
+ elif pretransform_type == 'wavelet':
53
+ from .pretransforms import WaveletPretransform
54
+
55
+ wavelet_config = pretransform_config["config"]
56
+ channels = wavelet_config["channels"]
57
+ levels = wavelet_config["levels"]
58
+ wavelet = wavelet_config["wavelet"]
59
+
60
+ pretransform = WaveletPretransform(channels, levels, wavelet)
61
+ elif pretransform_type == 'pqmf':
62
+ from .pretransforms import PQMFPretransform
63
+ pqmf_config = pretransform_config["config"]
64
+ pretransform = PQMFPretransform(**pqmf_config)
65
+ elif pretransform_type == 'dac_pretrained':
66
+ from .pretransforms import PretrainedDACPretransform
67
+ pretrained_dac_config = pretransform_config["config"]
68
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
69
+ elif pretransform_type == "audiocraft_pretrained":
70
+ from .pretransforms import AudiocraftCompressionPretransform
71
+
72
+ audiocraft_config = pretransform_config["config"]
73
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
74
+ else:
75
+ raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
76
+
77
+ enable_grad = pretransform_config.get('enable_grad', False)
78
+ pretransform.enable_grad = enable_grad
79
+
80
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
81
+
82
+ return pretransform
83
+
84
+ def create_bottleneck_from_config(bottleneck_config):
85
+ bottleneck_type = bottleneck_config.get('type', None)
86
+
87
+ assert bottleneck_type is not None, 'type must be specified in bottleneck config'
88
+
89
+ if bottleneck_type == 'tanh':
90
+ from .bottleneck import TanhBottleneck
91
+ bottleneck = TanhBottleneck()
92
+ elif bottleneck_type == 'vae':
93
+ from .bottleneck import VAEBottleneck
94
+ bottleneck = VAEBottleneck()
95
+ elif bottleneck_type == 'rvq':
96
+ from .bottleneck import RVQBottleneck
97
+
98
+ quantizer_params = {
99
+ "dim": 128,
100
+ "codebook_size": 1024,
101
+ "num_quantizers": 8,
102
+ "decay": 0.99,
103
+ "kmeans_init": True,
104
+ "kmeans_iters": 50,
105
+ "threshold_ema_dead_code": 2,
106
+ }
107
+
108
+ quantizer_params.update(bottleneck_config["config"])
109
+
110
+ bottleneck = RVQBottleneck(**quantizer_params)
111
+ elif bottleneck_type == "dac_rvq":
112
+ from .bottleneck import DACRVQBottleneck
113
+
114
+ bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
115
+
116
+ elif bottleneck_type == 'rvq_vae':
117
+ from .bottleneck import RVQVAEBottleneck
118
+
119
+ quantizer_params = {
120
+ "dim": 128,
121
+ "codebook_size": 1024,
122
+ "num_quantizers": 8,
123
+ "decay": 0.99,
124
+ "kmeans_init": True,
125
+ "kmeans_iters": 50,
126
+ "threshold_ema_dead_code": 2,
127
+ }
128
+
129
+ quantizer_params.update(bottleneck_config["config"])
130
+
131
+ bottleneck = RVQVAEBottleneck(**quantizer_params)
132
+
133
+ elif bottleneck_type == 'dac_rvq_vae':
134
+ from .bottleneck import DACRVQVAEBottleneck
135
+ bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
136
+ elif bottleneck_type == 'l2_norm':
137
+ from .bottleneck import L2Bottleneck
138
+ bottleneck = L2Bottleneck()
139
+ elif bottleneck_type == "wasserstein":
140
+ from .bottleneck import WassersteinBottleneck
141
+ bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
142
+ elif bottleneck_type == "fsq":
143
+ from .bottleneck import FSQBottleneck
144
+ bottleneck = FSQBottleneck(**bottleneck_config["config"])
145
+ else:
146
+ raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
147
+
148
+ requires_grad = bottleneck_config.get('requires_grad', True)
149
+ if not requires_grad:
150
+ for param in bottleneck.parameters():
151
+ param.requires_grad = False
152
+
153
+ return bottleneck
stable_audio_tools/models/lm.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ from tqdm.auto import trange
4
+ import typing as tp
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
9
+ from .factory import create_pretransform_from_config
10
+ from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone
11
+ from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform
12
+ from .utils import multinomial, sample_top_k, sample_top_p
13
+
14
+ from .codebook_patterns import (
15
+ CodebooksPatternProvider,
16
+ DelayedPatternProvider,
17
+ MusicLMPattern,
18
+ ParallelPatternProvider,
19
+ UnrolledPatternProvider
20
+ )
21
+
22
+ # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license
23
+ # License can be found in LICENSES/LICENSE_META.txt
24
+
25
+ @dataclass
26
+ class LMOutput:
27
+ # The logits are already re-aligned with the input codes
28
+ # hence no extra shift is required, e.g. when computing CE
29
+ logits: torch.Tensor # [B, K, T, card]
30
+ mask: torch.Tensor # [B, K, T]
31
+
32
+ # Wrapper for a multi-codebook language model
33
+ # Handles patterns and quantizer heads
34
+ class AudioLanguageModel(nn.Module):
35
+ def __init__(
36
+ self,
37
+ pattern_provider: CodebooksPatternProvider,
38
+ backbone: AudioLMBackbone,
39
+ num_quantizers: int,
40
+ codebook_size: int
41
+ ):
42
+ super().__init__()
43
+
44
+ self.pattern_provider = pattern_provider
45
+ self.backbone = backbone
46
+ self.num_quantizers = num_quantizers
47
+ self.codebook_size = codebook_size
48
+
49
+ self.masked_token_id = codebook_size
50
+
51
+ # Per-quantizer embedders
52
+ # Add one for the mask embed
53
+ self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)])
54
+
55
+ # Per-quantizer output heads
56
+ self.quantizer_heads = nn.ModuleList([
57
+ nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers)
58
+ ])
59
+
60
+ def forward(self,
61
+ sequence: torch.Tensor, #[batch, seq_len,
62
+ prepend_cond=None, #[batch, seq, channels]
63
+ prepend_cond_mask=None,
64
+ cross_attn_cond=None, #[batch, seq, channels],
65
+ **kwargs
66
+ ):
67
+
68
+
69
+ batch, num_quantizers, seq_len = sequence.shape
70
+
71
+ assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model"
72
+
73
+ backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim]
74
+
75
+ dtype = next(self.parameters()).dtype
76
+
77
+ if cross_attn_cond is not None:
78
+ cross_attn_cond = cross_attn_cond.to(dtype)
79
+
80
+ if prepend_cond is not None:
81
+ prepend_cond = prepend_cond.to(dtype)
82
+
83
+ if prepend_cond_mask is not None:
84
+ prepend_cond_mask = prepend_cond_mask.to(dtype)
85
+
86
+ backbone_input = backbone_input.to(dtype)
87
+
88
+ output = self.backbone(
89
+ backbone_input,
90
+ cross_attn_cond=cross_attn_cond,
91
+ prepend_cond=prepend_cond,
92
+ prepend_cond_mask=prepend_cond_mask,
93
+ **kwargs
94
+ ) # [batch, seq_len, embed_dim]
95
+
96
+ # Run output through quantizer heads
97
+ logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size]
98
+
99
+ return logits
100
+
101
+ def compute_logits(
102
+ self,
103
+ codes, #[batch, num_quantizers, seq_len]
104
+ **kwargs):
105
+ """
106
+ Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning
107
+ Handles translation between input sequence and pattern-shifted sequence
108
+ Only used during training
109
+ """
110
+
111
+ batch, _, seq_len = codes.shape
112
+
113
+ pattern = self.pattern_provider.get_pattern(seq_len)
114
+
115
+ # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps
116
+ shifted_codes, _, _ = pattern.build_pattern_sequence(
117
+ codes,
118
+ self.masked_token_id,
119
+ keep_only_valid_steps=True
120
+ )
121
+
122
+ # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size]
123
+ logits = self(shifted_codes, **kwargs)
124
+
125
+ # Rearrange logits to prepare to revert pattern
126
+ logits = rearrange(logits, "b n s c -> b c n s")
127
+
128
+ # Revert sequence logits back to original sequence length, removing masked steps
129
+ logits, _, logits_mask = pattern.revert_pattern_logits(
130
+ logits, float('nan'), keep_only_valid_steps=True
131
+ )
132
+
133
+ logits = rearrange(logits, "b c n t -> b n t c")
134
+
135
+ logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len]
136
+
137
+ return LMOutput(logits=logits, mask=logits_mask)
138
+
139
+ # Conditioning and generation wrapper for a multi-codebook language model
140
+ # Handles conditioning, CFG, generation, and encoding/decoding
141
+ class AudioLanguageModelWrapper(nn.Module):
142
+ def __init__(
143
+ self,
144
+ pretransform: Pretransform,
145
+ lm: AudioLanguageModel,
146
+ sample_rate: int,
147
+ min_input_length: int,
148
+ conditioner: MultiConditioner = None,
149
+ cross_attn_cond_ids: tp.List[str] = [],
150
+ prepend_cond_ids: tp.List[str] = [],
151
+ global_cond_ids: tp.List[str] = []
152
+ ):
153
+ super().__init__()
154
+
155
+ assert pretransform.is_discrete, "Pretransform must be discrete"
156
+ self.pretransform = pretransform
157
+
158
+ self.pretransform.requires_grad_(False)
159
+ self.pretransform.eval()
160
+
161
+ if isinstance(self.pretransform, AutoencoderPretransform):
162
+ self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers
163
+ self.codebook_size = self.pretransform.model.bottleneck.codebook_size
164
+ elif isinstance(self.pretransform, PretrainedDACPretransform):
165
+ self.num_quantizers = self.pretransform.model.num_quantizers
166
+ self.codebook_size = self.pretransform.model.codebook_size
167
+ elif isinstance(self.pretransform, AudiocraftCompressionPretransform):
168
+ self.num_quantizers = self.pretransform.num_quantizers
169
+ self.codebook_size = self.pretransform.codebook_size
170
+ else:
171
+ raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}")
172
+
173
+ self.conditioner = conditioner
174
+
175
+ self.lm = lm
176
+
177
+ self.sample_rate = sample_rate
178
+ self.min_input_length = min_input_length
179
+
180
+ self.cross_attn_cond_ids = cross_attn_cond_ids
181
+ self.prepend_cond_ids = prepend_cond_ids
182
+ self.global_cond_ids = global_cond_ids
183
+
184
+ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
185
+ cross_attention_input = None
186
+ prepend_cond = None
187
+ prepend_cond_mask = None
188
+ global_cond = None
189
+
190
+ if len(self.cross_attn_cond_ids) > 0:
191
+ # Concatenate all cross-attention inputs over the sequence dimension
192
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
193
+ cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1)
194
+
195
+ if len(self.prepend_cond_ids) > 0:
196
+ # Concatenate all prepend conditioning inputs over the sequence dimension
197
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
198
+ prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
199
+ prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
200
+
201
+ if len(self.global_cond_ids) > 0:
202
+ # Concatenate all global conditioning inputs over the channel dimension
203
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
204
+ global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
205
+ if len(global_cond.shape) == 3:
206
+ global_cond = global_cond.squeeze(1)
207
+
208
+ if negative:
209
+ return {
210
+ "negative_cross_attn_cond": cross_attention_input,
211
+ "negative_prepend_cond": prepend_cond,
212
+ "negative_prepend_cond_mask": prepend_cond_mask,
213
+ "negative_global_cond": global_cond
214
+ }
215
+ else:
216
+ return {
217
+ "cross_attn_cond": cross_attention_input,
218
+ "prepend_cond": prepend_cond,
219
+ "prepend_cond_mask": prepend_cond_mask,
220
+ "global_cond": global_cond
221
+ }
222
+
223
+ def compute_logits(
224
+ self,
225
+ codes,
226
+ condition_tensors=None,
227
+ cfg_dropout_prob=0.0,
228
+ **kwargs
229
+ ):
230
+ """
231
+ Compute logits for a batch of codes, and translates from conditioning inputs to model inputs
232
+ Handles CFG dropout
233
+ """
234
+
235
+ if condition_tensors is None:
236
+ condition_tensors = {}
237
+
238
+ conditioning_inputs = self.get_conditioning_inputs(condition_tensors)
239
+
240
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
241
+ prepend_cond = conditioning_inputs["prepend_cond"]
242
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
243
+ global_cond = conditioning_inputs["global_cond"]
244
+
245
+ if cfg_dropout_prob > 0.0:
246
+ if cross_attn_cond is not None:
247
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
248
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
249
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
250
+
251
+ if prepend_cond is not None:
252
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
253
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
254
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
255
+
256
+ if global_cond is not None:
257
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
258
+ dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool)
259
+ global_cond = torch.where(dropout_mask, null_embed, global_cond)
260
+
261
+ return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
262
+
263
+ def _sample_next_token(
264
+ self,
265
+ sequence, #[batch, num_quantizers, seq_len]
266
+ conditioning_tensors=None,
267
+ cross_attn_use_cfg=True,
268
+ prepend_use_cfg=True,
269
+ global_use_cfg=True,
270
+ cfg_scale=1.0,
271
+ top_k=250,
272
+ top_p=0.0,
273
+ temp=1.0,
274
+ **kwargs
275
+ ):
276
+ """
277
+ Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs
278
+ Handles CFG inference
279
+ """
280
+
281
+ if conditioning_tensors is None:
282
+ conditioning_tensors = {}
283
+
284
+ conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors)
285
+
286
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
287
+ prepend_cond = conditioning_inputs["prepend_cond"]
288
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
289
+ global_cond = conditioning_inputs["global_cond"]
290
+
291
+ if cfg_scale != 1.0:
292
+
293
+ # Batch size is doubled to account for negative samples
294
+ sequence = torch.cat([sequence, sequence], dim=0)
295
+
296
+ if cross_attn_cond is not None and cross_attn_use_cfg:
297
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
298
+
299
+ cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
300
+
301
+ if prepend_cond is not None and prepend_use_cfg:
302
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
303
+
304
+ prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
305
+
306
+ if prepend_cond_mask is not None:
307
+ prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
308
+
309
+ if global_cond is not None and global_use_cfg:
310
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
311
+
312
+ global_cond = torch.cat([global_cond, null_embed], dim=0)
313
+
314
+ logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
315
+
316
+ if cfg_scale != 1.0:
317
+ cond_logits, uncond_logits = logits.chunk(2, dim=0)
318
+
319
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
320
+
321
+ logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len]
322
+
323
+ # Grab the logits for the last step
324
+ logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size]
325
+
326
+ # Apply top-k or top-p sampling
327
+
328
+ if temp > 0:
329
+ probs = torch.softmax(logits / temp, dim=-1)
330
+
331
+ if top_p > 0.0:
332
+ next_token = sample_top_p(probs, p=top_p)
333
+ elif top_k > 0:
334
+ next_token = sample_top_k(probs, k=top_k)
335
+ else:
336
+ next_token = multinomial(probs, num_samples=1)
337
+
338
+ else:
339
+ next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1]
340
+
341
+ return next_token
342
+
343
+ @torch.no_grad()
344
+ def generate(
345
+ self,
346
+ max_gen_len: int = 256,
347
+ batch_size: tp.Optional[int] = None,
348
+ init_data: tp.Optional[torch.Tensor] = None,
349
+ conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
350
+ conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None,
351
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
352
+ use_cache: bool = True,
353
+ cfg_scale: float = 1.0,
354
+ **kwargs
355
+ ):
356
+ device = next(self.parameters()).device
357
+
358
+ if conditioning_tensors is None and conditioning is not None:
359
+ # Convert conditioning inputs to conditioning tensors
360
+ conditioning_tensors = self.conditioner(conditioning, device)
361
+
362
+ # Check that batch size is consistent across inputs
363
+ possible_batch_sizes = []
364
+
365
+ if batch_size is not None:
366
+ possible_batch_sizes.append(batch_size)
367
+ elif init_data is not None:
368
+ possible_batch_sizes.append(init_data.shape[0])
369
+ elif conditioning_tensors is not None:
370
+ # Assume that the first conditioning tensor has the batch dimension
371
+ possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0])
372
+ else:
373
+ possible_batch_sizes.append(1)
374
+
375
+ assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs"
376
+
377
+ batch_size = possible_batch_sizes[0]
378
+
379
+ if init_data is None:
380
+ # Initialize with zeros
381
+ assert batch_size > 0
382
+ init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long)
383
+
384
+ batch_size, num_quantizers, seq_len = init_data.shape
385
+
386
+ start_offset = seq_len
387
+ assert start_offset < max_gen_len, "init data longer than max gen length"
388
+
389
+ pattern = self.lm.pattern_provider.get_pattern(max_gen_len)
390
+
391
+ unknown_token = -1
392
+
393
+ # Initialize the generated codes with the init data, padded with unknown tokens
394
+ gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long)
395
+ gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len]
396
+
397
+ gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len]
398
+
399
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
400
+ assert start_offset_sequence is not None
401
+
402
+ # Generation
403
+ prev_offset = 0
404
+ gen_sequence_len = gen_sequence.shape[-1]
405
+
406
+ # Reset generation cache
407
+ if use_cache and self.lm.backbone.use_generation_cache:
408
+ self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2)
409
+
410
+ for offset in trange(start_offset_sequence, gen_sequence_len):
411
+
412
+ # Get the full sequence up to the current offset
413
+ curr_sequence = gen_sequence[..., prev_offset:offset]
414
+
415
+ next_token = self._sample_next_token(
416
+ curr_sequence,
417
+ conditioning_tensors=conditioning_tensors,
418
+ use_cache=use_cache,
419
+ cfg_scale=cfg_scale,
420
+ **kwargs
421
+ )
422
+
423
+ valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1)
424
+ next_token[~valid_mask] = self.lm.masked_token_id
425
+
426
+ # Update the generated sequence with the next token
427
+ gen_sequence[..., offset:offset+1] = torch.where(
428
+ gen_sequence[..., offset:offset+1] == unknown_token,
429
+ next_token,
430
+ gen_sequence[..., offset:offset+1]
431
+ )
432
+
433
+ if use_cache and self.lm.backbone.use_generation_cache:
434
+ # Only update the offset if caching is being used
435
+ prev_offset = offset
436
+
437
+ self.lm.backbone.update_generation_cache(offset)
438
+
439
+ if callback is not None:
440
+ # Callback to report progress
441
+ # Pass in the offset relative to the start of the sequence, and the length of the current sequence
442
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
443
+
444
+ assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence"
445
+
446
+ out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
447
+
448
+ # sanity checks over the returned codes and corresponding masks
449
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
450
+ assert (out_mask[..., :max_gen_len] == 1).all()
451
+
452
+ #out_codes = out_codes[..., 0:max_gen_len]
453
+
454
+ return out_codes
455
+
456
+
457
+ def generate_audio(
458
+ self,
459
+ **kwargs
460
+ ):
461
+ """
462
+ Generate audio from a batch of codes
463
+ """
464
+
465
+ codes = self.generate(**kwargs)
466
+
467
+ audio = self.pretransform.decode_tokens(codes)
468
+
469
+ return audio
470
+
471
+
472
+ def create_audio_lm_from_config(config):
473
+ model_config = config.get('model', None)
474
+ assert model_config is not None, 'model config must be specified in config'
475
+
476
+ sample_rate = config.get('sample_rate', None)
477
+ assert sample_rate is not None, "Must specify sample_rate in config"
478
+
479
+ lm_config = model_config.get('lm', None)
480
+ assert lm_config is not None, 'lm config must be specified in model config'
481
+
482
+ codebook_pattern = lm_config.get("codebook_pattern", "delay")
483
+
484
+ pattern_providers = {
485
+ 'parallel': ParallelPatternProvider,
486
+ 'delay': DelayedPatternProvider,
487
+ 'unroll': UnrolledPatternProvider,
488
+ 'musiclm': MusicLMPattern,
489
+ }
490
+
491
+ pretransform_config = model_config.get("pretransform", None)
492
+
493
+ pretransform = create_pretransform_from_config(pretransform_config, sample_rate)
494
+
495
+ assert pretransform.is_discrete, "Pretransform must be discrete"
496
+
497
+ min_input_length = pretransform.downsampling_ratio
498
+
499
+ pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers)
500
+
501
+ conditioning_config = model_config.get('conditioning', None)
502
+
503
+ conditioner = None
504
+ if conditioning_config is not None:
505
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
506
+
507
+ cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', [])
508
+ prepend_cond_ids = lm_config.get('prepend_cond_ids', [])
509
+ global_cond_ids = lm_config.get('global_cond_ids', [])
510
+
511
+ lm_type = lm_config.get("type", None)
512
+ lm_model_config = lm_config.get("config", None)
513
+
514
+ assert lm_type is not None, "Must specify lm type in lm config"
515
+ assert lm_model_config is not None, "Must specify lm model config in lm config"
516
+
517
+ if lm_type == "x-transformers":
518
+ backbone = XTransformersAudioLMBackbone(**lm_model_config)
519
+ elif lm_type == "continuous_transformer":
520
+ backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config)
521
+ else:
522
+ raise NotImplementedError(f"Unrecognized lm type {lm_type}")
523
+
524
+ lm = AudioLanguageModel(
525
+ pattern_provider=pattern_provider,
526
+ backbone=backbone,
527
+ num_quantizers=pretransform.num_quantizers,
528
+ codebook_size=pretransform.codebook_size
529
+ )
530
+
531
+ model = AudioLanguageModelWrapper(
532
+ pretransform=pretransform,
533
+ lm=lm,
534
+ conditioner=conditioner,
535
+ sample_rate=sample_rate,
536
+ min_input_length=min_input_length,
537
+ cross_attn_cond_ids=cross_attn_cond_ids,
538
+ prepend_cond_ids=prepend_cond_ids,
539
+ global_cond_ids=global_cond_ids
540
+ )
541
+
542
+ return model
stable_audio_tools/models/local_attention.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+ from .blocks import AdaRMSNorm
7
+ from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
8
+
9
+ def checkpoint(function, *args, **kwargs):
10
+ kwargs.setdefault("use_reentrant", False)
11
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
12
+
13
+ # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
14
+ class ContinuousLocalTransformer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ *,
18
+ dim,
19
+ depth,
20
+ dim_in = None,
21
+ dim_out = None,
22
+ causal = False,
23
+ local_attn_window_size = 64,
24
+ heads = 8,
25
+ ff_mult = 2,
26
+ cond_dim = 0,
27
+ cross_attn_cond_dim = 0,
28
+ **kwargs
29
+ ):
30
+ super().__init__()
31
+
32
+ dim_head = dim//heads
33
+
34
+ self.layers = nn.ModuleList([])
35
+
36
+ self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
37
+
38
+ self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
39
+
40
+ self.local_attn_window_size = local_attn_window_size
41
+
42
+ self.cond_dim = cond_dim
43
+
44
+ self.cross_attn_cond_dim = cross_attn_cond_dim
45
+
46
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
47
+
48
+ for _ in range(depth):
49
+
50
+ self.layers.append(nn.ModuleList([
51
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
52
+ Attention(
53
+ dim=dim,
54
+ dim_heads=dim_head,
55
+ causal=causal,
56
+ zero_init_output=True,
57
+ natten_kernel_size=local_attn_window_size,
58
+ ),
59
+ Attention(
60
+ dim=dim,
61
+ dim_heads=dim_head,
62
+ dim_context = cross_attn_cond_dim,
63
+ zero_init_output=True
64
+ ) if self.cross_attn_cond_dim > 0 else nn.Identity(),
65
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
66
+ FeedForward(dim = dim, mult = ff_mult, no_bias=True)
67
+ ]))
68
+
69
+ def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
70
+
71
+ x = checkpoint(self.project_in, x)
72
+
73
+ if prepend_cond is not None:
74
+ x = torch.cat([prepend_cond, x], dim=1)
75
+
76
+ pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
77
+
78
+ for attn_norm, attn, xattn, ff_norm, ff in self.layers:
79
+
80
+ residual = x
81
+ if cond is not None:
82
+ x = checkpoint(attn_norm, x, cond)
83
+ else:
84
+ x = checkpoint(attn_norm, x)
85
+
86
+ x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
87
+
88
+ if cross_attn_cond is not None:
89
+ x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
90
+
91
+ residual = x
92
+
93
+ if cond is not None:
94
+ x = checkpoint(ff_norm, x, cond)
95
+ else:
96
+ x = checkpoint(ff_norm, x)
97
+
98
+ x = checkpoint(ff, x) + residual
99
+
100
+ return checkpoint(self.project_out, x)
101
+
102
+ class TransformerDownsampleBlock1D(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_channels,
106
+ embed_dim = 768,
107
+ depth = 3,
108
+ heads = 12,
109
+ downsample_ratio = 2,
110
+ local_attn_window_size = 64,
111
+ **kwargs
112
+ ):
113
+ super().__init__()
114
+
115
+ self.downsample_ratio = downsample_ratio
116
+
117
+ self.transformer = ContinuousLocalTransformer(
118
+ dim=embed_dim,
119
+ depth=depth,
120
+ heads=heads,
121
+ local_attn_window_size=local_attn_window_size,
122
+ **kwargs
123
+ )
124
+
125
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
126
+
127
+ self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
128
+
129
+
130
+ def forward(self, x):
131
+
132
+ x = checkpoint(self.project_in, x)
133
+
134
+ # Compute
135
+ x = self.transformer(x)
136
+
137
+ # Trade sequence length for channels
138
+ x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
139
+
140
+ # Project back to embed dim
141
+ x = checkpoint(self.project_down, x)
142
+
143
+ return x
144
+
145
+ class TransformerUpsampleBlock1D(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels,
149
+ embed_dim,
150
+ depth = 3,
151
+ heads = 12,
152
+ upsample_ratio = 2,
153
+ local_attn_window_size = 64,
154
+ **kwargs
155
+ ):
156
+ super().__init__()
157
+
158
+ self.upsample_ratio = upsample_ratio
159
+
160
+ self.transformer = ContinuousLocalTransformer(
161
+ dim=embed_dim,
162
+ depth=depth,
163
+ heads=heads,
164
+ local_attn_window_size = local_attn_window_size,
165
+ **kwargs
166
+ )
167
+
168
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
169
+
170
+ self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
171
+
172
+ def forward(self, x):
173
+
174
+ # Project to embed dim
175
+ x = checkpoint(self.project_in, x)
176
+
177
+ # Project to increase channel dim
178
+ x = checkpoint(self.project_up, x)
179
+
180
+ # Trade channels for sequence length
181
+ x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
182
+
183
+ # Compute
184
+ x = self.transformer(x)
185
+
186
+ return x
187
+
188
+
189
+ class TransformerEncoder1D(nn.Module):
190
+ def __init__(
191
+ self,
192
+ in_channels,
193
+ out_channels,
194
+ embed_dims = [96, 192, 384, 768],
195
+ heads = [12, 12, 12, 12],
196
+ depths = [3, 3, 3, 3],
197
+ ratios = [2, 2, 2, 2],
198
+ local_attn_window_size = 64,
199
+ **kwargs
200
+ ):
201
+ super().__init__()
202
+
203
+ layers = []
204
+
205
+ for layer in range(len(depths)):
206
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
207
+
208
+ layers.append(
209
+ TransformerDownsampleBlock1D(
210
+ in_channels = prev_dim,
211
+ embed_dim = embed_dims[layer],
212
+ heads = heads[layer],
213
+ depth = depths[layer],
214
+ downsample_ratio = ratios[layer],
215
+ local_attn_window_size = local_attn_window_size,
216
+ **kwargs
217
+ )
218
+ )
219
+
220
+ self.layers = nn.Sequential(*layers)
221
+
222
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
223
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
224
+
225
+ def forward(self, x):
226
+ x = rearrange(x, "b c n -> b n c")
227
+ x = checkpoint(self.project_in, x)
228
+ x = self.layers(x)
229
+ x = checkpoint(self.project_out, x)
230
+ x = rearrange(x, "b n c -> b c n")
231
+
232
+ return x
233
+
234
+
235
+ class TransformerDecoder1D(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels,
239
+ out_channels,
240
+ embed_dims = [768, 384, 192, 96],
241
+ heads = [12, 12, 12, 12],
242
+ depths = [3, 3, 3, 3],
243
+ ratios = [2, 2, 2, 2],
244
+ local_attn_window_size = 64,
245
+ **kwargs
246
+ ):
247
+
248
+ super().__init__()
249
+
250
+ layers = []
251
+
252
+ for layer in range(len(depths)):
253
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
254
+
255
+ layers.append(
256
+ TransformerUpsampleBlock1D(
257
+ in_channels = prev_dim,
258
+ embed_dim = embed_dims[layer],
259
+ heads = heads[layer],
260
+ depth = depths[layer],
261
+ upsample_ratio = ratios[layer],
262
+ local_attn_window_size = local_attn_window_size,
263
+ **kwargs
264
+ )
265
+ )
266
+
267
+ self.layers = nn.Sequential(*layers)
268
+
269
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
270
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
271
+
272
+ def forward(self, x):
273
+ x = rearrange(x, "b c n -> b n c")
274
+ x = checkpoint(self.project_in, x)
275
+ x = self.layers(x)
276
+ x = checkpoint(self.project_out, x)
277
+ x = rearrange(x, "b n c -> b c n")
278
+ return x
stable_audio_tools/models/pqmf.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from scipy.optimize import fmin
7
+ from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
8
+
9
+ class PQMF(nn.Module):
10
+ """
11
+ Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
12
+ Uses polyphase representation which is computationally more efficient for real-time.
13
+
14
+ Parameters:
15
+ - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
16
+ - num_bands (int): Number of desired frequency bands. It must be a power of 2.
17
+ """
18
+
19
+ def __init__(self, attenuation, num_bands):
20
+ super(PQMF, self).__init__()
21
+
22
+ # Ensure num_bands is a power of 2
23
+ is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
24
+ assert is_power_of_2, "'num_bands' must be a power of 2."
25
+
26
+ # Create the prototype filter
27
+ prototype_filter = design_prototype_filter(attenuation, num_bands)
28
+ filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
29
+ padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
30
+
31
+ # Register filters and settings
32
+ self.register_buffer("filter_bank", padded_filter_bank)
33
+ self.register_buffer("prototype", prototype_filter)
34
+ self.num_bands = num_bands
35
+
36
+ def forward(self, signal):
37
+ """Decompose the signal into multiple frequency bands."""
38
+ # If signal is not a pytorch tensor of Batch x Channels x Length, convert it
39
+ signal = prepare_signal_dimensions(signal)
40
+ # The signal length must be a multiple of num_bands. Pad it with zeros.
41
+ signal = pad_signal(signal, self.num_bands)
42
+ # run it
43
+ signal = polyphase_analysis(signal, self.filter_bank)
44
+ return apply_alias_cancellation(signal)
45
+
46
+ def inverse(self, bands):
47
+ """Reconstruct the original signal from the frequency bands."""
48
+ bands = apply_alias_cancellation(bands)
49
+ return polyphase_synthesis(bands, self.filter_bank)
50
+
51
+
52
+ def prepare_signal_dimensions(signal):
53
+ """
54
+ Rearrange signal into Batch x Channels x Length.
55
+
56
+ Parameters
57
+ ----------
58
+ signal : torch.Tensor or numpy.ndarray
59
+ The input signal.
60
+
61
+ Returns
62
+ -------
63
+ torch.Tensor
64
+ Preprocessed signal tensor.
65
+ """
66
+ # Convert numpy to torch tensor
67
+ if isinstance(signal, np.ndarray):
68
+ signal = torch.from_numpy(signal)
69
+
70
+ # Ensure tensor
71
+ if not isinstance(signal, torch.Tensor):
72
+ raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
73
+
74
+ # Modify dimension of signal to Batch x Channels x Length
75
+ if signal.dim() == 1:
76
+ # This is just a mono signal. Unsqueeze to 1 x 1 x Length
77
+ signal = signal.unsqueeze(0).unsqueeze(0)
78
+ elif signal.dim() == 2:
79
+ # This is a multi-channel signal (e.g. stereo)
80
+ # Rearrange so that larger dimension (Length) is last
81
+ if signal.shape[0] > signal.shape[1]:
82
+ signal = signal.T
83
+ # Unsqueeze to 1 x Channels x Length
84
+ signal = signal.unsqueeze(0)
85
+ return signal
86
+
87
+ def pad_signal(signal, num_bands):
88
+ """
89
+ Pads the signal to make its length divisible by the given number of bands.
90
+
91
+ Parameters
92
+ ----------
93
+ signal : torch.Tensor
94
+ The input signal tensor, where the last dimension represents the signal length.
95
+
96
+ num_bands : int
97
+ The number of bands by which the signal length should be divisible.
98
+
99
+ Returns
100
+ -------
101
+ torch.Tensor
102
+ The padded signal tensor. If the original signal length was already divisible
103
+ by num_bands, returns the original signal unchanged.
104
+ """
105
+ remainder = signal.shape[-1] % num_bands
106
+ if remainder > 0:
107
+ padding_size = num_bands - remainder
108
+ signal = nn.functional.pad(signal, (0, padding_size))
109
+ return signal
110
+
111
+ def generate_modulated_filter_bank(prototype_filter, num_bands):
112
+ """
113
+ Generate a QMF bank of cosine modulated filters based on a given prototype filter.
114
+
115
+ Parameters
116
+ ----------
117
+ prototype_filter : torch.Tensor
118
+ The prototype filter used as the basis for modulation.
119
+ num_bands : int
120
+ The number of desired subbands or filters.
121
+
122
+ Returns
123
+ -------
124
+ torch.Tensor
125
+ A bank of cosine modulated filters.
126
+ """
127
+
128
+ # Initialize indices for modulation.
129
+ subband_indices = torch.arange(num_bands).reshape(-1, 1)
130
+
131
+ # Calculate the length of the prototype filter.
132
+ filter_length = prototype_filter.shape[-1]
133
+
134
+ # Generate symmetric time indices centered around zero.
135
+ time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
136
+
137
+ # Calculate phase offsets to ensure orthogonality between subbands.
138
+ phase_offsets = (-1)**subband_indices * np.pi / 4
139
+
140
+ # Compute the cosine modulation function.
141
+ modulation = torch.cos(
142
+ (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
143
+ )
144
+
145
+ # Apply modulation to the prototype filter.
146
+ modulated_filters = 2 * prototype_filter * modulation
147
+
148
+ return modulated_filters
149
+
150
+
151
+ def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
152
+ """
153
+ Design a lowpass filter using the Kaiser window.
154
+
155
+ Parameters
156
+ ----------
157
+ angular_cutoff : float
158
+ The angular frequency cutoff of the filter.
159
+ attenuation : float
160
+ The desired stopband attenuation in decibels (dB).
161
+ filter_length : int, optional
162
+ Desired length of the filter. If not provided, it's computed based on the given specs.
163
+
164
+ Returns
165
+ -------
166
+ ndarray
167
+ The designed lowpass filter coefficients.
168
+ """
169
+
170
+ estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
171
+
172
+ # Ensure the estimated length is odd.
173
+ estimated_length = 2 * (estimated_length // 2) + 1
174
+
175
+ if filter_length is None:
176
+ filter_length = estimated_length
177
+
178
+ return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
179
+
180
+
181
+ def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
182
+ """
183
+ Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
184
+
185
+ Parameters
186
+ ----------
187
+ angular_cutoff : float
188
+ Angular frequency cutoff of the filter.
189
+ attenuation : float
190
+ Desired stopband attenuation in dB.
191
+ num_bands : int
192
+ Number of bands for the multiband filter system.
193
+ filter_length : int, optional
194
+ Desired length of the filter.
195
+
196
+ Returns
197
+ -------
198
+ float
199
+ The computed objective (loss) value for the given filter specs.
200
+ """
201
+
202
+ filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
203
+ convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
204
+
205
+ return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
206
+
207
+
208
+ def design_prototype_filter(attenuation, num_bands, filter_length=None):
209
+ """
210
+ Design the optimal prototype filter for a multiband system given the desired specs.
211
+
212
+ Parameters
213
+ ----------
214
+ attenuation : float
215
+ The desired stopband attenuation in dB.
216
+ num_bands : int
217
+ Number of bands for the multiband filter system.
218
+ filter_length : int, optional
219
+ Desired length of the filter. If not provided, it's computed based on the given specs.
220
+
221
+ Returns
222
+ -------
223
+ ndarray
224
+ The optimal prototype filter coefficients.
225
+ """
226
+
227
+ optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
228
+ 1 / num_bands, disp=0)[0]
229
+
230
+ prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
231
+ return torch.tensor(prototype_filter, dtype=torch.float32)
232
+
233
+ def pad_to_nearest_power_of_two(x):
234
+ """
235
+ Pads the input tensor 'x' on both sides such that its last dimension
236
+ becomes the nearest larger power of two.
237
+
238
+ Parameters:
239
+ -----------
240
+ x : torch.Tensor
241
+ The input tensor to be padded.
242
+
243
+ Returns:
244
+ --------
245
+ torch.Tensor
246
+ The padded tensor.
247
+ """
248
+ current_length = x.shape[-1]
249
+ target_length = 2**math.ceil(math.log2(current_length))
250
+
251
+ total_padding = target_length - current_length
252
+ left_padding = total_padding // 2
253
+ right_padding = total_padding - left_padding
254
+
255
+ return nn.functional.pad(x, (left_padding, right_padding))
256
+
257
+ def apply_alias_cancellation(x):
258
+ """
259
+ Applies alias cancellation by inverting the sign of every
260
+ second element of every second row, starting from the second
261
+ row's first element in a tensor.
262
+
263
+ This operation helps ensure that the aliasing introduced in
264
+ each band during the decomposition will be counteracted during
265
+ the reconstruction.
266
+
267
+ Parameters:
268
+ -----------
269
+ x : torch.Tensor
270
+ The input tensor.
271
+
272
+ Returns:
273
+ --------
274
+ torch.Tensor
275
+ Tensor with specific elements' sign inverted for alias cancellation.
276
+ """
277
+
278
+ # Create a mask of the same shape as 'x', initialized with all ones
279
+ mask = torch.ones_like(x)
280
+
281
+ # Update specific elements in the mask to -1 to perform inversion
282
+ mask[..., 1::2, ::2] = -1
283
+
284
+ # Apply the mask to the input tensor 'x'
285
+ return x * mask
286
+
287
+ def ensure_odd_length(tensor):
288
+ """
289
+ Pads the last dimension of a tensor to ensure its size is odd.
290
+
291
+ Parameters:
292
+ -----------
293
+ tensor : torch.Tensor
294
+ Input tensor whose last dimension might need padding.
295
+
296
+ Returns:
297
+ --------
298
+ torch.Tensor
299
+ The original tensor if its last dimension was already odd,
300
+ or the padded tensor with an odd-sized last dimension.
301
+ """
302
+
303
+ last_dim_size = tensor.shape[-1]
304
+
305
+ if last_dim_size % 2 == 0:
306
+ tensor = nn.functional.pad(tensor, (0, 1))
307
+
308
+ return tensor
309
+
310
+ def polyphase_analysis(signal, filter_bank):
311
+ """
312
+ Applies the polyphase method to efficiently analyze the signal using a filter bank.
313
+
314
+ Parameters:
315
+ -----------
316
+ signal : torch.Tensor
317
+ Input signal tensor with shape (Batch x Channels x Length).
318
+
319
+ filter_bank : torch.Tensor
320
+ Filter bank tensor with shape (Bands x Length).
321
+
322
+ Returns:
323
+ --------
324
+ torch.Tensor
325
+ Signal split into sub-bands. (Batch x Channels x Bands x Length)
326
+ """
327
+
328
+ num_bands = filter_bank.shape[0]
329
+ num_channels = signal.shape[1]
330
+
331
+ # Rearrange signal for polyphase processing.
332
+ # Also combine Batch x Channel into one dimension for now.
333
+ #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
334
+ signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
335
+
336
+ # Rearrange the filter bank for matching signal shape
337
+ filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
338
+
339
+ # Apply convolution with appropriate padding to maintain spatial dimensions
340
+ padding = filter_bank.shape[-1] // 2
341
+ filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
342
+
343
+ # Truncate the last dimension post-convolution to adjust the output shape
344
+ filtered_signal = filtered_signal[..., :-1]
345
+ # Rearrange the first dimension back into Batch x Channels
346
+ filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
347
+
348
+ return filtered_signal
349
+
350
+ def polyphase_synthesis(signal, filter_bank):
351
+ """
352
+ Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
353
+
354
+ Parameters
355
+ ----------
356
+ signal : torch.Tensor
357
+ Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
358
+
359
+ filter_bank : torch.Tensor
360
+ Analysis filter bank (shape: Bands x Length).
361
+
362
+ should_rearrange : bool, optional
363
+ Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
364
+
365
+ Returns
366
+ -------
367
+ torch.Tensor
368
+ Reconstructed signal (shape: Batch x Channels X Length)
369
+ """
370
+
371
+ num_bands = filter_bank.shape[0]
372
+ num_channels = signal.shape[1]
373
+
374
+ # Rearrange the filter bank
375
+ filter_bank = filter_bank.flip(-1)
376
+ filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
377
+
378
+ # Combine Batch x Channels into one dimension for now.
379
+ signal = rearrange(signal, "b c n t -> (b c) n t")
380
+
381
+ # Apply convolution with appropriate padding
382
+ padding_amount = filter_bank.shape[-1] // 2 + 1
383
+ reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
384
+
385
+ # Scale the result
386
+ reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
387
+
388
+ # Reorganize the output and truncate
389
+ reconstructed_signal = reconstructed_signal.flip(1)
390
+ reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
391
+ reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
392
+
393
+ return reconstructed_signal
stable_audio_tools/models/pretrained.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from .factory import create_model_from_config
4
+ from .utils import load_ckpt_state_dict
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ def get_pretrained_model(name: str):
9
+
10
+ model_config_path = hf_hub_download(name, filename="config.json", repo_type='model')
11
+
12
+ with open(model_config_path) as f:
13
+ model_config = json.load(f)
14
+
15
+ model = create_model_from_config(model_config)
16
+
17
+ # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
18
+ try:
19
+ model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
20
+ except Exception as e:
21
+ model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model')
22
+
23
+ model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
24
+
25
+ return model, model_config
stable_audio_tools/models/pretransforms.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+ class Pretransform(nn.Module):
6
+ def __init__(self, enable_grad, io_channels, is_discrete):
7
+ super().__init__()
8
+
9
+ self.is_discrete = is_discrete
10
+ self.io_channels = io_channels
11
+ self.encoded_channels = None
12
+ self.downsampling_ratio = None
13
+
14
+ self.enable_grad = enable_grad
15
+
16
+ def encode(self, x):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, z):
20
+ raise NotImplementedError
21
+
22
+ def tokenize(self, x):
23
+ raise NotImplementedError
24
+
25
+ def decode_tokens(self, tokens):
26
+ raise NotImplementedError
27
+
28
+ class AutoencoderPretransform(Pretransform):
29
+ def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
30
+ super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
31
+ self.model = model
32
+ self.model.requires_grad_(False).eval()
33
+ self.scale=scale
34
+ self.downsampling_ratio = model.downsampling_ratio
35
+ self.io_channels = model.io_channels
36
+ self.sample_rate = model.sample_rate
37
+
38
+ self.model_half = model_half
39
+ self.iterate_batch = iterate_batch
40
+
41
+ self.encoded_channels = model.latent_dim
42
+
43
+ self.chunked = chunked
44
+ self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
45
+ self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
46
+
47
+ if self.model_half:
48
+ self.model.half()
49
+
50
+ def encode(self, x, **kwargs):
51
+
52
+ if self.model_half:
53
+ x = x.half()
54
+ self.model.to(torch.float16)
55
+
56
+ encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
57
+
58
+ if self.model_half:
59
+ encoded = encoded.float()
60
+
61
+ return encoded / self.scale
62
+
63
+ def decode(self, z, **kwargs):
64
+ z = z * self.scale
65
+
66
+ if self.model_half:
67
+ z = z.half()
68
+ self.model.to(torch.float16)
69
+
70
+ decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
71
+
72
+ if self.model_half:
73
+ decoded = decoded.float()
74
+
75
+ return decoded
76
+
77
+ def tokenize(self, x, **kwargs):
78
+ assert self.model.is_discrete, "Cannot tokenize with a continuous model"
79
+
80
+ _, info = self.model.encode(x, return_info = True, **kwargs)
81
+
82
+ return info[self.model.bottleneck.tokens_id]
83
+
84
+ def decode_tokens(self, tokens, **kwargs):
85
+ assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
86
+
87
+ return self.model.decode_tokens(tokens, **kwargs)
88
+
89
+ def load_state_dict(self, state_dict, strict=True):
90
+ self.model.load_state_dict(state_dict, strict=strict)
91
+
92
+ class WaveletPretransform(Pretransform):
93
+ def __init__(self, channels, levels, wavelet):
94
+ super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
95
+
96
+ from .wavelets import WaveletEncode1d, WaveletDecode1d
97
+
98
+ self.encoder = WaveletEncode1d(channels, levels, wavelet)
99
+ self.decoder = WaveletDecode1d(channels, levels, wavelet)
100
+
101
+ self.downsampling_ratio = 2 ** levels
102
+ self.io_channels = channels
103
+ self.encoded_channels = channels * self.downsampling_ratio
104
+
105
+ def encode(self, x):
106
+ return self.encoder(x)
107
+
108
+ def decode(self, z):
109
+ return self.decoder(z)
110
+
111
+ class PQMFPretransform(Pretransform):
112
+ def __init__(self, attenuation=100, num_bands=16):
113
+ # TODO: Fix PQMF to take in in-channels
114
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
115
+ from .pqmf import PQMF
116
+ self.pqmf = PQMF(attenuation, num_bands)
117
+
118
+
119
+ def encode(self, x):
120
+ # x is (Batch x Channels x Time)
121
+ x = self.pqmf.forward(x)
122
+ # pqmf.forward returns (Batch x Channels x Bands x Time)
123
+ # but Pretransform needs Batch x Channels x Time
124
+ # so concatenate channels and bands into one axis
125
+ return rearrange(x, "b c n t -> b (c n) t")
126
+
127
+ def decode(self, x):
128
+ # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
129
+ x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
130
+ # returns (Batch x Channels x Time)
131
+ return self.pqmf.inverse(x)
132
+
133
+ class PretrainedDACPretransform(Pretransform):
134
+ def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
135
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
136
+
137
+ import dac
138
+
139
+ model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
140
+
141
+ self.model = dac.DAC.load(model_path)
142
+
143
+ self.quantize_on_decode = quantize_on_decode
144
+
145
+ if model_type == "44khz":
146
+ self.downsampling_ratio = 512
147
+ else:
148
+ self.downsampling_ratio = 320
149
+
150
+ self.io_channels = 1
151
+
152
+ self.scale = scale
153
+
154
+ self.chunked = chunked
155
+
156
+ self.encoded_channels = self.model.latent_dim
157
+
158
+ self.num_quantizers = self.model.n_codebooks
159
+
160
+ self.codebook_size = self.model.codebook_size
161
+
162
+ def encode(self, x):
163
+
164
+ latents = self.model.encoder(x)
165
+
166
+ if self.quantize_on_decode:
167
+ output = latents
168
+ else:
169
+ z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
170
+ output = z
171
+
172
+ if self.scale != 1.0:
173
+ output = output / self.scale
174
+
175
+ return output
176
+
177
+ def decode(self, z):
178
+
179
+ if self.scale != 1.0:
180
+ z = z * self.scale
181
+
182
+ if self.quantize_on_decode:
183
+ z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
184
+
185
+ return self.model.decode(z)
186
+
187
+ def tokenize(self, x):
188
+ return self.model.encode(x)[1]
189
+
190
+ def decode_tokens(self, tokens):
191
+ latents = self.model.quantizer.from_codes(tokens)
192
+ return self.model.decode(latents)
193
+
194
+ class AudiocraftCompressionPretransform(Pretransform):
195
+ def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
196
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
197
+
198
+ try:
199
+ from audiocraft.models import CompressionModel
200
+ except ImportError:
201
+ raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
202
+
203
+ self.model = CompressionModel.get_pretrained(model_type)
204
+
205
+ self.quantize_on_decode = quantize_on_decode
206
+
207
+ self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
208
+
209
+ self.sample_rate = self.model.sample_rate
210
+
211
+ self.io_channels = self.model.channels
212
+
213
+ self.scale = scale
214
+
215
+ #self.encoded_channels = self.model.latent_dim
216
+
217
+ self.num_quantizers = self.model.num_codebooks
218
+
219
+ self.codebook_size = self.model.cardinality
220
+
221
+ self.model.to(torch.float16).eval().requires_grad_(False)
222
+
223
+ def encode(self, x):
224
+
225
+ assert False, "Audiocraft compression models do not support continuous encoding"
226
+
227
+ # latents = self.model.encoder(x)
228
+
229
+ # if self.quantize_on_decode:
230
+ # output = latents
231
+ # else:
232
+ # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
233
+ # output = z
234
+
235
+ # if self.scale != 1.0:
236
+ # output = output / self.scale
237
+
238
+ # return output
239
+
240
+ def decode(self, z):
241
+
242
+ assert False, "Audiocraft compression models do not support continuous decoding"
243
+
244
+ # if self.scale != 1.0:
245
+ # z = z * self.scale
246
+
247
+ # if self.quantize_on_decode:
248
+ # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
249
+
250
+ # return self.model.decode(z)
251
+
252
+ def tokenize(self, x):
253
+ with torch.cuda.amp.autocast(enabled=False):
254
+ return self.model.encode(x.to(torch.float16))[0]
255
+
256
+ def decode_tokens(self, tokens):
257
+ with torch.cuda.amp.autocast(enabled=False):
258
+ return self.model.decode(tokens)
stable_audio_tools/models/temptransformer.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+
8
+ class Residual(nn.Module):
9
+ def __init__(self, fn):
10
+ super().__init__()
11
+ self.fn = fn
12
+ def forward(self, x, **kwargs):
13
+ return self.fn(x, **kwargs) + x
14
+
15
+ class SA_PreNorm(nn.Module):
16
+ def __init__(self, dim, fn):
17
+ super().__init__()
18
+ self.norm = nn.LayerNorm(dim)
19
+ self.fn = fn
20
+ def forward(self, x, **kwargs):
21
+ return self.fn(self.norm(x), **kwargs)
22
+
23
+ class SA_FeedForward(nn.Module):
24
+ def __init__(self, dim, hidden_dim, dropout = 0.):
25
+ super().__init__()
26
+ self.net = nn.Sequential(
27
+ nn.Linear(dim, hidden_dim),
28
+ nn.GELU(),
29
+ nn.Dropout(dropout),
30
+ nn.Linear(hidden_dim, dim),
31
+ nn.Dropout(dropout)
32
+ )
33
+ def forward(self, x):
34
+ return self.net(x)
35
+
36
+ class SA_Attention(nn.Module):
37
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
38
+ super().__init__()
39
+ inner_dim = dim_head * heads
40
+ project_out = not (heads == 1 and dim_head == dim)
41
+
42
+ self.heads = heads
43
+ self.scale = dim_head ** -0.5
44
+
45
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
46
+
47
+ self.to_out = nn.Sequential(
48
+ nn.Linear(inner_dim, dim),
49
+ nn.Dropout(dropout)
50
+ ) if project_out else nn.Identity()
51
+
52
+ def forward(self, x):
53
+ b, n, _, h = *x.shape, self.heads
54
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
55
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
56
+
57
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
58
+
59
+ attn = dots.softmax(dim=-1)
60
+
61
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
62
+ out = rearrange(out, 'b h n d -> b n (h d)')
63
+ out = self.to_out(out)
64
+ return out
65
+
66
+
67
+ class ReAttention(nn.Module):
68
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
69
+ super().__init__()
70
+ inner_dim = dim_head * heads
71
+ self.heads = heads
72
+ self.scale = dim_head ** -0.5
73
+
74
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
75
+
76
+ self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
77
+
78
+ self.reattn_norm = nn.Sequential(
79
+ Rearrange('b h i j -> b i j h'),
80
+ nn.LayerNorm(heads),
81
+ Rearrange('b i j h -> b h i j')
82
+ )
83
+
84
+ self.to_out = nn.Sequential(
85
+ nn.Linear(inner_dim, dim),
86
+ nn.Dropout(dropout)
87
+ )
88
+
89
+ def forward(self, x):
90
+ b, n, _, h = *x.shape, self.heads
91
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
92
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
93
+
94
+ # attention
95
+
96
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
97
+ attn = dots.softmax(dim=-1)
98
+
99
+ # re-attention
100
+
101
+ attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
102
+ attn = self.reattn_norm(attn)
103
+
104
+ # aggregate and out
105
+
106
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
107
+ out = rearrange(out, 'b h n d -> b n (h d)')
108
+ out = self.to_out(out)
109
+ return out
110
+
111
+ class LeFF(nn.Module):
112
+
113
+ def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
114
+ super().__init__()
115
+
116
+ scale_dim = dim*scale
117
+ self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
118
+ Rearrange('b n c -> b c n'),
119
+ nn.BatchNorm1d(scale_dim),
120
+ nn.GELU(),
121
+ Rearrange('b c (h w) -> b c h w', h=14, w=14)
122
+ )
123
+
124
+ self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
125
+ nn.BatchNorm2d(scale_dim),
126
+ nn.GELU(),
127
+ Rearrange('b c h w -> b (h w) c', h=14, w=14)
128
+ )
129
+
130
+ self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
131
+ Rearrange('b n c -> b c n'),
132
+ nn.BatchNorm1d(dim),
133
+ nn.GELU(),
134
+ Rearrange('b c n -> b n c')
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = self.up_proj(x)
139
+ x = self.depth_conv(x)
140
+ x = self.down_proj(x)
141
+ return x
142
+
143
+
144
+ class LCAttention(nn.Module):
145
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
146
+ super().__init__()
147
+ inner_dim = dim_head * heads
148
+ project_out = not (heads == 1 and dim_head == dim)
149
+
150
+ self.heads = heads
151
+ self.scale = dim_head ** -0.5
152
+
153
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
154
+
155
+ self.to_out = nn.Sequential(
156
+ nn.Linear(inner_dim, dim),
157
+ nn.Dropout(dropout)
158
+ ) if project_out else nn.Identity()
159
+
160
+ def forward(self, x):
161
+ b, n, _, h = *x.shape, self.heads
162
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
163
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
164
+ q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query
165
+
166
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
167
+
168
+ attn = dots.softmax(dim=-1)
169
+
170
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
171
+ out = rearrange(out, 'b h n d -> b n (h d)')
172
+ out = self.to_out(out)
173
+ return out
174
+
175
+ class SA_Transformer(nn.Module):
176
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
177
+ super().__init__()
178
+ self.layers = nn.ModuleList([])
179
+ self.norm = nn.LayerNorm(dim)
180
+ for _ in range(depth):
181
+ self.layers.append(nn.ModuleList([
182
+ SA_PreNorm(dim, SA_Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
183
+ SA_PreNorm(dim, SA_FeedForward(dim, mlp_dim, dropout = dropout))
184
+ ]))
185
+
186
+ def forward(self, x):
187
+ for attn, ff in self.layers:
188
+ x = attn(x) + x
189
+ x = ff(x) + x
190
+ return self.norm(x)
stable_audio_tools/models/transformer.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce, partial
2
+ from packaging import version
3
+
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import Rearrange
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, einsum
9
+ from torch.cuda.amp import autocast
10
+ from typing import Callable, Literal
11
+ import warnings
12
+ warnings.simplefilter(action='ignore', category=FutureWarning)
13
+
14
+ try:
15
+ from flash_attn import flash_attn_func, flash_attn_kvpacked_func
16
+ except ImportError as e:
17
+ print(e)
18
+ print('flash_attn not installed, disabling Flash Attention')
19
+ flash_attn_kvpacked_func = None
20
+ flash_attn_func = None
21
+
22
+ try:
23
+ import natten
24
+ except ImportError:
25
+ natten = None
26
+
27
+ def checkpoint(function, *args, **kwargs):
28
+ kwargs.setdefault("use_reentrant", False)
29
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
30
+
31
+
32
+ # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
33
+ # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
34
+
35
+ def create_causal_mask(i, j, device):
36
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
37
+
38
+ def or_reduce(masks):
39
+ head, *body = masks
40
+ for rest in body:
41
+ head = head | rest
42
+ return head
43
+
44
+ # positional embeddings
45
+
46
+ class AbsolutePositionalEmbedding(nn.Module):
47
+ def __init__(self, dim, max_seq_len):
48
+ super().__init__()
49
+ self.scale = dim ** -0.5
50
+ self.max_seq_len = max_seq_len
51
+ self.emb = nn.Embedding(max_seq_len, dim)
52
+
53
+ def forward(self, x, pos = None, seq_start_pos = None):
54
+ seq_len, device = x.shape[1], x.device
55
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
56
+
57
+ if pos is None:
58
+ pos = torch.arange(seq_len, device = device)
59
+
60
+ if seq_start_pos is not None:
61
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
62
+
63
+ pos_emb = self.emb(pos)
64
+ pos_emb = pos_emb * self.scale
65
+ return pos_emb
66
+
67
+ class ScaledSinusoidalEmbedding(nn.Module):
68
+ def __init__(self, dim, theta = 10000):
69
+ super().__init__()
70
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
71
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
72
+
73
+ half_dim = dim // 2
74
+ freq_seq = torch.arange(half_dim).float() / half_dim
75
+ inv_freq = theta ** -freq_seq
76
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
77
+
78
+ def forward(self, x, pos = None, seq_start_pos = None):
79
+ seq_len, device = x.shape[1], x.device
80
+
81
+ if pos is None:
82
+ pos = torch.arange(seq_len, device = device)
83
+
84
+ if seq_start_pos is not None:
85
+ pos = pos - seq_start_pos[..., None]
86
+
87
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
88
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
89
+ return emb * self.scale
90
+
91
+ class RotaryEmbedding(nn.Module):
92
+ def __init__(
93
+ self,
94
+ dim,
95
+ use_xpos = False,
96
+ scale_base = 512,
97
+ interpolation_factor = 1.,
98
+ base = 10000,
99
+ base_rescale_factor = 1.
100
+ ):
101
+ super().__init__()
102
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
103
+ # has some connection to NTK literature
104
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
105
+ base *= base_rescale_factor ** (dim / (dim - 2))
106
+
107
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
108
+ self.register_buffer('inv_freq', inv_freq)
109
+
110
+ assert interpolation_factor >= 1.
111
+ self.interpolation_factor = interpolation_factor
112
+
113
+ if not use_xpos:
114
+ self.register_buffer('scale', None)
115
+ return
116
+
117
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
118
+
119
+ self.scale_base = scale_base
120
+ self.register_buffer('scale', scale)
121
+
122
+ def forward_from_seq_len(self, seq_len):
123
+ device = self.inv_freq.device
124
+
125
+ t = torch.arange(seq_len, device = device)
126
+ return self.forward(t)
127
+
128
+ @autocast(enabled = False)
129
+ def forward(self, t):
130
+ device = self.inv_freq.device
131
+
132
+ t = t.to(torch.float32)
133
+
134
+ t = t / self.interpolation_factor
135
+
136
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
137
+ freqs = torch.cat((freqs, freqs), dim = -1)
138
+
139
+ if self.scale is None:
140
+ return freqs, 1.
141
+
142
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
143
+ scale = self.scale ** rearrange(power, 'n -> n 1')
144
+ scale = torch.cat((scale, scale), dim = -1)
145
+
146
+ return freqs, scale
147
+
148
+ def rotate_half(x):
149
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
150
+ x1, x2 = x.unbind(dim = -2)
151
+ return torch.cat((-x2, x1), dim = -1)
152
+
153
+ @autocast(enabled = False)
154
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
155
+ out_dtype = t.dtype
156
+
157
+ # cast to float32 if necessary for numerical stability
158
+ dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
159
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
160
+ freqs, t = freqs.to(dtype), t.to(dtype)
161
+ freqs = freqs[-seq_len:, :]
162
+
163
+ if t.ndim == 4 and freqs.ndim == 3:
164
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
165
+
166
+ # partial rotary embeddings, Wang et al. GPT-J
167
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
168
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
169
+
170
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
171
+
172
+ return torch.cat((t, t_unrotated), dim = -1)
173
+
174
+ # norms
175
+ class LayerNorm(nn.Module):
176
+ def __init__(self, dim, bias=False, fix_scale=False):
177
+ """
178
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
179
+ """
180
+ super().__init__()
181
+
182
+ if fix_scale:
183
+ self.register_buffer("gamma", torch.ones(dim))
184
+ else:
185
+ self.gamma = nn.Parameter(torch.ones(dim))
186
+
187
+ if bias:
188
+ self.beta = nn.Parameter(torch.zeros(dim))
189
+ else:
190
+ self.register_buffer("beta", torch.zeros(dim))
191
+
192
+
193
+ def forward(self, x):
194
+ return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
195
+
196
+ # feedforward
197
+
198
+ class GLU(nn.Module):
199
+ def __init__(
200
+ self,
201
+ dim_in,
202
+ dim_out,
203
+ activation: Callable,
204
+ use_conv = False,
205
+ conv_kernel_size = 3,
206
+ ):
207
+ super().__init__()
208
+ self.act = activation
209
+ self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
210
+ self.use_conv = use_conv
211
+
212
+ def forward(self, x):
213
+ if self.use_conv:
214
+ x = rearrange(x, 'b n d -> b d n')
215
+ x = self.proj(x)
216
+ x = rearrange(x, 'b d n -> b n d')
217
+ else:
218
+ x = self.proj(x)
219
+
220
+ x, gate = x.chunk(2, dim = -1)
221
+ return x * self.act(gate)
222
+
223
+ class FeedForward(nn.Module):
224
+ def __init__(
225
+ self,
226
+ dim,
227
+ dim_out = None,
228
+ mult = 4,
229
+ no_bias = False,
230
+ glu = True,
231
+ use_conv = False,
232
+ conv_kernel_size = 3,
233
+ zero_init_output = True,
234
+ ):
235
+ super().__init__()
236
+ inner_dim = int(dim * mult)
237
+
238
+ # Default to SwiGLU
239
+
240
+ activation = nn.SiLU()
241
+
242
+ dim_out = dim if dim_out is None else dim_out
243
+
244
+ if glu:
245
+ linear_in = GLU(dim, inner_dim, activation)
246
+ else:
247
+ linear_in = nn.Sequential(
248
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
249
+ nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
250
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
251
+ activation
252
+ )
253
+
254
+ linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
255
+
256
+ # init last linear layer to 0
257
+ if zero_init_output:
258
+ nn.init.zeros_(linear_out.weight)
259
+ if not no_bias:
260
+ nn.init.zeros_(linear_out.bias)
261
+
262
+
263
+ self.ff = nn.Sequential(
264
+ linear_in,
265
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
266
+ linear_out,
267
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
268
+ )
269
+
270
+ def forward(self, x):
271
+ return self.ff(x)
272
+
273
+ class Attention(nn.Module):
274
+ def __init__(
275
+ self,
276
+ dim,
277
+ dim_heads = 64,
278
+ dim_context = None,
279
+ causal = False,
280
+ zero_init_output=True,
281
+ qk_norm: Literal['l2', 'ln', 'none'] = 'none',
282
+ natten_kernel_size = None
283
+ ):
284
+ super().__init__()
285
+ self.dim = dim
286
+ self.dim_heads = dim_heads
287
+ self.causal = causal
288
+
289
+ dim_kv = dim_context if dim_context is not None else dim
290
+
291
+ self.num_heads = dim // dim_heads
292
+ self.kv_heads = dim_kv // dim_heads
293
+
294
+ if dim_context is not None:
295
+ self.to_q = nn.Linear(dim, dim, bias=False)
296
+ self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
297
+ else:
298
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
299
+
300
+ self.to_out = nn.Linear(dim, dim, bias=False)
301
+
302
+ if zero_init_output:
303
+ nn.init.zeros_(self.to_out.weight)
304
+
305
+ self.qk_norm = qk_norm
306
+
307
+ if self.qk_norm == "ln":
308
+ self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
309
+ self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
310
+
311
+ # Using 1d neighborhood attention
312
+ self.natten_kernel_size = natten_kernel_size
313
+ if natten_kernel_size is not None:
314
+ return
315
+
316
+ self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
317
+
318
+ self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
319
+
320
+ self.sdp_kwargs = dict(
321
+ enable_flash = True,
322
+ enable_math = True,
323
+ enable_mem_efficient = True
324
+ )
325
+
326
+ def flash_attn(
327
+ self,
328
+ q,
329
+ k,
330
+ v,
331
+ mask = None,
332
+ causal = None
333
+ ):
334
+ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
335
+ kv_heads = k.shape[1]
336
+ # Recommended for multi-query single-key-value attention by Tri Dao
337
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
338
+
339
+ if heads != kv_heads:
340
+ # Repeat interleave kv_heads to match q_heads
341
+ heads_per_kv_head = heads // kv_heads
342
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
343
+
344
+ if k.ndim == 3:
345
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
346
+
347
+ if v.ndim == 3:
348
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
349
+
350
+ causal = self.causal if causal is None else causal
351
+
352
+ if q_len == 1 and causal:
353
+ causal = False
354
+
355
+ if mask is not None:
356
+ assert mask.ndim == 4
357
+ mask = mask.expand(batch, heads, q_len, k_len)
358
+
359
+ # handle kv cache - this should be bypassable in updated flash attention 2
360
+
361
+ if k_len > q_len and causal:
362
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
363
+ if mask is None:
364
+ mask = ~causal_mask
365
+ else:
366
+ mask = mask & ~causal_mask
367
+ causal = False
368
+
369
+ # manually handle causal mask, if another mask was given
370
+
371
+ row_is_entirely_masked = None
372
+
373
+ if mask is not None and causal:
374
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
375
+ mask = mask & ~causal_mask
376
+
377
+ # protect against an entire row being masked out
378
+
379
+ row_is_entirely_masked = ~mask.any(dim = -1)
380
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
381
+
382
+ causal = False
383
+
384
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
385
+ out = F.scaled_dot_product_attention(
386
+ q, k, v,
387
+ attn_mask = mask,
388
+ is_causal = causal
389
+ )
390
+
391
+ # for a row that is entirely masked out, should zero out the output of that row token
392
+
393
+ if row_is_entirely_masked is not None:
394
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
395
+
396
+ return out
397
+
398
+ def forward(
399
+ self,
400
+ x,
401
+ context = None,
402
+ mask = None,
403
+ context_mask = None,
404
+ rotary_pos_emb = None,
405
+ causal = None
406
+ ):
407
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
408
+
409
+ kv_input = context if has_context else x
410
+
411
+ if hasattr(self, 'to_q'):
412
+ # Use separate linear projections for q and k/v
413
+ q = self.to_q(x)
414
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h) # [B, 24, 1025, 64]
415
+
416
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
417
+
418
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
419
+ else:
420
+ # Use fused linear projection
421
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
422
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
423
+
424
+ # Normalize q and k for cosine sim attention
425
+ if self.qk_norm == "l2":
426
+ q = F.normalize(q, dim=-1)
427
+ k = F.normalize(k, dim=-1)
428
+ elif self.qk_norm == "ln":
429
+ q = self.q_norm(q)
430
+ k = self.k_norm(k)
431
+
432
+ if rotary_pos_emb is not None and not has_context:
433
+ freqs, _ = rotary_pos_emb
434
+
435
+ q_dtype = q.dtype
436
+ k_dtype = k.dtype
437
+
438
+ q = q.to(torch.float32)
439
+ k = k.to(torch.float32)
440
+ freqs = freqs.to(torch.float32)
441
+
442
+ q = apply_rotary_pos_emb(q, freqs)
443
+ k = apply_rotary_pos_emb(k, freqs)
444
+
445
+ q = q.to(q_dtype)
446
+ k = k.to(k_dtype)
447
+
448
+ input_mask = context_mask
449
+
450
+ if input_mask is None and not has_context:
451
+ input_mask = mask
452
+
453
+ # determine masking
454
+ masks = []
455
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
456
+
457
+ if input_mask is not None:
458
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
459
+ masks.append(~input_mask)
460
+
461
+ # Other masks will be added here later
462
+
463
+ if len(masks) > 0:
464
+ final_attn_mask = ~or_reduce(masks)
465
+
466
+ n, device = q.shape[-2], q.device
467
+
468
+ causal = self.causal if causal is None else causal
469
+
470
+ if n == 1 and causal:
471
+ causal = False
472
+
473
+ if self.natten_kernel_size is not None:
474
+ if natten is None:
475
+ raise ImportError('natten not installed, please install natten to use neighborhood attention')
476
+
477
+ dtype_in = q.dtype
478
+ q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
479
+
480
+ attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
481
+
482
+ if final_attn_mask is not None:
483
+ attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
484
+
485
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32)
486
+
487
+ out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
488
+
489
+ # Prioritize Flash Attention 2
490
+ elif self.use_fa_flash:
491
+ assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
492
+ # Flash Attention 2 requires FP16 inputs
493
+ fa_dtype_in = q.dtype
494
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
495
+
496
+ out = flash_attn_func(q, k, v, causal = causal)
497
+
498
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
499
+
500
+ # Fall back to PyTorch implementation
501
+ elif self.use_pt_flash:
502
+ out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
503
+
504
+ else:
505
+ # Fall back to custom implementation
506
+
507
+ if h != kv_h:
508
+ # Repeat interleave kv_heads to match q_heads
509
+ heads_per_kv_head = h // kv_h
510
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
511
+
512
+ scale = 1. / (q.shape[-1] ** 0.5)
513
+
514
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
515
+
516
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
517
+
518
+ i, j, dtype = *dots.shape[-2:], dots.dtype
519
+
520
+ mask_value = -torch.finfo(dots.dtype).max
521
+
522
+ if final_attn_mask is not None:
523
+ dots = dots.masked_fill(~final_attn_mask, mask_value)
524
+
525
+ if causal:
526
+ causal_mask = self.create_causal_mask(i, j, device = device)
527
+ dots = dots.masked_fill(causal_mask, mask_value)
528
+
529
+ attn = F.softmax(dots, dim=-1, dtype=torch.float32)
530
+ attn = attn.type(dtype)
531
+
532
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
533
+
534
+ # merge heads
535
+ out = rearrange(out, ' b h n d -> b n (h d)')
536
+
537
+ # Communicate between heads
538
+ out = self.to_out(out)
539
+
540
+ if mask is not None:
541
+ mask = rearrange(mask, 'b n -> b n 1')
542
+ out = out.masked_fill(~mask, 0.)
543
+
544
+ return out
545
+
546
+
547
+ class ConformerModule(nn.Module):
548
+ def __init__(
549
+ self,
550
+ dim,
551
+ norm_kwargs = {},
552
+ ):
553
+
554
+ super().__init__()
555
+
556
+ self.dim = dim
557
+
558
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
559
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
560
+ self.glu = GLU(dim, dim, nn.SiLU())
561
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
562
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
563
+ self.swish = nn.SiLU()
564
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
565
+
566
+ def forward(self, x):
567
+ x = self.in_norm(x)
568
+ x = rearrange(x, 'b n d -> b d n')
569
+ x = self.pointwise_conv(x)
570
+ x = rearrange(x, 'b d n -> b n d')
571
+ x = self.glu(x)
572
+ x = rearrange(x, 'b n d -> b d n')
573
+ x = self.depthwise_conv(x)
574
+ x = rearrange(x, 'b d n -> b n d')
575
+ x = self.mid_norm(x)
576
+ x = self.swish(x)
577
+ x = rearrange(x, 'b n d -> b d n')
578
+ x = self.pointwise_conv_2(x)
579
+ x = rearrange(x, 'b d n -> b n d')
580
+
581
+ return x
582
+
583
+ class TransformerBlock(nn.Module):
584
+ def __init__(
585
+ self,
586
+ dim,
587
+ dim_heads = 64,
588
+ cross_attend = False,
589
+ dim_context = None,
590
+ global_cond_dim = None,
591
+ causal = False,
592
+ zero_init_branch_outputs = True,
593
+ conformer = False,
594
+ layer_ix = -1,
595
+ remove_norms = False,
596
+ attn_kwargs = {},
597
+ ff_kwargs = {},
598
+ norm_kwargs = {}
599
+ ):
600
+
601
+ super().__init__()
602
+ self.dim = dim
603
+ self.dim_heads = dim_heads
604
+ self.cross_attend = cross_attend
605
+ self.dim_context = dim_context
606
+ self.causal = causal
607
+
608
+ self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
609
+
610
+ self.self_attn = Attention(
611
+ dim,
612
+ dim_heads = dim_heads,
613
+ causal = causal,
614
+ zero_init_output=zero_init_branch_outputs,
615
+ **attn_kwargs
616
+ )
617
+
618
+ if cross_attend:
619
+ self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
620
+ self.cross_attn = Attention(
621
+ dim,
622
+ dim_heads = dim_heads,
623
+ dim_context=dim_context,
624
+ causal = causal,
625
+ zero_init_output=zero_init_branch_outputs,
626
+ **attn_kwargs
627
+ )
628
+
629
+ self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
630
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
631
+
632
+ self.layer_ix = layer_ix
633
+
634
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
635
+
636
+ self.global_cond_dim = global_cond_dim
637
+
638
+ if global_cond_dim is not None:
639
+ self.to_scale_shift_gate = nn.Sequential(
640
+ nn.SiLU(),
641
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
642
+ )
643
+
644
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
645
+
646
+ def forward(
647
+ self,
648
+ x,
649
+ context = None,
650
+ global_cond=None,
651
+ mask = None,
652
+ context_mask = None,
653
+ rotary_pos_emb = None,
654
+ adapter=None
655
+ ):
656
+
657
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: # False
658
+
659
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
660
+
661
+ # self-attention with adaLN
662
+ residual = x
663
+ x = self.pre_norm(x)
664
+ x = x * (1 + scale_self) + shift_self
665
+
666
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
667
+ x = x * torch.sigmoid(1 - gate_self)
668
+ x = x + residual
669
+
670
+ if context is not None:
671
+
672
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
673
+
674
+ if self.conformer is not None:
675
+ x = x + self.conformer(x)
676
+
677
+ # feedforward with adaLN
678
+ residual = x
679
+ x = self.ff_norm(x)
680
+ x = x * (1 + scale_ff) + shift_ff
681
+ x = self.ff(x)
682
+ x = x * torch.sigmoid(1 - gate_ff)
683
+ x = x + residual
684
+
685
+ else:
686
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
687
+
688
+ if context is not None:
689
+
690
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
691
+
692
+ if self.conformer is not None:
693
+ x = x + self.conformer(x)
694
+
695
+ x = x + self.ff(self.ff_norm(x))
696
+
697
+ return x
698
+
699
+ class ContinuousTransformer(nn.Module):
700
+ def __init__(
701
+ self,
702
+ dim,
703
+ depth,
704
+ *,
705
+ dim_in = None,
706
+ dim_out = None,
707
+ dim_heads = 64,
708
+ cross_attend=False,
709
+ cond_token_dim=None,
710
+ global_cond_dim=None,
711
+ causal=False,
712
+ rotary_pos_emb=True,
713
+ zero_init_branch_outputs=True,
714
+ conformer=False,
715
+ use_sinusoidal_emb=False,
716
+ use_abs_pos_emb=False,
717
+ abs_pos_emb_max_length=10000,
718
+ **kwargs
719
+ ):
720
+
721
+ super().__init__()
722
+
723
+ self.dim = dim
724
+ self.depth = depth
725
+ self.causal = causal
726
+ self.layers = nn.ModuleList([])
727
+
728
+ self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
729
+ self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
730
+
731
+ if rotary_pos_emb:
732
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
733
+ else:
734
+ self.rotary_pos_emb = None
735
+
736
+ self.use_sinusoidal_emb = use_sinusoidal_emb
737
+ if use_sinusoidal_emb:
738
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
739
+
740
+ self.use_abs_pos_emb = use_abs_pos_emb
741
+ if use_abs_pos_emb:
742
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
743
+
744
+ for i in range(depth):
745
+ self.layers.append(
746
+ TransformerBlock(
747
+ dim,
748
+ dim_heads = dim_heads,
749
+ cross_attend = cross_attend,
750
+ dim_context = cond_token_dim,
751
+ global_cond_dim = global_cond_dim,
752
+ causal = causal,
753
+ zero_init_branch_outputs = zero_init_branch_outputs,
754
+ conformer=conformer,
755
+ layer_ix=i,
756
+ **kwargs
757
+ )
758
+ )
759
+
760
+ def forward(
761
+ self,
762
+ x,
763
+ mask = None,
764
+ prepend_embeds = None,
765
+ prepend_mask = None,
766
+ global_cond = None,
767
+ return_info = False,
768
+ **kwargs
769
+ ):
770
+ batch, seq, device = *x.shape[:2], x.device
771
+
772
+ info = {
773
+ "hidden_states": [],
774
+ }
775
+
776
+ x = self.project_in(x)
777
+
778
+ if prepend_embeds is not None:
779
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
780
+
781
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
782
+
783
+ x = torch.cat((prepend_embeds, x), dim = -2)
784
+
785
+ if prepend_mask is not None or mask is not None:
786
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
787
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
788
+
789
+ mask = torch.cat((prepend_mask, mask), dim = -1)
790
+
791
+ # Attention layers
792
+ if self.rotary_pos_emb is not None:
793
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
794
+ else:
795
+ rotary_pos_emb = None
796
+
797
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
798
+ x = x + self.pos_emb(x)
799
+
800
+ # Iterate over the transformer layers
801
+ for index, layer in enumerate(self.layers):
802
+ x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
803
+
804
+ if return_info:
805
+ info["hidden_states"].append(x)
806
+
807
+ x = self.project_out(x)
808
+
809
+ if return_info:
810
+ return x, info
811
+
812
+ return x
stable_audio_tools/models/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import load_file
3
+
4
+ from torch.nn.utils import remove_weight_norm
5
+ import warnings
6
+ warnings.simplefilter(action='ignore', category=FutureWarning)
7
+
8
+
9
+ def load_ckpt_state_dict(ckpt_path):
10
+ if ckpt_path.endswith(".safetensors"):
11
+ state_dict = load_file(ckpt_path)
12
+ else:
13
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
14
+
15
+ return state_dict
16
+
17
+ def remove_weight_norm_from_model(model):
18
+ for module in model.modules():
19
+ if hasattr(module, "weight"):
20
+ print(f"Removing weight norm from {module}")
21
+ remove_weight_norm(module)
22
+
23
+ return model
24
+
25
+ # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
26
+ # License can be found in LICENSES/LICENSE_META.txt
27
+
28
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
29
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
30
+
31
+ Args:
32
+ input (torch.Tensor): The input tensor containing probabilities.
33
+ num_samples (int): Number of samples to draw.
34
+ replacement (bool): Whether to draw with replacement or not.
35
+ Keywords args:
36
+ generator (torch.Generator): A pseudorandom number generator for sampling.
37
+ Returns:
38
+ torch.Tensor: Last dimension contains num_samples indices
39
+ sampled from the multinomial probability distribution
40
+ located in the last dimension of tensor input.
41
+ """
42
+
43
+ if num_samples == 1:
44
+ q = torch.empty_like(input).exponential_(1, generator=generator)
45
+ return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
46
+
47
+ input_ = input.reshape(-1, input.shape[-1])
48
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
49
+ output = output_.reshape(*list(input.shape[:-1]), -1)
50
+ return output
51
+
52
+
53
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
54
+ """Sample next token from top K values along the last dimension of the input probs tensor.
55
+
56
+ Args:
57
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
58
+ k (int): The k in “top-k”.
59
+ Returns:
60
+ torch.Tensor: Sampled tokens.
61
+ """
62
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
63
+ min_value_top_k = top_k_value[..., [-1]]
64
+ probs *= (probs >= min_value_top_k).float()
65
+ probs.div_(probs.sum(dim=-1, keepdim=True))
66
+ next_token = multinomial(probs, num_samples=1)
67
+ return next_token
68
+
69
+
70
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
71
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
72
+
73
+ Args:
74
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
75
+ p (int): The p in “top-p”.
76
+ Returns:
77
+ torch.Tensor: Sampled tokens.
78
+ """
79
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
80
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
81
+ mask = probs_sum - probs_sort > p
82
+ probs_sort *= (~mask).float()
83
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
84
+ next_token = multinomial(probs_sort, num_samples=1)
85
+ next_token = torch.gather(probs_idx, -1, next_token)
86
+ return next_token
87
+
88
+ def next_power_of_two(n):
89
+ return 2 ** (n - 1).bit_length()
90
+
91
+ def next_multiple_of_64(n):
92
+ return ((n + 63) // 64) * 64
stable_audio_tools/models/wavelets.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The 1D discrete wavelet transform for PyTorch."""
2
+
3
+ from einops import rearrange
4
+ import pywt
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from typing import Literal
9
+
10
+
11
+ def get_filter_bank(wavelet):
12
+ filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
13
+ if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
14
+ filt = filt[:, 1:]
15
+ return filt
16
+
17
+ class WaveletEncode1d(nn.Module):
18
+ def __init__(self,
19
+ channels,
20
+ levels,
21
+ wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
22
+ super().__init__()
23
+ self.wavelet = wavelet
24
+ self.channels = channels
25
+ self.levels = levels
26
+ filt = get_filter_bank(wavelet)
27
+ assert filt.shape[-1] % 2 == 1
28
+ kernel = filt[:2, None]
29
+ kernel = torch.flip(kernel, dims=(-1,))
30
+ index_i = torch.repeat_interleave(torch.arange(2), channels)
31
+ index_j = torch.tile(torch.arange(channels), (2,))
32
+ kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
33
+ kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
34
+ self.register_buffer("kernel", kernel_final)
35
+
36
+ def forward(self, x):
37
+ for i in range(self.levels):
38
+ low, rest = x[:, : self.channels], x[:, self.channels :]
39
+ pad = self.kernel.shape[-1] // 2
40
+ low = F.pad(low, (pad, pad), "reflect")
41
+ low = F.conv1d(low, self.kernel, stride=2)
42
+ rest = rearrange(
43
+ rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
44
+ )
45
+ x = torch.cat([low, rest], dim=1)
46
+ return x
47
+
48
+
49
+ class WaveletDecode1d(nn.Module):
50
+ def __init__(self,
51
+ channels,
52
+ levels,
53
+ wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
54
+ super().__init__()
55
+ self.wavelet = wavelet
56
+ self.channels = channels
57
+ self.levels = levels
58
+ filt = get_filter_bank(wavelet)
59
+ assert filt.shape[-1] % 2 == 1
60
+ kernel = filt[2:, None]
61
+ index_i = torch.repeat_interleave(torch.arange(2), channels)
62
+ index_j = torch.tile(torch.arange(channels), (2,))
63
+ kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
64
+ kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
65
+ self.register_buffer("kernel", kernel_final)
66
+
67
+ def forward(self, x):
68
+ for i in range(self.levels):
69
+ low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
70
+ pad = self.kernel.shape[-1] // 2 + 2
71
+ low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
72
+ low = F.pad(low, (pad, pad), "reflect")
73
+ low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
74
+ low = F.conv_transpose1d(
75
+ low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
76
+ )
77
+ low = low[..., pad - 1 : -pad]
78
+ rest = rearrange(
79
+ rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
80
+ )
81
+ x = torch.cat([low, rest], dim=1)
82
+ return x
stable_audio_tools/training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_training_wrapper_from_config, create_demo_callback_from_config
stable_audio_tools/training/autoencoders.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import wandb
4
+ from einops import rearrange
5
+ from safetensors.torch import save_file, save_model
6
+ from ema_pytorch import EMA
7
+ from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss
8
+ import pytorch_lightning as pl
9
+ from ..models.autoencoders import AudioAutoencoder
10
+ from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss
11
+ from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
12
+ from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
13
+ from .utils import create_optimizer_from_config, create_scheduler_from_config
14
+
15
+
16
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
17
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
18
+
19
+ class AutoencoderTrainingWrapper(pl.LightningModule):
20
+ def __init__(
21
+ self,
22
+ autoencoder: AudioAutoencoder,
23
+ lr: float = 1e-4,
24
+ warmup_steps: int = 0,
25
+ encoder_freeze_on_warmup: bool = False,
26
+ sample_rate=48000,
27
+ loss_config: dict = None,
28
+ optimizer_configs: dict = None,
29
+ use_ema: bool = True,
30
+ ema_copy = None,
31
+ force_input_mono = False,
32
+ latent_mask_ratio = 0.0,
33
+ teacher_model: AudioAutoencoder = None
34
+ ):
35
+ super().__init__()
36
+
37
+ self.automatic_optimization = False
38
+
39
+ self.autoencoder = autoencoder
40
+
41
+ self.warmed_up = False
42
+ self.warmup_steps = warmup_steps
43
+ self.encoder_freeze_on_warmup = encoder_freeze_on_warmup
44
+ self.lr = lr
45
+
46
+ self.force_input_mono = force_input_mono
47
+
48
+ self.teacher_model = teacher_model
49
+
50
+ if optimizer_configs is None:
51
+ optimizer_configs ={
52
+ "autoencoder": {
53
+ "optimizer": {
54
+ "type": "AdamW",
55
+ "config": {
56
+ "lr": lr,
57
+ "betas": (.8, .99)
58
+ }
59
+ }
60
+ },
61
+ "discriminator": {
62
+ "optimizer": {
63
+ "type": "AdamW",
64
+ "config": {
65
+ "lr": lr,
66
+ "betas": (.8, .99)
67
+ }
68
+ }
69
+ }
70
+
71
+ }
72
+
73
+ self.optimizer_configs = optimizer_configs
74
+
75
+ if loss_config is None:
76
+ scales = [2048, 1024, 512, 256, 128, 64, 32]
77
+ hop_sizes = []
78
+ win_lengths = []
79
+ overlap = 0.75
80
+ for s in scales:
81
+ hop_sizes.append(int(s * (1 - overlap)))
82
+ win_lengths.append(s)
83
+
84
+ loss_config = {
85
+ "discriminator": {
86
+ "type": "encodec",
87
+ "config": {
88
+ "n_ffts": scales,
89
+ "hop_lengths": hop_sizes,
90
+ "win_lengths": win_lengths,
91
+ "filters": 32
92
+ },
93
+ "weights": {
94
+ "adversarial": 0.1,
95
+ "feature_matching": 5.0,
96
+ }
97
+ },
98
+ "spectral": {
99
+ "type": "mrstft",
100
+ "config": {
101
+ "fft_sizes": scales,
102
+ "hop_sizes": hop_sizes,
103
+ "win_lengths": win_lengths,
104
+ "perceptual_weighting": True
105
+ },
106
+ "weights": {
107
+ "mrstft": 1.0,
108
+ }
109
+ },
110
+ "time": {
111
+ "type": "l1",
112
+ "config": {},
113
+ "weights": {
114
+ "l1": 0.0,
115
+ }
116
+ }
117
+ }
118
+
119
+ self.loss_config = loss_config
120
+
121
+ # Spectral reconstruction loss
122
+
123
+ stft_loss_args = loss_config['spectral']['config']
124
+
125
+ if self.autoencoder.out_channels == 2:
126
+ self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
127
+ self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
128
+ else:
129
+ self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
130
+
131
+ # Discriminator
132
+
133
+ if loss_config['discriminator']['type'] == 'oobleck':
134
+ self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config'])
135
+ elif loss_config['discriminator']['type'] == 'encodec':
136
+ self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config'])
137
+ elif loss_config['discriminator']['type'] == 'dac':
138
+ self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config'])
139
+
140
+ self.gen_loss_modules = []
141
+
142
+ # Adversarial and feature matching losses
143
+ self.gen_loss_modules += [
144
+ ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'),
145
+ ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'),
146
+ ]
147
+
148
+ if self.teacher_model is not None:
149
+ # Distillation losses
150
+
151
+ stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25
152
+ self.gen_loss_modules += [
153
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss
154
+ AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder
155
+ AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder
156
+ AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder
157
+ ]
158
+
159
+ else:
160
+
161
+ # Reconstruction loss
162
+ self.gen_loss_modules += [
163
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
164
+ ]
165
+
166
+ if self.autoencoder.out_channels == 2:
167
+
168
+ # Add left and right channel reconstruction losses in addition to the sum and difference
169
+ self.gen_loss_modules += [
170
+ AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2),
171
+ AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2),
172
+ ]
173
+
174
+ self.gen_loss_modules += [
175
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
176
+ ]
177
+
178
+ if self.loss_config['time']['weights']['l1'] > 0.0:
179
+ self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss'))
180
+
181
+ if self.autoencoder.bottleneck is not None:
182
+ self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config)
183
+
184
+ self.losses_gen = MultiLoss(self.gen_loss_modules)
185
+
186
+ self.disc_loss_modules = [
187
+ ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'),
188
+ ]
189
+
190
+ self.losses_disc = MultiLoss(self.disc_loss_modules)
191
+
192
+ # Set up EMA for model weights
193
+ self.autoencoder_ema = None
194
+
195
+ self.use_ema = use_ema
196
+
197
+ if self.use_ema:
198
+ self.autoencoder_ema = EMA(
199
+ self.autoencoder,
200
+ ema_model=ema_copy,
201
+ beta=0.9999,
202
+ power=3/4,
203
+ update_every=1,
204
+ update_after_step=1
205
+ )
206
+
207
+ self.latent_mask_ratio = latent_mask_ratio
208
+
209
+ def configure_optimizers(self):
210
+
211
+ opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters())
212
+ opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters())
213
+
214
+ if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']:
215
+ sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen)
216
+ sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc)
217
+ return [opt_gen, opt_disc], [sched_gen, sched_disc]
218
+
219
+ return [opt_gen, opt_disc]
220
+
221
+ def training_step(self, batch, batch_idx):
222
+ reals, _ = batch
223
+
224
+ # Remove extra dimension added by WebDataset
225
+ if reals.ndim == 4 and reals.shape[0] == 1:
226
+ reals = reals[0]
227
+
228
+ if self.global_step >= self.warmup_steps:
229
+ self.warmed_up = True
230
+
231
+ loss_info = {}
232
+
233
+ loss_info["reals"] = reals
234
+
235
+ encoder_input = reals
236
+
237
+ if self.force_input_mono and encoder_input.shape[1] > 1:
238
+ encoder_input = encoder_input.mean(dim=1, keepdim=True)
239
+
240
+ loss_info["encoder_input"] = encoder_input
241
+
242
+ data_std = encoder_input.std()
243
+
244
+ if self.warmed_up and self.encoder_freeze_on_warmup:
245
+ with torch.no_grad():
246
+ latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
247
+ else:
248
+ latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
249
+
250
+ loss_info["latents"] = latents
251
+
252
+ loss_info.update(encoder_info)
253
+
254
+ # Encode with teacher model for distillation
255
+ if self.teacher_model is not None:
256
+ with torch.no_grad():
257
+ teacher_latents = self.teacher_model.encode(encoder_input, return_info=False)
258
+ loss_info['teacher_latents'] = teacher_latents
259
+
260
+ if self.latent_mask_ratio > 0.0:
261
+ mask = torch.rand_like(latents) < self.latent_mask_ratio
262
+ latents = torch.where(mask, torch.zeros_like(latents), latents)
263
+
264
+ decoded = self.autoencoder.decode(latents)
265
+
266
+ loss_info["decoded"] = decoded
267
+
268
+ if self.autoencoder.out_channels == 2:
269
+ loss_info["decoded_left"] = decoded[:, 0:1, :]
270
+ loss_info["decoded_right"] = decoded[:, 1:2, :]
271
+ loss_info["reals_left"] = reals[:, 0:1, :]
272
+ loss_info["reals_right"] = reals[:, 1:2, :]
273
+
274
+ # Distillation
275
+ if self.teacher_model is not None:
276
+ with torch.no_grad():
277
+ teacher_decoded = self.teacher_model.decode(teacher_latents)
278
+ own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher
279
+ teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model
280
+
281
+ loss_info['teacher_decoded'] = teacher_decoded
282
+ loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded
283
+ loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded
284
+
285
+
286
+ if self.warmed_up:
287
+ loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded)
288
+ else:
289
+ loss_dis = torch.tensor(0.).to(reals)
290
+ loss_adv = torch.tensor(0.).to(reals)
291
+ feature_matching_distance = torch.tensor(0.).to(reals)
292
+
293
+ loss_info["loss_dis"] = loss_dis
294
+ loss_info["loss_adv"] = loss_adv
295
+ loss_info["feature_matching_distance"] = feature_matching_distance
296
+
297
+ opt_gen, opt_disc = self.optimizers()
298
+
299
+ lr_schedulers = self.lr_schedulers()
300
+
301
+ sched_gen = None
302
+ sched_disc = None
303
+
304
+ if lr_schedulers is not None:
305
+ sched_gen, sched_disc = lr_schedulers
306
+
307
+ # Train the discriminator
308
+ if self.global_step % 2 and self.warmed_up:
309
+ loss, losses = self.losses_disc(loss_info)
310
+
311
+ log_dict = {
312
+ 'train/disc_lr': opt_disc.param_groups[0]['lr']
313
+ }
314
+
315
+ opt_disc.zero_grad()
316
+ self.manual_backward(loss)
317
+ opt_disc.step()
318
+
319
+ if sched_disc is not None:
320
+ # sched step every step
321
+ sched_disc.step()
322
+
323
+ # Train the generator
324
+ else:
325
+
326
+ loss, losses = self.losses_gen(loss_info)
327
+
328
+ if self.use_ema:
329
+ self.autoencoder_ema.update()
330
+
331
+ opt_gen.zero_grad()
332
+ self.manual_backward(loss)
333
+ opt_gen.step()
334
+
335
+ if sched_gen is not None:
336
+ # scheduler step every step
337
+ sched_gen.step()
338
+
339
+ log_dict = {
340
+ 'train/loss': loss.detach(),
341
+ 'train/latent_std': latents.std().detach(),
342
+ 'train/data_std': data_std.detach(),
343
+ 'train/gen_lr': opt_gen.param_groups[0]['lr']
344
+ }
345
+
346
+ for loss_name, loss_value in losses.items():
347
+ log_dict[f'train/{loss_name}'] = loss_value.detach()
348
+
349
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
350
+
351
+ return loss
352
+
353
+ def export_model(self, path, use_safetensors=False):
354
+ if self.autoencoder_ema is not None:
355
+ model = self.autoencoder_ema.ema_model
356
+ else:
357
+ model = self.autoencoder
358
+
359
+ if use_safetensors:
360
+ save_model(model, path)
361
+ else:
362
+ torch.save({"state_dict": model.state_dict()}, path)
363
+
364
+
365
+ class AutoencoderDemoCallback(pl.Callback):
366
+ def __init__(
367
+ self,
368
+ demo_dl,
369
+ demo_every=2000,
370
+ sample_size=65536,
371
+ sample_rate=48000
372
+ ):
373
+ super().__init__()
374
+ self.demo_every = demo_every
375
+ self.demo_samples = sample_size
376
+ self.demo_dl = iter(demo_dl)
377
+ self.sample_rate = sample_rate
378
+ self.last_demo_step = -1
379
+
380
+ @rank_zero_only
381
+ @torch.no_grad()
382
+ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
383
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
384
+ return
385
+
386
+ self.last_demo_step = trainer.global_step
387
+
388
+ module.eval()
389
+
390
+ try:
391
+ demo_reals, _ = next(self.demo_dl)
392
+
393
+ # Remove extra dimension added by WebDataset
394
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
395
+ demo_reals = demo_reals[0]
396
+
397
+ encoder_input = demo_reals
398
+
399
+ encoder_input = encoder_input.to(module.device)
400
+
401
+ if module.force_input_mono:
402
+ encoder_input = encoder_input.mean(dim=1, keepdim=True)
403
+
404
+ demo_reals = demo_reals.to(module.device)
405
+
406
+ with torch.no_grad():
407
+ if module.use_ema:
408
+
409
+ latents = module.autoencoder_ema.ema_model.encode(encoder_input)
410
+
411
+ fakes = module.autoencoder_ema.ema_model.decode(latents)
412
+ else:
413
+ latents = module.autoencoder.encode(encoder_input)
414
+
415
+ fakes = module.autoencoder.decode(latents)
416
+
417
+ #Interleave reals and fakes
418
+ reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
419
+
420
+ # Put the demos together
421
+ reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
422
+
423
+ log_dict = {}
424
+
425
+ filename = f'recon_{trainer.global_step:08}.wav'
426
+ reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
427
+ torchaudio.save(filename, reals_fakes, self.sample_rate)
428
+
429
+ log_dict[f'recon'] = wandb.Audio(filename,
430
+ sample_rate=self.sample_rate,
431
+ caption=f'Reconstructed')
432
+
433
+ log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
434
+ log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
435
+
436
+ log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
437
+
438
+ trainer.logger.experiment.log(log_dict)
439
+ except Exception as e:
440
+ print(f'{type(e).__name__}: {e}')
441
+ raise e
442
+ finally:
443
+ module.train()
444
+
445
+ def create_loss_modules_from_bottleneck(bottleneck, loss_config):
446
+ losses = []
447
+
448
+ if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
449
+ try:
450
+ kl_weight = loss_config['bottleneck']['weights']['kl']
451
+ except:
452
+ kl_weight = 1e-6
453
+
454
+ kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss')
455
+ losses.append(kl_loss)
456
+
457
+ if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
458
+ quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss')
459
+ losses.append(quantizer_loss)
460
+
461
+ if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck):
462
+ codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss')
463
+ commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss')
464
+ losses.append(codebook_loss)
465
+ losses.append(commitment_loss)
466
+
467
+ if isinstance(bottleneck, WassersteinBottleneck):
468
+ try:
469
+ mmd_weight = loss_config['bottleneck']['weights']['mmd']
470
+ except:
471
+ mmd_weight = 100
472
+
473
+ mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss')
474
+ losses.append(mmd_loss)
475
+
476
+ return losses
stable_audio_tools/training/diffusion.py ADDED
@@ -0,0 +1,1656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import sys, gc
3
+ import random
4
+ import torch
5
+ import torchaudio
6
+ import typing as tp
7
+ import wandb
8
+
9
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
10
+ import auraloss
11
+ from ema_pytorch import EMA
12
+ from einops import rearrange
13
+ from safetensors.torch import save_file
14
+ from torch import optim
15
+ from torch.nn import functional as F
16
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
17
+
18
+ from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
19
+ from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
20
+ from ..models.autoencoders import DiffusionAutoencoder
21
+ from ..models.diffusion_prior import PriorType
22
+ from .autoencoders import create_loss_modules_from_bottleneck
23
+ from .losses import AuralossLoss, MSELoss, MultiLoss
24
+ from .utils import create_optimizer_from_config, create_scheduler_from_config
25
+
26
+ from time import time
27
+
28
+
29
+ class Profiler:
30
+
31
+ def __init__(self):
32
+ self.ticks = [[time(), None]]
33
+
34
+ def tick(self, msg):
35
+ self.ticks.append([time(), msg])
36
+
37
+ def __repr__(self):
38
+ rep = 80 * "=" + "\n"
39
+ for i in range(1, len(self.ticks)):
40
+ msg = self.ticks[i][1]
41
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
42
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
43
+ rep += 80 * "=" + "\n\n\n"
44
+ return rep
45
+
46
+ class DiffusionUncondTrainingWrapper(pl.LightningModule):
47
+ '''
48
+ Wrapper for training an unconditional audio diffusion model (like Dance Diffusion).
49
+ '''
50
+ def __init__(
51
+ self,
52
+ model: DiffusionModelWrapper,
53
+ lr: float = 1e-4,
54
+ pre_encoded: bool = False
55
+ ):
56
+ super().__init__()
57
+
58
+ self.diffusion = model
59
+
60
+ self.diffusion_ema = EMA(
61
+ self.diffusion.model,
62
+ beta=0.9999,
63
+ power=3/4,
64
+ update_every=1,
65
+ update_after_step=1
66
+ )
67
+
68
+ self.lr = lr
69
+
70
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
71
+
72
+ loss_modules = [
73
+ MSELoss("v",
74
+ "targets",
75
+ weight=1.0,
76
+ name="mse_loss"
77
+ )
78
+ ]
79
+
80
+ self.losses = MultiLoss(loss_modules)
81
+
82
+ self.pre_encoded = pre_encoded
83
+
84
+ def configure_optimizers(self):
85
+ return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
86
+
87
+ def training_step(self, batch, batch_idx):
88
+ reals = batch[0]
89
+
90
+ if reals.ndim == 4 and reals.shape[0] == 1:
91
+ reals = reals[0]
92
+
93
+ diffusion_input = reals
94
+
95
+ loss_info = {}
96
+
97
+ if not self.pre_encoded:
98
+ loss_info["audio_reals"] = diffusion_input
99
+
100
+ if self.diffusion.pretransform is not None:
101
+ if not self.pre_encoded:
102
+ with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
103
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
104
+ else:
105
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
106
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
107
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
108
+
109
+ loss_info["reals"] = diffusion_input
110
+
111
+ # Draw uniformly distributed continuous timesteps
112
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
113
+
114
+ # Calculate the noise schedule parameters for those timesteps
115
+ alphas, sigmas = get_alphas_sigmas(t)
116
+
117
+ # Combine the ground truth data and the noise
118
+ alphas = alphas[:, None, None]
119
+ sigmas = sigmas[:, None, None]
120
+ noise = torch.randn_like(diffusion_input)
121
+ noised_inputs = diffusion_input * alphas + noise * sigmas
122
+ targets = noise * alphas - diffusion_input * sigmas
123
+
124
+ with torch.cuda.amp.autocast():
125
+ v = self.diffusion(noised_inputs, t)
126
+
127
+ loss_info.update({
128
+ "v": v,
129
+ "targets": targets
130
+ })
131
+
132
+ loss, losses = self.losses(loss_info)
133
+
134
+ log_dict = {
135
+ 'train/loss': loss.detach(),
136
+ 'train/std_data': diffusion_input.std(),
137
+ }
138
+
139
+ for loss_name, loss_value in losses.items():
140
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
141
+
142
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
143
+ return loss
144
+
145
+ def on_before_zero_grad(self, *args, **kwargs):
146
+ self.diffusion_ema.update()
147
+
148
+ def export_model(self, path, use_safetensors=False):
149
+
150
+ self.diffusion.model = self.diffusion_ema.ema_model
151
+
152
+ if use_safetensors:
153
+ save_file(self.diffusion.state_dict(), path)
154
+ else:
155
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
156
+
157
+ class DiffusionUncondDemoCallback(pl.Callback):
158
+ def __init__(self,
159
+ demo_every=2000,
160
+ num_demos=8,
161
+ demo_steps=250,
162
+ sample_rate=48000
163
+ ):
164
+ super().__init__()
165
+
166
+ self.demo_every = demo_every
167
+ self.num_demos = num_demos
168
+ self.demo_steps = demo_steps
169
+ self.sample_rate = sample_rate
170
+ self.last_demo_step = -1
171
+
172
+ @rank_zero_only
173
+ @torch.no_grad()
174
+ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
175
+
176
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
177
+ return
178
+
179
+ self.last_demo_step = trainer.global_step
180
+
181
+ demo_samples = module.diffusion.sample_size
182
+
183
+ if module.diffusion.pretransform is not None:
184
+ demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
185
+
186
+ noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
187
+
188
+ try:
189
+ with torch.cuda.amp.autocast():
190
+ fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0)
191
+
192
+ if module.diffusion.pretransform is not None:
193
+ fakes = module.diffusion.pretransform.decode(fakes)
194
+
195
+ # Put the demos together
196
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
197
+
198
+ log_dict = {}
199
+
200
+ filename = f'demo_{trainer.global_step:08}.wav'
201
+ fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
202
+ torchaudio.save(filename, fakes, self.sample_rate)
203
+
204
+ log_dict[f'demo'] = wandb.Audio(filename,
205
+ sample_rate=self.sample_rate,
206
+ caption=f'Reconstructed')
207
+
208
+ log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes))
209
+
210
+ trainer.logger.experiment.log(log_dict)
211
+
212
+ del fakes
213
+
214
+ except Exception as e:
215
+ print(f'{type(e).__name__}: {e}')
216
+ finally:
217
+ gc.collect()
218
+ torch.cuda.empty_cache()
219
+
220
+ class DiffusionCondTrainingWrapper(pl.LightningModule):
221
+ '''
222
+ Wrapper for training a conditional audio diffusion model.
223
+ '''
224
+ def __init__(
225
+ self,
226
+ model: ConditionedDiffusionModelWrapper,
227
+ lr: float = None,
228
+ mask_padding: bool = False,
229
+ mask_padding_dropout: float = 0.0,
230
+ use_ema: bool = True,
231
+ log_loss_info: bool = True,
232
+ optimizer_configs: dict = None,
233
+ pre_encoded: bool = False,
234
+ cfg_dropout_prob = 0.1,
235
+ timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
236
+ ):
237
+ super().__init__()
238
+
239
+ self.diffusion = model
240
+
241
+ if use_ema:
242
+ self.diffusion_ema = EMA(
243
+ self.diffusion.model,
244
+ beta=0.9999,
245
+ power=3/4,
246
+ update_every=1,
247
+ update_after_step=1,
248
+ include_online_model=False
249
+ )
250
+ else:
251
+ self.diffusion_ema = None
252
+
253
+ self.mask_padding = mask_padding
254
+ self.mask_padding_dropout = mask_padding_dropout
255
+
256
+ self.cfg_dropout_prob = cfg_dropout_prob
257
+
258
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
259
+
260
+ self.timestep_sampler = timestep_sampler
261
+
262
+ self.diffusion_objective = model.diffusion_objective
263
+
264
+ if 'av_loss' in optimizer_configs and optimizer_configs['av_loss']['if_add_av_loss']:
265
+ av_align_weight = optimizer_configs['av_loss']['config']['weight']
266
+ self.loss_modules = [
267
+ MSELoss("output",
268
+ "targets",
269
+ weight=1.0 - av_align_weight,
270
+ mask_key="padding_mask" if self.mask_padding else None,
271
+ name="mse_loss"
272
+ )
273
+ ]
274
+ else:
275
+ self.loss_modules = [
276
+ MSELoss("output",
277
+ "targets",
278
+ weight=1.0,
279
+ mask_key="padding_mask" if self.mask_padding else None,
280
+ name="mse_loss"
281
+ )
282
+ ]
283
+
284
+
285
+ self.losses = MultiLoss(self.loss_modules)
286
+
287
+ self.log_loss_info = log_loss_info
288
+
289
+ assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
290
+
291
+ if optimizer_configs is None:
292
+ optimizer_configs = {
293
+ "diffusion": {
294
+ "optimizer": {
295
+ "type": "Adam",
296
+ "config": {
297
+ "lr": lr
298
+ }
299
+ }
300
+ }
301
+ }
302
+ else:
303
+ if lr is not None:
304
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
305
+
306
+ self.optimizer_configs = optimizer_configs
307
+
308
+ self.pre_encoded = pre_encoded
309
+
310
+ def configure_optimizers(self):
311
+ diffusion_opt_config = self.optimizer_configs['diffusion']
312
+ opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
313
+
314
+ if "scheduler" in diffusion_opt_config:
315
+ sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
316
+ sched_diff_config = {
317
+ "scheduler": sched_diff,
318
+ "interval": "step"
319
+ }
320
+ return [opt_diff], [sched_diff_config]
321
+
322
+ return [opt_diff]
323
+
324
+ def training_step(self, batch, batch_idx):
325
+
326
+
327
+ reals, metadata = batch
328
+
329
+ p = Profiler()
330
+
331
+ if reals.ndim == 4 and reals.shape[0] == 1:
332
+ reals = reals[0]
333
+
334
+ loss_info = {}
335
+
336
+ diffusion_input = reals
337
+ if not self.pre_encoded:
338
+ loss_info["audio_reals"] = diffusion_input
339
+
340
+ p.tick("setup")
341
+
342
+ with torch.cuda.amp.autocast():
343
+ conditioning = self.diffusion.conditioner(metadata, self.device)
344
+
345
+ use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout
346
+
347
+ # Create batch tensor of attention masks from the "mask" field of the metadata array
348
+ if use_padding_mask:
349
+ padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device)
350
+
351
+ p.tick("conditioning")
352
+
353
+ if self.diffusion.pretransform is not None:
354
+ self.diffusion.pretransform.to(self.device)
355
+
356
+ if not self.pre_encoded:
357
+ with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
358
+ self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
359
+
360
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
361
+ p.tick("pretransform")
362
+
363
+ # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
364
+ if use_padding_mask:
365
+ padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
366
+ else:
367
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
368
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
369
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
370
+
371
+ if self.timestep_sampler == "uniform":
372
+ # Draw uniformly distributed continuous timesteps
373
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) # [0.1360, 0.5232]
374
+ elif self.timestep_sampler == "logit_normal":
375
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
376
+
377
+ # Calculate the noise schedule parameters for those timesteps
378
+ if self.diffusion_objective == "v":
379
+ alphas, sigmas = get_alphas_sigmas(t)
380
+ elif self.diffusion_objective == "rectified_flow":
381
+ alphas, sigmas = 1-t, t
382
+
383
+ # Combine the ground truth data and the noise
384
+ alphas = alphas[:, None, None]
385
+ sigmas = sigmas[:, None, None]
386
+ noise = torch.randn_like(diffusion_input)
387
+ noised_inputs = diffusion_input * alphas + noise * sigmas
388
+
389
+ if self.diffusion_objective == "v":
390
+ targets = noise * alphas - diffusion_input * sigmas
391
+ elif self.diffusion_objective == "rectified_flow":
392
+ targets = noise - diffusion_input
393
+
394
+ p.tick("noise")
395
+
396
+ extra_args = {}
397
+
398
+ if use_padding_mask:
399
+ extra_args["mask"] = padding_masks
400
+
401
+ with torch.cuda.amp.autocast():
402
+ p.tick("amp")
403
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
404
+ p.tick("diffusion")
405
+
406
+ loss_info.update({
407
+ "output": output,
408
+ "targets": targets,
409
+ "padding_mask": padding_masks if use_padding_mask else None,
410
+ })
411
+
412
+ loss, losses = self.losses(loss_info)
413
+
414
+ p.tick("loss")
415
+
416
+ if self.log_loss_info:
417
+ # Loss debugging logs
418
+ num_loss_buckets = 10
419
+ bucket_size = 1 / num_loss_buckets
420
+ loss_all = F.mse_loss(output, targets, reduction="none")
421
+
422
+ sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze()
423
+
424
+ # gather loss_all across all GPUs
425
+ loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n")
426
+
427
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
428
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
429
+
430
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
431
+ debug_log_dict = {
432
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
433
+ }
434
+
435
+ self.log_dict(debug_log_dict)
436
+
437
+
438
+ log_dict = {
439
+ 'train/loss': loss.detach(),
440
+ 'train/std_data': diffusion_input.std(),
441
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
442
+ }
443
+
444
+ for loss_name, loss_value in losses.items():
445
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
446
+
447
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
448
+ p.tick("log")
449
+ #print(f"Profiler: {p}")
450
+ return loss
451
+
452
+ def validation_step(self, batch, batch_idx):
453
+ reals, metadata = batch
454
+
455
+ p = Profiler()
456
+
457
+ if reals.ndim == 4 and reals.shape[0] == 1:
458
+ reals = reals[0]
459
+
460
+ loss_info = {}
461
+
462
+ diffusion_input = reals
463
+
464
+ if not self.pre_encoded:
465
+ loss_info["audio_reals"] = diffusion_input
466
+
467
+ p.tick("setup")
468
+ with torch.cuda.amp.autocast():
469
+ conditioning = self.diffusion.conditioner(metadata, self.device)
470
+
471
+ # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding
472
+ use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout
473
+
474
+ # Create batch tensor of attention masks from the "mask" field of the metadata array
475
+ if use_padding_mask:
476
+ padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length)
477
+
478
+ p.tick("conditioning")
479
+
480
+ if self.diffusion.pretransform is not None:
481
+ self.diffusion.pretransform.to(self.device)
482
+
483
+ if not self.pre_encoded:
484
+ with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
485
+ self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
486
+
487
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
488
+ p.tick("pretransform")
489
+
490
+ # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
491
+ if use_padding_mask:
492
+ padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
493
+ else:
494
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
495
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
496
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
497
+
498
+ if self.timestep_sampler == "uniform":
499
+ # Draw uniformly distributed continuous timesteps
500
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
501
+ elif self.timestep_sampler == "logit_normal":
502
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
503
+
504
+ # Calculate the noise schedule parameters for those timesteps
505
+ if self.diffusion_objective == "v":
506
+ alphas, sigmas = get_alphas_sigmas(t)
507
+ elif self.diffusion_objective == "rectified_flow":
508
+ alphas, sigmas = 1-t, t
509
+
510
+ # Combine the ground truth data and the noise
511
+ alphas = alphas[:, None, None]
512
+ sigmas = sigmas[:, None, None]
513
+ noise = torch.randn_like(diffusion_input)
514
+ noised_inputs = diffusion_input * alphas + noise * sigmas
515
+
516
+ if self.diffusion_objective == "v":
517
+ targets = noise * alphas - diffusion_input * sigmas
518
+ elif self.diffusion_objective == "rectified_flow":
519
+ targets = noise - diffusion_input
520
+
521
+ p.tick("noise")
522
+
523
+ extra_args = {}
524
+
525
+ if use_padding_mask:
526
+ extra_args["mask"] = padding_masks
527
+
528
+ with torch.cuda.amp.autocast():
529
+ p.tick("amp")
530
+
531
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
532
+ p.tick("diffusion")
533
+
534
+ loss_info.update({
535
+ "output": output,
536
+ "targets": targets,
537
+ "padding_mask": padding_masks if use_padding_mask else None,
538
+ })
539
+
540
+ loss, losses = self.losses(loss_info)
541
+
542
+ p.tick("loss")
543
+
544
+ if self.log_loss_info:
545
+ # Loss debugging logs
546
+ num_loss_buckets = 10
547
+ bucket_size = 1 / num_loss_buckets
548
+ loss_all = F.mse_loss(output, targets, reduction="none")
549
+ # loss_all = F.binary_cross_entropy_with_logits(output, targets, reduction="none")
550
+
551
+
552
+ sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze()
553
+ # sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
554
+
555
+ # gather loss_all across all GPUs
556
+ loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n")
557
+ # loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
558
+
559
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
560
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
561
+
562
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
563
+ debug_log_dict = {
564
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
565
+ }
566
+
567
+ self.log_dict(debug_log_dict)
568
+
569
+
570
+ log_dict = {
571
+ 'valid/loss': loss.detach(),
572
+ 'valid/std_data': diffusion_input.std(),
573
+ 'valid/lr': self.trainer.optimizers[0].param_groups[0]['lr']
574
+ }
575
+
576
+
577
+ for loss_name, loss_value in losses.items():
578
+ log_dict[f"valid/{loss_name}"] = loss_value.detach()
579
+
580
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
581
+ # self.log('val_loss', val_loss, on_epoch=True, on_step=True)
582
+
583
+ p.tick("log")
584
+ #print(f"Profiler: {p}")
585
+ return loss
586
+
587
+ def on_before_zero_grad(self, *args, **kwargs):
588
+ if self.diffusion_ema is not None:
589
+ self.diffusion_ema.update()
590
+
591
+ def export_model(self, path, use_safetensors=False):
592
+ if self.diffusion_ema is not None:
593
+ self.diffusion.model = self.diffusion_ema.ema_model
594
+
595
+ if use_safetensors:
596
+ save_file(self.diffusion.state_dict(), path)
597
+ else:
598
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
599
+
600
+ class DiffusionCondDemoCallback(pl.Callback):
601
+ def __init__(self,
602
+ demo_every=2000,
603
+ num_demos=8,
604
+ sample_size=65536,
605
+ demo_steps=250,
606
+ sample_rate=48000,
607
+ demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {},
608
+ demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
609
+ demo_cond_from_batch: bool = False,
610
+ display_audio_cond: bool = False
611
+ ):
612
+ super().__init__()
613
+
614
+ self.demo_every = demo_every
615
+ self.num_demos = num_demos
616
+ self.demo_samples = sample_size
617
+ self.demo_steps = demo_steps
618
+ self.sample_rate = sample_rate
619
+ self.last_demo_step = -1
620
+ self.demo_conditioning = demo_conditioning
621
+ self.demo_cfg_scales = demo_cfg_scales
622
+
623
+ # If true, the callback will use the metadata from the batch to generate the demo conditioning
624
+ self.demo_cond_from_batch = demo_cond_from_batch
625
+
626
+ # If true, the callback will display the audio conditioning
627
+ self.display_audio_cond = display_audio_cond
628
+
629
+ @rank_zero_only
630
+ @torch.no_grad()
631
+ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
632
+
633
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
634
+ return
635
+
636
+ module.eval()
637
+
638
+ print(f"Generating demo")
639
+ self.last_demo_step = trainer.global_step
640
+
641
+ demo_samples = self.demo_samples
642
+
643
+ demo_cond = self.demo_conditioning
644
+
645
+ if self.demo_cond_from_batch:
646
+ # Get metadata from the batch
647
+ demo_cond = batch[1][:self.num_demos]
648
+
649
+ if module.diffusion.pretransform is not None:
650
+ demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
651
+
652
+ noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
653
+
654
+ try:
655
+ print("Getting conditioning")
656
+ with torch.cuda.amp.autocast():
657
+ conditioning = module.diffusion.conditioner(demo_cond, module.device)
658
+
659
+
660
+ cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
661
+
662
+ log_dict = {}
663
+
664
+ if self.display_audio_cond:
665
+ audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0)
666
+ audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)')
667
+
668
+ filename = f'demo_audio_cond_{trainer.global_step:08}.wav'
669
+ audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu()
670
+ torchaudio.save(filename, audio_inputs, self.sample_rate)
671
+ log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning")
672
+ log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs))
673
+ trainer.logger.experiment.log(log_dict)
674
+
675
+ for cfg_scale in self.demo_cfg_scales:
676
+
677
+ print(f"Generating demo for cfg scale {cfg_scale}")
678
+
679
+ with torch.cuda.amp.autocast():
680
+ model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
681
+
682
+ if module.diffusion_objective == "v":
683
+ fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
684
+ elif module.diffusion_objective == "rectified_flow":
685
+ fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
686
+
687
+ if module.diffusion.pretransform is not None:
688
+ fakes = module.diffusion.pretransform.decode(fakes)
689
+
690
+ # Put the demos together
691
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
692
+
693
+ log_dict = {}
694
+
695
+ filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
696
+ fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
697
+ torchaudio.save(filename, fakes, self.sample_rate)
698
+
699
+ log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
700
+ sample_rate=self.sample_rate,
701
+ caption=f'Reconstructed')
702
+
703
+ log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
704
+
705
+ trainer.logger.experiment.log(log_dict)
706
+
707
+ del fakes
708
+
709
+ except Exception as e:
710
+ raise e
711
+ finally:
712
+ gc.collect()
713
+ torch.cuda.empty_cache()
714
+ module.train()
715
+
716
+ class DiffusionCondInpaintTrainingWrapper(pl.LightningModule):
717
+ '''
718
+ Wrapper for training a conditional audio diffusion model.
719
+ '''
720
+ def __init__(
721
+ self,
722
+ model: ConditionedDiffusionModelWrapper,
723
+ lr: float = 1e-4,
724
+ max_mask_segments = 10,
725
+ log_loss_info: bool = False,
726
+ optimizer_configs: dict = None,
727
+ use_ema: bool = True,
728
+ pre_encoded: bool = False,
729
+ cfg_dropout_prob = 0.1,
730
+ timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
731
+ ):
732
+ super().__init__()
733
+
734
+ self.diffusion = model
735
+
736
+ self.use_ema = use_ema
737
+
738
+ if self.use_ema:
739
+ self.diffusion_ema = EMA(
740
+ self.diffusion.model,
741
+ beta=0.9999,
742
+ power=3/4,
743
+ update_every=1,
744
+ update_after_step=1,
745
+ include_online_model=False
746
+ )
747
+ else:
748
+ self.diffusion_ema = None
749
+
750
+ self.cfg_dropout_prob = cfg_dropout_prob
751
+
752
+ self.lr = lr
753
+ self.max_mask_segments = max_mask_segments
754
+
755
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
756
+
757
+ self.timestep_sampler = timestep_sampler
758
+
759
+ self.diffusion_objective = model.diffusion_objective
760
+
761
+ self.loss_modules = [
762
+ MSELoss("output",
763
+ "targets",
764
+ weight=1.0,
765
+ name="mse_loss"
766
+ )
767
+ ]
768
+
769
+ self.losses = MultiLoss(self.loss_modules)
770
+
771
+ self.log_loss_info = log_loss_info
772
+
773
+ assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
774
+
775
+ if optimizer_configs is None:
776
+ optimizer_configs = {
777
+ "diffusion": {
778
+ "optimizer": {
779
+ "type": "Adam",
780
+ "config": {
781
+ "lr": lr
782
+ }
783
+ }
784
+ }
785
+ }
786
+ else:
787
+ if lr is not None:
788
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
789
+
790
+ self.optimizer_configs = optimizer_configs
791
+
792
+ self.pre_encoded = pre_encoded
793
+
794
+ def configure_optimizers(self):
795
+ diffusion_opt_config = self.optimizer_configs['diffusion']
796
+ opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
797
+
798
+ if "scheduler" in diffusion_opt_config:
799
+ sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
800
+ sched_diff_config = {
801
+ "scheduler": sched_diff,
802
+ "interval": "step"
803
+ }
804
+ return [opt_diff], [sched_diff_config]
805
+
806
+ return [opt_diff]
807
+
808
+ def random_mask(self, sequence, max_mask_length):
809
+ b, _, sequence_length = sequence.size()
810
+
811
+ # Create a mask tensor for each batch element
812
+ masks = []
813
+
814
+ for i in range(b):
815
+ mask_type = random.randint(0, 2)
816
+
817
+ if mask_type == 0: # Random mask with multiple segments
818
+ num_segments = random.randint(1, self.max_mask_segments)
819
+ max_segment_length = max_mask_length // num_segments
820
+
821
+ segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
822
+
823
+ mask = torch.ones((1, 1, sequence_length))
824
+ for length in segment_lengths:
825
+ mask_start = random.randint(0, sequence_length - length)
826
+ mask[:, :, mask_start:mask_start + length] = 0
827
+
828
+ elif mask_type == 1: # Full mask
829
+ mask = torch.zeros((1, 1, sequence_length))
830
+
831
+ elif mask_type == 2: # Causal mask
832
+ mask = torch.ones((1, 1, sequence_length))
833
+ mask_length = random.randint(1, max_mask_length)
834
+ mask[:, :, -mask_length:] = 0
835
+
836
+ mask = mask.to(sequence.device)
837
+ masks.append(mask)
838
+
839
+ # Concatenate the mask tensors into a single tensor
840
+ mask = torch.cat(masks, dim=0).to(sequence.device)
841
+
842
+ # Apply the mask to the sequence tensor for each batch element
843
+ masked_sequence = sequence * mask
844
+
845
+ return masked_sequence, mask
846
+
847
+ def training_step(self, batch, batch_idx):
848
+ reals, metadata = batch
849
+
850
+ p = Profiler()
851
+
852
+ if reals.ndim == 4 and reals.shape[0] == 1:
853
+ reals = reals[0]
854
+
855
+ loss_info = {}
856
+
857
+ diffusion_input = reals
858
+
859
+ if not self.pre_encoded:
860
+ loss_info["audio_reals"] = diffusion_input
861
+
862
+ p.tick("setup")
863
+
864
+ with torch.cuda.amp.autocast():
865
+ conditioning = self.diffusion.conditioner(metadata, self.device)
866
+
867
+ p.tick("conditioning")
868
+
869
+ if self.diffusion.pretransform is not None:
870
+ self.diffusion.pretransform.to(self.device)
871
+
872
+ if not self.pre_encoded:
873
+ with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
874
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
875
+ p.tick("pretransform")
876
+
877
+ # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
878
+ # if use_padding_mask:
879
+ # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
880
+ else:
881
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
882
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
883
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
884
+
885
+ # Max mask size is the full sequence length
886
+ max_mask_length = diffusion_input.shape[2]
887
+
888
+ # Create a mask of random length for a random slice of the input
889
+ masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
890
+
891
+ conditioning['inpaint_mask'] = [mask]
892
+ conditioning['inpaint_masked_input'] = [masked_input]
893
+
894
+ if self.timestep_sampler == "uniform":
895
+ # Draw uniformly distributed continuous timesteps
896
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
897
+ elif self.timestep_sampler == "logit_normal":
898
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
899
+
900
+ # Calculate the noise schedule parameters for those timesteps
901
+ if self.diffusion_objective == "v":
902
+ alphas, sigmas = get_alphas_sigmas(t)
903
+ elif self.diffusion_objective == "rectified_flow":
904
+ alphas, sigmas = 1-t, t
905
+
906
+ # Combine the ground truth data and the noise
907
+ alphas = alphas[:, None, None]
908
+ sigmas = sigmas[:, None, None]
909
+ noise = torch.randn_like(diffusion_input)
910
+ noised_inputs = diffusion_input * alphas + noise * sigmas
911
+
912
+ if self.diffusion_objective == "v":
913
+ targets = noise * alphas - diffusion_input * sigmas
914
+ elif self.diffusion_objective == "rectified_flow":
915
+ targets = noise - diffusion_input
916
+
917
+ p.tick("noise")
918
+
919
+ extra_args = {}
920
+
921
+ with torch.cuda.amp.autocast():
922
+ p.tick("amp")
923
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
924
+ p.tick("diffusion")
925
+
926
+ loss_info.update({
927
+ "output": output,
928
+ "targets": targets,
929
+ })
930
+
931
+ loss, losses = self.losses(loss_info)
932
+
933
+ if self.log_loss_info:
934
+ # Loss debugging logs
935
+ num_loss_buckets = 10
936
+ bucket_size = 1 / num_loss_buckets
937
+ loss_all = F.mse_loss(output, targets, reduction="none")
938
+
939
+ sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
940
+
941
+ # gather loss_all across all GPUs
942
+ loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
943
+
944
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
945
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
946
+
947
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
948
+ debug_log_dict = {
949
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
950
+ }
951
+
952
+ self.log_dict(debug_log_dict)
953
+
954
+ log_dict = {
955
+ 'train/loss': loss.detach(),
956
+ 'train/std_data': diffusion_input.std(),
957
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
958
+ }
959
+
960
+ for loss_name, loss_value in losses.items():
961
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
962
+
963
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
964
+ p.tick("log")
965
+ #print(f"Profiler: {p}")
966
+ return loss
967
+
968
+ def on_before_zero_grad(self, *args, **kwargs):
969
+ if self.diffusion_ema is not None:
970
+ self.diffusion_ema.update()
971
+
972
+ def export_model(self, path, use_safetensors=False):
973
+ if self.diffusion_ema is not None:
974
+ self.diffusion.model = self.diffusion_ema.ema_model
975
+
976
+ if use_safetensors:
977
+ save_file(self.diffusion.state_dict(), path)
978
+ else:
979
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
980
+
981
+ class DiffusionCondInpaintDemoCallback(pl.Callback):
982
+ def __init__(
983
+ self,
984
+ demo_dl,
985
+ demo_every=2000,
986
+ demo_steps=250,
987
+ sample_size=65536,
988
+ sample_rate=48000,
989
+ demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7]
990
+ ):
991
+ super().__init__()
992
+ self.demo_every = demo_every
993
+ self.demo_steps = demo_steps
994
+ self.demo_samples = sample_size
995
+ self.demo_dl = iter(demo_dl)
996
+ self.sample_rate = sample_rate
997
+ self.demo_cfg_scales = demo_cfg_scales
998
+ self.last_demo_step = -1
999
+
1000
+ @rank_zero_only
1001
+ @torch.no_grad()
1002
+ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
1003
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1004
+ return
1005
+
1006
+ self.last_demo_step = trainer.global_step
1007
+
1008
+ try:
1009
+ log_dict = {}
1010
+
1011
+ demo_reals, metadata = next(self.demo_dl)
1012
+
1013
+ # Remove extra dimension added by WebDataset
1014
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1015
+ demo_reals = demo_reals[0]
1016
+
1017
+ demo_reals = demo_reals.to(module.device)
1018
+
1019
+ if not module.pre_encoded:
1020
+ # Log the real audio
1021
+ log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()))
1022
+ # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals")
1023
+
1024
+ if module.diffusion.pretransform is not None:
1025
+ module.diffusion.pretransform.to(module.device)
1026
+ with torch.cuda.amp.autocast():
1027
+ demo_reals = module.diffusion.pretransform.encode(demo_reals)
1028
+
1029
+ demo_samples = demo_reals.shape[2]
1030
+
1031
+ # Get conditioning
1032
+ conditioning = module.diffusion.conditioner(metadata, module.device)
1033
+
1034
+ masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2])
1035
+
1036
+ conditioning['inpaint_mask'] = [mask]
1037
+ conditioning['inpaint_masked_input'] = [masked_input]
1038
+
1039
+ if module.diffusion.pretransform is not None:
1040
+ log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu()))
1041
+ else:
1042
+ log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu()))
1043
+
1044
+ cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
1045
+
1046
+ noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device)
1047
+
1048
+ trainer.logger.experiment.log(log_dict)
1049
+
1050
+ for cfg_scale in self.demo_cfg_scales:
1051
+ model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
1052
+ print(f"Generating demo for cfg scale {cfg_scale}")
1053
+
1054
+ if module.diffusion_objective == "v":
1055
+ fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1056
+ elif module.diffusion_objective == "rectified_flow":
1057
+ fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1058
+
1059
+ if module.diffusion.pretransform is not None:
1060
+ with torch.cuda.amp.autocast():
1061
+ fakes = module.diffusion.pretransform.decode(fakes)
1062
+
1063
+ # Put the demos together
1064
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
1065
+
1066
+ log_dict = {}
1067
+
1068
+ filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
1069
+ fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
1070
+ torchaudio.save(filename, fakes, self.sample_rate)
1071
+
1072
+ log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
1073
+ sample_rate=self.sample_rate,
1074
+ caption=f'Reconstructed')
1075
+
1076
+ log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
1077
+
1078
+ trainer.logger.experiment.log(log_dict)
1079
+ except Exception as e:
1080
+ print(f'{type(e).__name__}: {e}')
1081
+ raise e
1082
+
1083
+ class DiffusionAutoencoderTrainingWrapper(pl.LightningModule):
1084
+ '''
1085
+ Wrapper for training a diffusion autoencoder
1086
+ '''
1087
+ def __init__(
1088
+ self,
1089
+ model: DiffusionAutoencoder,
1090
+ lr: float = 1e-4,
1091
+ ema_copy = None,
1092
+ use_reconstruction_loss: bool = False
1093
+ ):
1094
+ super().__init__()
1095
+
1096
+ self.diffae = model
1097
+
1098
+ self.diffae_ema = EMA(
1099
+ self.diffae,
1100
+ ema_model=ema_copy,
1101
+ beta=0.9999,
1102
+ power=3/4,
1103
+ update_every=1,
1104
+ update_after_step=1,
1105
+ include_online_model=False
1106
+ )
1107
+
1108
+ self.lr = lr
1109
+
1110
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1111
+
1112
+ loss_modules = [
1113
+ MSELoss("v",
1114
+ "targets",
1115
+ weight=1.0,
1116
+ name="mse_loss"
1117
+ )
1118
+ ]
1119
+
1120
+ if model.bottleneck is not None:
1121
+ # TODO: Use loss config for configurable bottleneck weights and reconstruction losses
1122
+ loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {})
1123
+
1124
+ self.use_reconstruction_loss = use_reconstruction_loss
1125
+
1126
+ if use_reconstruction_loss:
1127
+ scales = [2048, 1024, 512, 256, 128, 64, 32]
1128
+ hop_sizes = []
1129
+ win_lengths = []
1130
+ overlap = 0.75
1131
+ for s in scales:
1132
+ hop_sizes.append(int(s * (1 - overlap)))
1133
+ win_lengths.append(s)
1134
+
1135
+ sample_rate = model.sample_rate
1136
+
1137
+ stft_loss_args = {
1138
+ "fft_sizes": scales,
1139
+ "hop_sizes": hop_sizes,
1140
+ "win_lengths": win_lengths,
1141
+ "perceptual_weighting": True
1142
+ }
1143
+
1144
+ out_channels = model.out_channels
1145
+
1146
+ if model.pretransform is not None:
1147
+ out_channels = model.pretransform.io_channels
1148
+
1149
+ if out_channels == 2:
1150
+ self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1151
+ else:
1152
+ self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1153
+
1154
+ loss_modules.append(
1155
+ AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1156
+ )
1157
+
1158
+ self.losses = MultiLoss(loss_modules)
1159
+
1160
+ def configure_optimizers(self):
1161
+ return optim.Adam([*self.diffae.parameters()], lr=self.lr)
1162
+
1163
+ def training_step(self, batch, batch_idx):
1164
+ reals = batch[0]
1165
+
1166
+ if reals.ndim == 4 and reals.shape[0] == 1:
1167
+ reals = reals[0]
1168
+
1169
+ loss_info = {}
1170
+
1171
+ loss_info["audio_reals"] = reals
1172
+
1173
+ if self.diffae.pretransform is not None:
1174
+ with torch.no_grad():
1175
+ reals = self.diffae.pretransform.encode(reals)
1176
+
1177
+ loss_info["reals"] = reals
1178
+
1179
+ #Encode reals, skipping the pretransform since it was already applied
1180
+ latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True)
1181
+
1182
+ loss_info["latents"] = latents
1183
+ loss_info.update(encoder_info)
1184
+
1185
+ if self.diffae.decoder is not None:
1186
+ latents = self.diffae.decoder(latents)
1187
+
1188
+ # Upsample latents to match diffusion length
1189
+ if latents.shape[2] != reals.shape[2]:
1190
+ latents = F.interpolate(latents, size=reals.shape[2], mode='nearest')
1191
+
1192
+ loss_info["latents_upsampled"] = latents
1193
+
1194
+ # Draw uniformly distributed continuous timesteps
1195
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1196
+
1197
+ # Calculate the noise schedule parameters for those timesteps
1198
+ alphas, sigmas = get_alphas_sigmas(t)
1199
+
1200
+ # Combine the ground truth data and the noise
1201
+ alphas = alphas[:, None, None]
1202
+ sigmas = sigmas[:, None, None]
1203
+ noise = torch.randn_like(reals)
1204
+ noised_reals = reals * alphas + noise * sigmas
1205
+ targets = noise * alphas - reals * sigmas
1206
+
1207
+ with torch.cuda.amp.autocast():
1208
+ v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents)
1209
+
1210
+ loss_info.update({
1211
+ "v": v,
1212
+ "targets": targets
1213
+ })
1214
+
1215
+ if self.use_reconstruction_loss:
1216
+ pred = noised_reals * alphas - v * sigmas
1217
+
1218
+ loss_info["pred"] = pred
1219
+
1220
+ if self.diffae.pretransform is not None:
1221
+ pred = self.diffae.pretransform.decode(pred)
1222
+ loss_info["audio_pred"] = pred
1223
+
1224
+ loss, losses = self.losses(loss_info)
1225
+
1226
+ log_dict = {
1227
+ 'train/loss': loss.detach(),
1228
+ 'train/std_data': reals.std(),
1229
+ 'train/latent_std': latents.std(),
1230
+ }
1231
+
1232
+ for loss_name, loss_value in losses.items():
1233
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
1234
+
1235
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
1236
+ return loss
1237
+
1238
+ def on_before_zero_grad(self, *args, **kwargs):
1239
+ self.diffae_ema.update()
1240
+
1241
+ def export_model(self, path, use_safetensors=False):
1242
+
1243
+ model = self.diffae_ema.ema_model
1244
+
1245
+ if use_safetensors:
1246
+ save_file(model.state_dict(), path)
1247
+ else:
1248
+ torch.save({"state_dict": model.state_dict()}, path)
1249
+
1250
+ class DiffusionAutoencoderDemoCallback(pl.Callback):
1251
+ def __init__(
1252
+ self,
1253
+ demo_dl,
1254
+ demo_every=2000,
1255
+ demo_steps=250,
1256
+ sample_size=65536,
1257
+ sample_rate=48000
1258
+ ):
1259
+ super().__init__()
1260
+ self.demo_every = demo_every
1261
+ self.demo_steps = demo_steps
1262
+ self.demo_samples = sample_size
1263
+ self.demo_dl = iter(demo_dl)
1264
+ self.sample_rate = sample_rate
1265
+ self.last_demo_step = -1
1266
+
1267
+ @rank_zero_only
1268
+ @torch.no_grad()
1269
+ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1270
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1271
+ return
1272
+
1273
+ self.last_demo_step = trainer.global_step
1274
+
1275
+ demo_reals, _ = next(self.demo_dl)
1276
+
1277
+ # Remove extra dimension added by WebDataset
1278
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1279
+ demo_reals = demo_reals[0]
1280
+
1281
+ encoder_input = demo_reals
1282
+
1283
+ encoder_input = encoder_input.to(module.device)
1284
+
1285
+ demo_reals = demo_reals.to(module.device)
1286
+
1287
+ with torch.no_grad() and torch.cuda.amp.autocast():
1288
+ latents = module.diffae_ema.ema_model.encode(encoder_input).float()
1289
+ fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps)
1290
+
1291
+ #Interleave reals and fakes
1292
+ reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1293
+
1294
+ # Put the demos together
1295
+ reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1296
+
1297
+ log_dict = {}
1298
+
1299
+ filename = f'recon_{trainer.global_step:08}.wav'
1300
+ reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
1301
+ torchaudio.save(filename, reals_fakes, self.sample_rate)
1302
+
1303
+ log_dict[f'recon'] = wandb.Audio(filename,
1304
+ sample_rate=self.sample_rate,
1305
+ caption=f'Reconstructed')
1306
+
1307
+ log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
1308
+ log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
1309
+
1310
+ log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
1311
+
1312
+ if module.diffae_ema.ema_model.pretransform is not None:
1313
+ with torch.no_grad() and torch.cuda.amp.autocast():
1314
+ initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input)
1315
+ first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents)
1316
+ first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)')
1317
+ first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu()
1318
+ first_stage_filename = f'first_stage_{trainer.global_step:08}.wav'
1319
+ torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate)
1320
+
1321
+ log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents))
1322
+
1323
+ log_dict[f'first_stage'] = wandb.Audio(first_stage_filename,
1324
+ sample_rate=self.sample_rate,
1325
+ caption=f'First Stage Reconstructed')
1326
+
1327
+ log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes))
1328
+
1329
+
1330
+ trainer.logger.experiment.log(log_dict)
1331
+
1332
+ def create_source_mixture(reals, num_sources=2):
1333
+ # Create a fake mixture source by mixing elements from the training batch together with random offsets
1334
+ source = torch.zeros_like(reals)
1335
+ for i in range(reals.shape[0]):
1336
+ sources_added = 0
1337
+
1338
+ js = list(range(reals.shape[0]))
1339
+ random.shuffle(js)
1340
+ for j in js:
1341
+ if i == j or (i != j and sources_added < num_sources):
1342
+ # Randomly offset the mixed element between 0 and the length of the source
1343
+ seq_len = reals.shape[2]
1344
+ offset = random.randint(0, seq_len-1)
1345
+ source[i, :, offset:] += reals[j, :, :-offset]
1346
+ if i == j:
1347
+ # If this is the real one, shift the reals as well to ensure alignment
1348
+ new_reals = torch.zeros_like(reals[i])
1349
+ new_reals[:, offset:] = reals[i, :, :-offset]
1350
+ reals[i] = new_reals
1351
+ sources_added += 1
1352
+
1353
+ return source
1354
+
1355
+ class DiffusionPriorTrainingWrapper(pl.LightningModule):
1356
+ '''
1357
+ Wrapper for training a diffusion prior for inverse problems
1358
+ Prior types:
1359
+ mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version
1360
+ '''
1361
+ def __init__(
1362
+ self,
1363
+ model: ConditionedDiffusionModelWrapper,
1364
+ lr: float = 1e-4,
1365
+ ema_copy = None,
1366
+ prior_type: PriorType = PriorType.MonoToStereo,
1367
+ use_reconstruction_loss: bool = False,
1368
+ log_loss_info: bool = False,
1369
+ ):
1370
+ super().__init__()
1371
+
1372
+ self.diffusion = model
1373
+
1374
+ self.diffusion_ema = EMA(
1375
+ self.diffusion,
1376
+ ema_model=ema_copy,
1377
+ beta=0.9999,
1378
+ power=3/4,
1379
+ update_every=1,
1380
+ update_after_step=1,
1381
+ include_online_model=False
1382
+ )
1383
+
1384
+ self.lr = lr
1385
+
1386
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1387
+
1388
+ self.log_loss_info = log_loss_info
1389
+
1390
+ loss_modules = [
1391
+ MSELoss("v",
1392
+ "targets",
1393
+ weight=1.0,
1394
+ name="mse_loss"
1395
+ )
1396
+ ]
1397
+
1398
+ self.use_reconstruction_loss = use_reconstruction_loss
1399
+
1400
+ if use_reconstruction_loss:
1401
+ scales = [2048, 1024, 512, 256, 128, 64, 32]
1402
+ hop_sizes = []
1403
+ win_lengths = []
1404
+ overlap = 0.75
1405
+ for s in scales:
1406
+ hop_sizes.append(int(s * (1 - overlap)))
1407
+ win_lengths.append(s)
1408
+
1409
+ sample_rate = model.sample_rate
1410
+
1411
+ stft_loss_args = {
1412
+ "fft_sizes": scales,
1413
+ "hop_sizes": hop_sizes,
1414
+ "win_lengths": win_lengths,
1415
+ "perceptual_weighting": True
1416
+ }
1417
+
1418
+ out_channels = model.io_channels
1419
+
1420
+ self.audio_out_channels = out_channels
1421
+
1422
+ if model.pretransform is not None:
1423
+ out_channels = model.pretransform.io_channels
1424
+
1425
+ if self.audio_out_channels == 2:
1426
+ self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1427
+ self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1428
+
1429
+ # Add left and right channel reconstruction losses in addition to the sum and difference
1430
+ self.loss_modules += [
1431
+ AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05),
1432
+ AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05),
1433
+ ]
1434
+
1435
+ else:
1436
+ self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1437
+
1438
+ self.loss_modules.append(
1439
+ AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1440
+ )
1441
+
1442
+ self.losses = MultiLoss(loss_modules)
1443
+
1444
+ self.prior_type = prior_type
1445
+
1446
+ def configure_optimizers(self):
1447
+ return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
1448
+
1449
+ def training_step(self, batch, batch_idx):
1450
+ reals, metadata = batch
1451
+
1452
+ if reals.ndim == 4 and reals.shape[0] == 1:
1453
+ reals = reals[0]
1454
+
1455
+ loss_info = {}
1456
+
1457
+ loss_info["audio_reals"] = reals
1458
+
1459
+ if self.prior_type == PriorType.MonoToStereo:
1460
+ source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device)
1461
+ loss_info["audio_reals_mono"] = source
1462
+ else:
1463
+ raise ValueError(f"Unknown prior type {self.prior_type}")
1464
+
1465
+ if self.diffusion.pretransform is not None:
1466
+ with torch.no_grad():
1467
+ reals = self.diffusion.pretransform.encode(reals)
1468
+
1469
+ if self.prior_type in [PriorType.MonoToStereo]:
1470
+ source = self.diffusion.pretransform.encode(source)
1471
+
1472
+ if self.diffusion.conditioner is not None:
1473
+ with torch.cuda.amp.autocast():
1474
+ conditioning = self.diffusion.conditioner(metadata, self.device)
1475
+ else:
1476
+ conditioning = {}
1477
+
1478
+ loss_info["reals"] = reals
1479
+
1480
+ # Draw uniformly distributed continuous timesteps
1481
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1482
+
1483
+ # Calculate the noise schedule parameters for those timesteps
1484
+ alphas, sigmas = get_alphas_sigmas(t)
1485
+
1486
+ # Combine the ground truth data and the noise
1487
+ alphas = alphas[:, None, None]
1488
+ sigmas = sigmas[:, None, None]
1489
+ noise = torch.randn_like(reals)
1490
+ noised_reals = reals * alphas + noise * sigmas
1491
+ targets = noise * alphas - reals * sigmas
1492
+
1493
+ with torch.cuda.amp.autocast():
1494
+
1495
+ conditioning['source'] = [source]
1496
+
1497
+ v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1)
1498
+
1499
+ loss_info.update({
1500
+ "v": v,
1501
+ "targets": targets
1502
+ })
1503
+
1504
+ if self.use_reconstruction_loss:
1505
+ pred = noised_reals * alphas - v * sigmas
1506
+
1507
+ loss_info["pred"] = pred
1508
+
1509
+ if self.diffusion.pretransform is not None:
1510
+ pred = self.diffusion.pretransform.decode(pred)
1511
+ loss_info["audio_pred"] = pred
1512
+
1513
+ if self.audio_out_channels == 2:
1514
+ loss_info["pred_left"] = pred[:, 0:1, :]
1515
+ loss_info["pred_right"] = pred[:, 1:2, :]
1516
+ loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :]
1517
+ loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :]
1518
+
1519
+ loss, losses = self.losses(loss_info)
1520
+
1521
+ if self.log_loss_info:
1522
+ # Loss debugging logs
1523
+ num_loss_buckets = 10
1524
+ bucket_size = 1 / num_loss_buckets
1525
+ loss_all = F.mse_loss(v, targets, reduction="none")
1526
+
1527
+ sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
1528
+
1529
+ # gather loss_all across all GPUs
1530
+ loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
1531
+
1532
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
1533
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
1534
+
1535
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
1536
+ debug_log_dict = {
1537
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
1538
+ }
1539
+
1540
+ self.log_dict(debug_log_dict)
1541
+
1542
+ log_dict = {
1543
+ 'train/loss': loss.detach(),
1544
+ 'train/std_data': reals.std()
1545
+ }
1546
+
1547
+ for loss_name, loss_value in losses.items():
1548
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
1549
+
1550
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
1551
+ return loss
1552
+
1553
+ def on_before_zero_grad(self, *args, **kwargs):
1554
+ self.diffusion_ema.update()
1555
+
1556
+ def export_model(self, path, use_safetensors=False):
1557
+
1558
+ #model = self.diffusion_ema.ema_model
1559
+ model = self.diffusion
1560
+
1561
+ if use_safetensors:
1562
+ save_file(model.state_dict(), path)
1563
+ else:
1564
+ torch.save({"state_dict": model.state_dict()}, path)
1565
+
1566
+ class DiffusionPriorDemoCallback(pl.Callback):
1567
+ def __init__(
1568
+ self,
1569
+ demo_dl,
1570
+ demo_every=2000,
1571
+ demo_steps=250,
1572
+ sample_size=65536,
1573
+ sample_rate=48000
1574
+ ):
1575
+ super().__init__()
1576
+ self.demo_every = demo_every
1577
+ self.demo_steps = demo_steps
1578
+ self.demo_samples = sample_size
1579
+ self.demo_dl = iter(demo_dl)
1580
+ self.sample_rate = sample_rate
1581
+ self.last_demo_step = -1
1582
+
1583
+ @rank_zero_only
1584
+ @torch.no_grad()
1585
+ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1586
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1587
+ return
1588
+
1589
+ self.last_demo_step = trainer.global_step
1590
+
1591
+ demo_reals, metadata = next(self.demo_dl)
1592
+
1593
+ # Remove extra dimension added by WebDataset
1594
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1595
+ demo_reals = demo_reals[0]
1596
+
1597
+ demo_reals = demo_reals.to(module.device)
1598
+
1599
+ encoder_input = demo_reals
1600
+
1601
+ if module.diffusion.conditioner is not None:
1602
+ with torch.cuda.amp.autocast():
1603
+ conditioning_tensors = module.diffusion.conditioner(metadata, module.device)
1604
+
1605
+ else:
1606
+ conditioning_tensors = {}
1607
+
1608
+
1609
+ with torch.no_grad() and torch.cuda.amp.autocast():
1610
+ if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1:
1611
+ source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device)
1612
+
1613
+ if module.diffusion.pretransform is not None:
1614
+ encoder_input = module.diffusion.pretransform.encode(encoder_input)
1615
+ source_input = module.diffusion.pretransform.encode(source)
1616
+ else:
1617
+ source_input = source
1618
+
1619
+ conditioning_tensors['source'] = [source_input]
1620
+
1621
+ fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors)
1622
+
1623
+ if module.diffusion.pretransform is not None:
1624
+ fakes = module.diffusion.pretransform.decode(fakes)
1625
+
1626
+ #Interleave reals and fakes
1627
+ reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1628
+
1629
+ # Put the demos together
1630
+ reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1631
+
1632
+ log_dict = {}
1633
+
1634
+ filename = f'recon_{trainer.global_step:08}.wav'
1635
+ reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
1636
+ torchaudio.save(filename, reals_fakes, self.sample_rate)
1637
+
1638
+ log_dict[f'recon'] = wandb.Audio(filename,
1639
+ sample_rate=self.sample_rate,
1640
+ caption=f'Reconstructed')
1641
+
1642
+ log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
1643
+
1644
+ #Log the source
1645
+ filename = f'source_{trainer.global_step:08}.wav'
1646
+ source = rearrange(source, 'b d n -> d (b n)')
1647
+ source = source.to(torch.float32).mul(32767).to(torch.int16).cpu()
1648
+ torchaudio.save(filename, source, self.sample_rate)
1649
+
1650
+ log_dict[f'source'] = wandb.Audio(filename,
1651
+ sample_rate=self.sample_rate,
1652
+ caption=f'Source')
1653
+
1654
+ log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source))
1655
+
1656
+ trainer.logger.experiment.log(log_dict)
stable_audio_tools/training/factory.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Parameter
3
+ from ..models.factory import create_model_from_config
4
+
5
+ def create_training_wrapper_from_config(model_config, model):
6
+ model_type = model_config.get('model_type', None)
7
+ assert model_type is not None, 'model_type must be specified in model config'
8
+
9
+ training_config = model_config.get('training', None)
10
+ assert training_config is not None, 'training config must be specified in model config'
11
+
12
+ if model_type == 'autoencoder':
13
+ from .autoencoders import AutoencoderTrainingWrapper
14
+
15
+ ema_copy = None
16
+
17
+ if training_config.get("use_ema", False):
18
+ ema_copy = create_model_from_config(model_config)
19
+ ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once
20
+ # Copy each weight to the ema copy
21
+ for name, param in model.state_dict().items():
22
+ if isinstance(param, Parameter):
23
+ # backwards compatibility for serialized parameters
24
+ param = param.data
25
+ ema_copy.state_dict()[name].copy_(param)
26
+
27
+ use_ema = training_config.get("use_ema", False)
28
+
29
+ latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0)
30
+
31
+ teacher_model = training_config.get("teacher_model", None)
32
+ if teacher_model is not None:
33
+ teacher_model = create_model_from_config(teacher_model)
34
+ teacher_model = teacher_model.eval().requires_grad_(False)
35
+
36
+ teacher_model_ckpt = training_config.get("teacher_model_ckpt", None)
37
+ if teacher_model_ckpt is not None:
38
+ teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"])
39
+ else:
40
+ raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified")
41
+
42
+ return AutoencoderTrainingWrapper(
43
+ model,
44
+ lr=training_config["learning_rate"],
45
+ warmup_steps=training_config.get("warmup_steps", 0),
46
+ encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False),
47
+ sample_rate=model_config["sample_rate"],
48
+ loss_config=training_config.get("loss_configs", None),
49
+ optimizer_configs=training_config.get("optimizer_configs", None),
50
+ use_ema=use_ema,
51
+ ema_copy=ema_copy if use_ema else None,
52
+ force_input_mono=training_config.get("force_input_mono", False),
53
+ latent_mask_ratio=latent_mask_ratio,
54
+ teacher_model=teacher_model
55
+ )
56
+ elif model_type == 'diffusion_uncond':
57
+ from .diffusion import DiffusionUncondTrainingWrapper
58
+ return DiffusionUncondTrainingWrapper(
59
+ model,
60
+ lr=training_config["learning_rate"],
61
+ pre_encoded=training_config.get("pre_encoded", False),
62
+ )
63
+ elif model_type == 'diffusion_cond':
64
+ from .diffusion import DiffusionCondTrainingWrapper
65
+ return DiffusionCondTrainingWrapper(
66
+ model,
67
+ lr=training_config.get("learning_rate", None),
68
+ mask_padding=training_config.get("mask_padding", False),
69
+ mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
70
+ use_ema = training_config.get("use_ema", True),
71
+ log_loss_info=training_config.get("log_loss_info", False),
72
+ optimizer_configs=training_config.get("optimizer_configs", None),
73
+ pre_encoded=training_config.get("pre_encoded", False),
74
+ cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
75
+ timestep_sampler = training_config.get("timestep_sampler", "uniform")
76
+ )
77
+ elif model_type == 'diffusion_prior':
78
+ from .diffusion import DiffusionPriorTrainingWrapper
79
+ from ..models.diffusion_prior import PriorType
80
+
81
+ ema_copy = create_model_from_config(model_config)
82
+
83
+ # Copy each weight to the ema copy
84
+ for name, param in model.state_dict().items():
85
+ if isinstance(param, Parameter):
86
+ # backwards compatibility for serialized parameters
87
+ param = param.data
88
+ ema_copy.state_dict()[name].copy_(param)
89
+
90
+ prior_type = training_config.get("prior_type", "mono_stereo")
91
+
92
+ if prior_type == "mono_stereo":
93
+ prior_type_enum = PriorType.MonoToStereo
94
+ else:
95
+ raise ValueError(f"Unknown prior type: {prior_type}")
96
+
97
+ return DiffusionPriorTrainingWrapper(
98
+ model,
99
+ lr=training_config["learning_rate"],
100
+ ema_copy=ema_copy,
101
+ prior_type=prior_type_enum,
102
+ log_loss_info=training_config.get("log_loss_info", False),
103
+ use_reconstruction_loss=training_config.get("use_reconstruction_loss", False),
104
+ )
105
+ elif model_type == 'diffusion_cond_inpaint':
106
+ from .diffusion import DiffusionCondInpaintTrainingWrapper
107
+ return DiffusionCondInpaintTrainingWrapper(
108
+ model,
109
+ lr=training_config.get("learning_rate", None),
110
+ max_mask_segments = training_config.get("max_mask_segments", 10),
111
+ log_loss_info=training_config.get("log_loss_info", False),
112
+ optimizer_configs=training_config.get("optimizer_configs", None),
113
+ use_ema=training_config.get("use_ema", True),
114
+ pre_encoded=training_config.get("pre_encoded", False),
115
+ cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
116
+ timestep_sampler = training_config.get("timestep_sampler", "uniform")
117
+ )
118
+ elif model_type == 'diffusion_autoencoder':
119
+ from .diffusion import DiffusionAutoencoderTrainingWrapper
120
+
121
+ ema_copy = create_model_from_config(model_config)
122
+
123
+ # Copy each weight to the ema copy
124
+ for name, param in model.state_dict().items():
125
+ if isinstance(param, Parameter):
126
+ # backwards compatibility for serialized parameters
127
+ param = param.data
128
+ ema_copy.state_dict()[name].copy_(param)
129
+
130
+ return DiffusionAutoencoderTrainingWrapper(
131
+ model,
132
+ ema_copy=ema_copy,
133
+ lr=training_config["learning_rate"],
134
+ use_reconstruction_loss=training_config.get("use_reconstruction_loss", False)
135
+ )
136
+ elif model_type == 'lm':
137
+ from .lm import AudioLanguageModelTrainingWrapper
138
+
139
+ ema_copy = create_model_from_config(model_config)
140
+
141
+ for name, param in model.state_dict().items():
142
+ if isinstance(param, Parameter):
143
+ # backwards compatibility for serialized parameters
144
+ param = param.data
145
+ ema_copy.state_dict()[name].copy_(param)
146
+
147
+ return AudioLanguageModelTrainingWrapper(
148
+ model,
149
+ ema_copy=ema_copy,
150
+ lr=training_config.get("learning_rate", None),
151
+ use_ema=training_config.get("use_ema", False),
152
+ optimizer_configs=training_config.get("optimizer_configs", None),
153
+ pre_encoded=training_config.get("pre_encoded", False),
154
+ )
155
+
156
+ else:
157
+ raise NotImplementedError(f'Unknown model type: {model_type}')
158
+
159
+ def create_demo_callback_from_config(model_config, **kwargs):
160
+ model_type = model_config.get('model_type', None)
161
+ assert model_type is not None, 'model_type must be specified in model config'
162
+
163
+ training_config = model_config.get('training', None)
164
+ assert training_config is not None, 'training config must be specified in model config'
165
+
166
+ demo_config = training_config.get("demo", {})
167
+
168
+ if model_type == 'autoencoder':
169
+ from .autoencoders import AutoencoderDemoCallback
170
+ return AutoencoderDemoCallback(
171
+ demo_every=demo_config.get("demo_every", 2000),
172
+ sample_size=model_config["sample_size"],
173
+ sample_rate=model_config["sample_rate"],
174
+ **kwargs
175
+ )
176
+ elif model_type == 'diffusion_uncond':
177
+ from .diffusion import DiffusionUncondDemoCallback
178
+ return DiffusionUncondDemoCallback(
179
+ demo_every=demo_config.get("demo_every", 2000),
180
+ demo_steps=demo_config.get("demo_steps", 250),
181
+ sample_rate=model_config["sample_rate"]
182
+ )
183
+ elif model_type == "diffusion_autoencoder":
184
+ from .diffusion import DiffusionAutoencoderDemoCallback
185
+ return DiffusionAutoencoderDemoCallback(
186
+ demo_every=demo_config.get("demo_every", 2000),
187
+ demo_steps=demo_config.get("demo_steps", 250),
188
+ sample_size=model_config["sample_size"],
189
+ sample_rate=model_config["sample_rate"],
190
+ **kwargs
191
+ )
192
+ elif model_type == "diffusion_prior":
193
+ from .diffusion import DiffusionPriorDemoCallback
194
+ return DiffusionPriorDemoCallback(
195
+ demo_every=demo_config.get("demo_every", 2000),
196
+ demo_steps=demo_config.get("demo_steps", 250),
197
+ sample_size=model_config["sample_size"],
198
+ sample_rate=model_config["sample_rate"],
199
+ **kwargs
200
+ )
201
+ elif model_type == "diffusion_cond":
202
+ from .diffusion import DiffusionCondDemoCallback
203
+
204
+ return DiffusionCondDemoCallback(
205
+ demo_every=demo_config.get("demo_every", 2000),
206
+ sample_size=model_config["sample_size"],
207
+ sample_rate=model_config["sample_rate"],
208
+ demo_steps=demo_config.get("demo_steps", 250),
209
+ num_demos=demo_config["num_demos"],
210
+ demo_cfg_scales=demo_config["demo_cfg_scales"],
211
+ demo_conditioning=demo_config.get("demo_cond", {}),
212
+ demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False),
213
+ display_audio_cond=demo_config.get("display_audio_cond", False),
214
+ )
215
+ elif model_type == "diffusion_cond_inpaint":
216
+ from .diffusion import DiffusionCondInpaintDemoCallback
217
+
218
+ return DiffusionCondInpaintDemoCallback(
219
+ demo_every=demo_config.get("demo_every", 2000),
220
+ sample_size=model_config["sample_size"],
221
+ sample_rate=model_config["sample_rate"],
222
+ demo_steps=demo_config.get("demo_steps", 250),
223
+ demo_cfg_scales=demo_config["demo_cfg_scales"],
224
+ **kwargs
225
+ )
226
+
227
+ elif model_type == "lm":
228
+ from .lm import AudioLanguageModelDemoCallback
229
+
230
+ return AudioLanguageModelDemoCallback(
231
+ demo_every=demo_config.get("demo_every", 2000),
232
+ sample_size=model_config["sample_size"],
233
+ sample_rate=model_config["sample_rate"],
234
+ demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]),
235
+ demo_conditioning=demo_config.get("demo_cond", None),
236
+ num_demos=demo_config.get("num_demos", 8),
237
+ **kwargs
238
+ )
239
+ else:
240
+ raise NotImplementedError(f'Unknown model type: {model_type}')
stable_audio_tools/training/lm.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import sys, gc
3
+ import random
4
+ import torch
5
+ import torchaudio
6
+ import typing as tp
7
+ import wandb
8
+
9
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
10
+ from ema_pytorch import EMA
11
+ from einops import rearrange
12
+ from safetensors.torch import save_file
13
+ from torch import optim
14
+ from torch.nn import functional as F
15
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
16
+
17
+ from ..models.lm import AudioLanguageModelWrapper
18
+ from .utils import create_optimizer_from_config, create_scheduler_from_config
19
+
20
+ class AudioLanguageModelTrainingWrapper(pl.LightningModule):
21
+ def __init__(
22
+ self,
23
+ model: AudioLanguageModelWrapper,
24
+ lr = 1e-4,
25
+ use_ema=False,
26
+ ema_copy=None,
27
+ optimizer_configs: dict = None,
28
+ pre_encoded=False
29
+ ):
30
+ super().__init__()
31
+
32
+ self.model = model
33
+
34
+ self.model.pretransform.requires_grad_(False)
35
+
36
+ self.model_ema = None
37
+ if use_ema:
38
+ self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10)
39
+
40
+ assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
41
+
42
+ if optimizer_configs is None:
43
+ optimizer_configs = {
44
+ "lm": {
45
+ "optimizer": {
46
+ "type": "AdamW",
47
+ "config": {
48
+ "lr": lr,
49
+ "betas": (0.9, 0.95),
50
+ "weight_decay": 0.1
51
+ }
52
+ }
53
+ }
54
+ }
55
+ else:
56
+ if lr is not None:
57
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
58
+
59
+ self.optimizer_configs = optimizer_configs
60
+
61
+ self.pre_encoded = pre_encoded
62
+
63
+ def configure_optimizers(self):
64
+ lm_opt_config = self.optimizer_configs['lm']
65
+ opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters())
66
+
67
+ if "scheduler" in lm_opt_config:
68
+ sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm)
69
+ sched_lm_config = {
70
+ "scheduler": sched_lm,
71
+ "interval": "step"
72
+ }
73
+ return [opt_lm], [sched_lm_config]
74
+
75
+ return [opt_lm]
76
+
77
+ # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license
78
+ # License can be found in LICENSES/LICENSE_META.txt
79
+
80
+ def _compute_cross_entropy(
81
+ self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
82
+ ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
83
+ """Compute cross entropy between multi-codebook targets and model's logits.
84
+ The cross entropy is computed per codebook to provide codebook-level cross entropy.
85
+ Valid timesteps for each of the codebook are pulled from the mask, where invalid
86
+ timesteps are set to 0.
87
+
88
+ Args:
89
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
90
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
91
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
92
+ Returns:
93
+ ce (torch.Tensor): Cross entropy averaged over the codebooks
94
+ ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
95
+ """
96
+ B, K, T = targets.shape
97
+ assert logits.shape[:-1] == targets.shape
98
+ assert mask.shape == targets.shape
99
+ ce = torch.zeros([], device=targets.device)
100
+ ce_per_codebook: tp.List[torch.Tensor] = []
101
+ for k in range(K):
102
+ logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
103
+ targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
104
+ mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
105
+ ce_targets = targets_k[mask_k]
106
+ ce_logits = logits_k[mask_k]
107
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
108
+ ce += q_ce
109
+ ce_per_codebook.append(q_ce.detach())
110
+ # average cross entropy across codebooks
111
+ ce = ce / K
112
+ return ce, ce_per_codebook
113
+
114
+ def training_step(self, batch, batch_idx):
115
+ reals, metadata = batch
116
+
117
+ if reals.ndim == 4 and reals.shape[0] == 1:
118
+ reals = reals[0]
119
+
120
+ if not self.pre_encoded:
121
+ codes = self.model.pretransform.tokenize(reals)
122
+ else:
123
+ codes = reals
124
+
125
+ padding_masks = []
126
+ for md in metadata:
127
+ if md["padding_mask"].ndim == 1:
128
+ padding_masks.append(md["padding_mask"])
129
+ else:
130
+ padding_masks.append(md["padding_mask"][0])
131
+
132
+ padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length)
133
+
134
+ # Interpolate padding masks to the same length as the codes
135
+ padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool()
136
+
137
+ condition_tensors = None
138
+
139
+ # If the model is conditioned, get the conditioning tensors
140
+ if self.model.conditioner is not None:
141
+ condition_tensors = self.model.conditioner(metadata, self.device)
142
+
143
+ lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1)
144
+
145
+ logits = lm_output.logits # [b, k, t, c]
146
+ logits_mask = lm_output.mask # [b, k, t]
147
+
148
+ logits_mask = logits_mask & padding_masks
149
+
150
+ cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask)
151
+
152
+ loss = cross_entropy
153
+
154
+ log_dict = {
155
+ 'train/loss': loss.detach(),
156
+ 'train/cross_entropy': cross_entropy.detach(),
157
+ 'train/perplexity': torch.exp(cross_entropy).detach(),
158
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
159
+ }
160
+
161
+ for k, ce_q in enumerate(cross_entropy_per_codebook):
162
+ log_dict[f'cross_entropy_q{k + 1}'] = ce_q
163
+ log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q)
164
+
165
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
166
+ return loss
167
+
168
+ def on_before_zero_grad(self, *args, **kwargs):
169
+ if self.model_ema is not None:
170
+ self.model_ema.update()
171
+
172
+ def export_model(self, path, use_safetensors=False):
173
+
174
+ model = self.model_ema.ema_model if self.model_ema is not None else self.model
175
+
176
+ if use_safetensors:
177
+ save_file(model.state_dict(), path)
178
+ else:
179
+ torch.save({"state_dict": model.state_dict()}, path)
180
+
181
+
182
+ class AudioLanguageModelDemoCallback(pl.Callback):
183
+ def __init__(self,
184
+ demo_every=2000,
185
+ num_demos=8,
186
+ sample_size=65536,
187
+ sample_rate=48000,
188
+ demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
189
+ demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
190
+ **kwargs
191
+ ):
192
+ super().__init__()
193
+
194
+ self.demo_every = demo_every
195
+ self.num_demos = num_demos
196
+ self.demo_samples = sample_size
197
+ self.sample_rate = sample_rate
198
+ self.last_demo_step = -1
199
+ self.demo_conditioning = demo_conditioning
200
+ self.demo_cfg_scales = demo_cfg_scales
201
+
202
+ @rank_zero_only
203
+ @torch.no_grad()
204
+ def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx):
205
+
206
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
207
+ return
208
+
209
+ module.eval()
210
+
211
+ print(f"Generating demo")
212
+ self.last_demo_step = trainer.global_step
213
+
214
+ demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio
215
+
216
+ #demo_reals = batch[0][:self.num_demos]
217
+
218
+ # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
219
+ # demo_reals = demo_reals[0]
220
+
221
+ #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals)
222
+
223
+ ##Limit to first 50 tokens
224
+ #demo_reals_tokens = demo_reals_tokens[:, :, :50]
225
+
226
+ try:
227
+ print("Getting conditioning")
228
+
229
+ for cfg_scale in self.demo_cfg_scales:
230
+
231
+ model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model
232
+
233
+ print(f"Generating demo for cfg scale {cfg_scale}")
234
+ fakes = model.generate_audio(
235
+ batch_size=self.num_demos,
236
+ max_gen_len=demo_length_tokens,
237
+ conditioning=self.demo_conditioning,
238
+ #init_data = demo_reals_tokens,
239
+ cfg_scale=cfg_scale,
240
+ temp=1.0,
241
+ top_p=0.95
242
+ )
243
+
244
+ # Put the demos together
245
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
246
+
247
+ log_dict = {}
248
+
249
+ filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
250
+ fakes = fakes / fakes.abs().max()
251
+ fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu()
252
+ torchaudio.save(filename, fakes, self.sample_rate)
253
+
254
+ log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
255
+ sample_rate=self.sample_rate,
256
+ caption=f'Reconstructed')
257
+
258
+ log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
259
+
260
+ trainer.logger.experiment.log(log_dict)
261
+
262
+ except Exception as e:
263
+ raise e
264
+ finally:
265
+ gc.collect()
266
+ torch.cuda.empty_cache()
267
+ module.train()